Coverage for src / bartz / _profiler.py: 94%
88 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
1# bartz/src/bartz/_profiler.py
2#
3# Copyright (c) 2025-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Module with utilities related to profiling bartz."""
27from collections.abc import Callable, Iterator
28from contextlib import contextmanager
29from functools import wraps
30from typing import Any, TypeVar
32from jax import block_until_ready, debug, jit
33from jax.lax import cond, scan
34from jax.profiler import TraceAnnotation
35from jaxtyping import Array, Bool
37from bartz.mcmcstep._state import vmap_chains
39PROFILE_MODE: bool = False
41T = TypeVar('T')
42Carry = TypeVar('Carry')
45def get_profile_mode() -> bool:
46 """Return the current profile mode status.
48 Returns
49 -------
50 True if profile mode is enabled, False otherwise.
51 """
52 return PROFILE_MODE 2c d b e f g o p q ybzbAbBbCbDbr s t abbbcbu v w x Eby z FbA B C D E F G GbHbIbJbKbLbMbNbObPbQbRbH I J SbTbUbK L M VbWbXbN O P Q R YbZb0bS T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb) * + , - . / : ; = ? @ n h [ ] ^ _ ubvbwbxb( } ` ~ { tb1b2bi j k 9 l m ! $ % # | '
55def set_profile_mode(value: bool, /) -> None:
56 """Set the profile mode status.
58 Parameters
59 ----------
60 value
61 If True, enable profile mode. If False, disable it.
62 """
63 global PROFILE_MODE # noqa: PLW0603
64 PROFILE_MODE = value 2c d b e f g n h ( } ` ~ { tb1bi j k 9 l m ! $ % # | '
67@contextmanager
68def profile_mode(value: bool, /) -> Iterator[None]:
69 """Context manager to temporarily set profile mode.
71 Parameters
72 ----------
73 value
74 Profile mode value to set within the context.
76 Examples
77 --------
78 >>> with profile_mode(True):
79 ... # Code runs with profile mode enabled
80 ... pass
82 Notes
83 -----
84 In profiling mode, the MCMC loop is not compiled into a single function, but
85 instead compiled in smaller pieces that are instrumented to show up in the
86 jax tracer and Python profiling statistics. Search for function names
87 starting with 'jab' (see `jit_and_block_if_profiling`).
89 Jax tracing is not enabled by this context manager and if used must be
90 handled separately by the user; this context manager only makes sure that
91 the execution flow will be more interpretable in the traces if the tracer is
92 used.
93 """
94 old_value = get_profile_mode() 2c d b e f g n h ( } ` ~ { tbi j k 9 l m ! $ % # | '
95 set_profile_mode(value) 2c d b e f g n h ( } ` ~ { tbi j k 9 l m ! $ % # | '
96 try: 2c d b e f g n h ( } ` ~ { tbi j k 9 l m ! $ % # | '
97 yield 2c d b e f g n h ( } ` ~ { tbi j k 9 l m ! $ % # | '
98 finally:
99 set_profile_mode(old_value) 2c d b e f g n h ( } ` ~ { tbi j k 9 l m ! $ % # | '
102def jit_and_block_if_profiling(
103 func: Callable[..., T], block_before: bool = False, **kwargs
104) -> Callable[..., T]:
105 """Apply JIT compilation and block if profiling is enabled.
107 When profile mode is off, the function runs without JIT. When profile mode
108 is on, the function is JIT compiled and blocks outputs to ensure proper
109 timing.
111 Parameters
112 ----------
113 func
114 Function to wrap.
115 block_before
116 If True block inputs before passing them to the JIT-compiled function.
117 This ensures that any pending computations are completed before entering
118 the JIT-compiled function. This phase is not included in the trace
119 event.
120 **kwargs
121 Additional arguments to pass to `jax.jit`.
123 Returns
124 -------
125 Wrapped function.
127 Notes
128 -----
129 Under profiling mode, the function invocation is handled such that a custom
130 jax trace event with name `jab[<func_name>]` is created. The statistics on
131 the actual Python function will be off, while the function
132 `jab_inner_wrapper` represents the actual execution time.
133 """
134 jitted_func = jit(func, **kwargs) 1aijk9lm
136 event_name = f'jab[{func.__name__}]' 1aijk9lm
138 # this wrapper is meant to measure the time spent executing the function
139 def jab_inner_wrapper(*args, **kwargs) -> T: 1aijk9lm
140 with TraceAnnotation(event_name): 1cdbefghijklm
141 result = jitted_func(*args, **kwargs) 1cdbefghijklm
142 return block_until_ready(result) 1cdbefghiklm
144 @wraps(func) 1aijk9lm
145 def jab_outer_wrapper(*args: Any, **kwargs: Any) -> T: 1aijk9lm
146 if get_profile_mode(): 2c d b e f g o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 ) * + , - . / : ; = ? @ n h [ ] ^ _ ubvbwbxbi j k 9 l m
147 if block_before: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true1cdbefghijklm
148 args, kwargs = block_until_ready((args, kwargs))
149 return jab_inner_wrapper(*args, **kwargs) 1cdbefghijklm
150 else:
151 return func(*args, **kwargs) 2b o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 ) * + , - . / : ; = ? @ n [ ] ^ _ ubvbwbxbi j 9
153 return jab_outer_wrapper 1aijk9lm
156def jit_if_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]:
157 """Apply JIT compilation only when profiling.
159 Parameters
160 ----------
161 func
162 Function to wrap.
163 *args
164 **kwargs
165 Additional arguments to pass to `jax.jit`.
167 Returns
168 -------
169 Wrapped function.
170 """
171 jitted_func = jit(func, *args, **kwargs)
173 @wraps(func)
174 def wrapper(*args: Any, **kwargs: Any) -> T:
175 if get_profile_mode(): 1cdbefgopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ012345678)*+,-./:;=?@nh[]^_
176 return jitted_func(*args, **kwargs) 1cdbefgh
177 else:
178 return func(*args, **kwargs) 1bopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ012345678)*+,-./:;=?@n[]^_
180 return wrapper
183def jit_if_not_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]:
184 """Apply JIT compilation only when not profiling.
186 When profile mode is off, the function is JIT compiled. When profile mode is
187 on, the function runs as-is.
189 Parameters
190 ----------
191 func
192 Function to wrap.
193 *args
194 **kwargs
195 Additional arguments to pass to `jax.jit`.
197 Returns
198 -------
199 Wrapped function.
200 """
201 jitted_func = jit(func, *args, **kwargs) 1a!$%
203 @wraps(func) 1a!$%
204 def wrapper(*args: Any, **kwargs: Any) -> T: 1a!$%
205 if get_profile_mode(): 2c d b e f g o p q ybzbAbBbCbDbr s t abbbcbu v w x Eby z FbA B C D E F G GbHbIbJbKbLbMbNbObPbQbRbH I J SbTbUbK L M VbWbXbN O P Q R YbZb0bS T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb) * + , - . / : ; = ? @ n h [ ] ^ _ ! $ %
206 return func(*args, **kwargs) 1cdbefgh!%
207 else:
208 return jitted_func(*args, **kwargs) 2c d b e f g o p q ybzbAbBbCbDbr s t abbbcbu v w x Eby z FbA B C D E F G GbHbIbJbKbLbMbNbObPbQbRbH I J SbTbUbK L M VbWbXbN O P Q R YbZb0bS T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb) * + , - . / : ; = ? @ n [ ] ^ _ ! $
210 return wrapper 1a!$%
213def scan_if_not_profiling(
214 f: Callable[[Carry, None], tuple[Carry, None]],
215 init: Carry,
216 xs: None,
217 length: int,
218 /,
219) -> tuple[Carry, None]:
220 """Restricted replacement for `jax.lax.scan` that uses a Python loop when profiling.
222 Parameters
223 ----------
224 f
225 Scan body function with signature (carry, None) -> (carry, None).
226 init
227 Initial carry value.
228 xs
229 Input values to scan over (not supported).
230 length
231 Integer specifying the number of loop iterations.
233 Returns
234 -------
235 Tuple of (final_carry, None) (stacked outputs not supported).
236 """
237 assert xs is None 2c d b e f g o p q r s t abbbcbu v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb# | '
238 if get_profile_mode(): 2c d b e f g o p q r s t abbbcbu v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb# | '
239 carry = init 1cdbefg#'
240 for _i in range(length): 1cdbefg#'
241 carry, _ = f(carry, None) 1cdbefg#'
242 return carry, None 1cdbefg#'
244 else:
245 return scan(f, init, None, length) 2b o p q r s t abbbcbu v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb# |
248def cond_if_not_profiling(
249 pred: bool | Bool[Array, ''],
250 true_fun: Callable[..., T],
251 false_fun: Callable[..., T],
252 /,
253 *operands,
254) -> T:
255 """Restricted replacement for `jax.lax.cond` that uses a Python if when profiling.
257 Parameters
258 ----------
259 pred
260 Boolean predicate to choose which function to execute.
261 true_fun
262 Function to execute if `pred` is True.
263 false_fun
264 Function to execute if `pred` is False.
265 *operands
266 Arguments passed to `true_fun` and `false_fun`.
268 Returns
269 -------
270 Result of either `true_fun()` or `false_fun()`.
271 """
272 if get_profile_mode(): 2c d b e f g o p q r s t abbbcbu v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb( } ` ~ {
273 if pred: 1cdbefg(`{
274 return true_fun(*operands) 1cdbefg({
275 else:
276 return false_fun(*operands) 1cdbefg(`
277 else:
278 return cond(pred, true_fun, false_fun, *operands) 2b o p q r s t abbbcbu v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb( } ~
281def callback_if_not_profiling(
282 callback: Callable[..., None], *args: Any, ordered: bool = False, **kwargs: Any
283):
284 """Restricted replacement for `jax.debug.callback` that calls the callback directly in profiling mode."""
285 if get_profile_mode():
286 callback(*args, **kwargs)
287 else:
288 debug.callback(callback, *args, ordered=ordered, **kwargs)
291def vmap_chains_if_profiling(fun: Callable[..., T], **kwargs) -> Callable[..., T]:
292 """Apply `vmap_chains` only when profile mode is enabled."""
293 new_fun = vmap_chains(fun, **kwargs)
295 @wraps(fun)
296 def wrapper(*args, **kwargs):
297 if get_profile_mode(): 2c d b o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 ) * + , - . / : ; = ? @ n h [ ] ^ _ ubvbwbxb
298 return new_fun(*args, **kwargs) 1cdbh
299 else:
300 return fun(*args, **kwargs) 2b o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 ) * + , - . / : ; = ? @ n [ ] ^ _ ubvbwbxb
302 return wrapper
305def vmap_chains_if_not_profiling(fun: Callable[..., T], **kwargs) -> Callable[..., T]:
306 """Apply `vmap_chains` only when profile mode is disabled."""
307 new_fun = vmap_chains(fun, **kwargs)
309 @wraps(fun)
310 def wrapper(*args, **kwargs):
311 if get_profile_mode(): 1cdbefgopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ012345678)*+,-./:;=?@nh[]^_
312 return fun(*args, **kwargs) 1cdbefgh
313 else:
314 return new_fun(*args, **kwargs) 1bopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ012345678)*+,-./:;=?@n[]^_
316 return wrapper