Coverage for src / bartz / _profiler.py: 94%

88 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-13 00:35 +0000

1# bartz/src/bartz/_profiler.py 

2# 

3# Copyright (c) 2025-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Module with utilities related to profiling bartz.""" 

26 

27from collections.abc import Callable, Iterator 

28from contextlib import contextmanager 

29from functools import wraps 

30from typing import Any, TypeVar 

31 

32from jax import block_until_ready, debug, jit 

33from jax.lax import cond, scan 

34from jax.profiler import TraceAnnotation 

35from jaxtyping import Array, Bool 

36 

37from bartz.mcmcstep._state import vmap_chains 

38 

39PROFILE_MODE: bool = False 

40 

41T = TypeVar('T') 

42Carry = TypeVar('Carry') 

43 

44 

45def get_profile_mode() -> bool: 

46 """Return the current profile mode status. 

47 

48 Returns 

49 ------- 

50 True if profile mode is enabled, False otherwise. 

51 """ 

52 return PROFILE_MODE 2c d b e f g o p q ybzbAbBbCbDbr s t abbbcbu v w x Eby z FbA B C D E F G GbHbIbJbKbLbMbNbObPbQbRbH I J SbTbUbK L M VbWbXbN O P Q R YbZb0bS T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb) * + , - . / : ; = ? @ n h [ ] ^ _ ubvbwbxb( } ` ~ { tb1b2bi j k 9 l m ! $ % # | '

53 

54 

55def set_profile_mode(value: bool, /) -> None: 

56 """Set the profile mode status. 

57 

58 Parameters 

59 ---------- 

60 value 

61 If True, enable profile mode. If False, disable it. 

62 """ 

63 global PROFILE_MODE # noqa: PLW0603 

64 PROFILE_MODE = value 2c d b e f g n h ( } ` ~ { tb1bi j k 9 l m ! $ % # | '

65 

66 

67@contextmanager 

68def profile_mode(value: bool, /) -> Iterator[None]: 

69 """Context manager to temporarily set profile mode. 

70 

71 Parameters 

72 ---------- 

73 value 

74 Profile mode value to set within the context. 

75 

76 Examples 

77 -------- 

78 >>> with profile_mode(True): 

79 ... # Code runs with profile mode enabled 

80 ... pass 

81 

82 Notes 

83 ----- 

84 In profiling mode, the MCMC loop is not compiled into a single function, but 

85 instead compiled in smaller pieces that are instrumented to show up in the 

86 jax tracer and Python profiling statistics. Search for function names 

87 starting with 'jab' (see `jit_and_block_if_profiling`). 

88 

89 Jax tracing is not enabled by this context manager and if used must be 

90 handled separately by the user; this context manager only makes sure that 

91 the execution flow will be more interpretable in the traces if the tracer is 

92 used. 

93 """ 

94 old_value = get_profile_mode() 2c d b e f g n h ( } ` ~ { tbi j k 9 l m ! $ % # | '

95 set_profile_mode(value) 2c d b e f g n h ( } ` ~ { tbi j k 9 l m ! $ % # | '

96 try: 2c d b e f g n h ( } ` ~ { tbi j k 9 l m ! $ % # | '

97 yield 2c d b e f g n h ( } ` ~ { tbi j k 9 l m ! $ % # | '

98 finally: 

99 set_profile_mode(old_value) 2c d b e f g n h ( } ` ~ { tbi j k 9 l m ! $ % # | '

100 

101 

102def jit_and_block_if_profiling( 

103 func: Callable[..., T], block_before: bool = False, **kwargs 

104) -> Callable[..., T]: 

105 """Apply JIT compilation and block if profiling is enabled. 

106 

107 When profile mode is off, the function runs without JIT. When profile mode 

108 is on, the function is JIT compiled and blocks outputs to ensure proper 

109 timing. 

110 

111 Parameters 

112 ---------- 

113 func 

114 Function to wrap. 

115 block_before 

116 If True block inputs before passing them to the JIT-compiled function. 

117 This ensures that any pending computations are completed before entering 

118 the JIT-compiled function. This phase is not included in the trace 

119 event. 

120 **kwargs 

121 Additional arguments to pass to `jax.jit`. 

122 

123 Returns 

124 ------- 

125 Wrapped function. 

126 

127 Notes 

128 ----- 

129 Under profiling mode, the function invocation is handled such that a custom 

130 jax trace event with name `jab[<func_name>]` is created. The statistics on 

131 the actual Python function will be off, while the function 

132 `jab_inner_wrapper` represents the actual execution time. 

133 """ 

134 jitted_func = jit(func, **kwargs) 1aijk9lm

135 

136 event_name = f'jab[{func.__name__}]' 1aijk9lm

137 

138 # this wrapper is meant to measure the time spent executing the function 

139 def jab_inner_wrapper(*args, **kwargs) -> T: 1aijk9lm

140 with TraceAnnotation(event_name): 1cdbefghijklm

141 result = jitted_func(*args, **kwargs) 1cdbefghijklm

142 return block_until_ready(result) 1cdbefghiklm

143 

144 @wraps(func) 1aijk9lm

145 def jab_outer_wrapper(*args: Any, **kwargs: Any) -> T: 1aijk9lm

146 if get_profile_mode(): 2c d b e f g o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 ) * + , - . / : ; = ? @ n h [ ] ^ _ ubvbwbxbi j k 9 l m

147 if block_before: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true1cdbefghijklm

148 args, kwargs = block_until_ready((args, kwargs)) 

149 return jab_inner_wrapper(*args, **kwargs) 1cdbefghijklm

150 else: 

151 return func(*args, **kwargs) 2b o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 ) * + , - . / : ; = ? @ n [ ] ^ _ ubvbwbxbi j 9

152 

153 return jab_outer_wrapper 1aijk9lm

154 

155 

156def jit_if_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]: 

157 """Apply JIT compilation only when profiling. 

158 

159 Parameters 

160 ---------- 

161 func 

162 Function to wrap. 

163 *args 

164 **kwargs 

165 Additional arguments to pass to `jax.jit`. 

166 

167 Returns 

168 ------- 

169 Wrapped function. 

170 """ 

171 jitted_func = jit(func, *args, **kwargs) 

172 

173 @wraps(func) 

174 def wrapper(*args: Any, **kwargs: Any) -> T: 

175 if get_profile_mode(): 1cdbefgopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ012345678)*+,-./:;=?@nh[]^_

176 return jitted_func(*args, **kwargs) 1cdbefgh

177 else: 

178 return func(*args, **kwargs) 1bopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ012345678)*+,-./:;=?@n[]^_

179 

180 return wrapper 

181 

182 

183def jit_if_not_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]: 

184 """Apply JIT compilation only when not profiling. 

185 

186 When profile mode is off, the function is JIT compiled. When profile mode is 

187 on, the function runs as-is. 

188 

189 Parameters 

190 ---------- 

191 func 

192 Function to wrap. 

193 *args 

194 **kwargs 

195 Additional arguments to pass to `jax.jit`. 

196 

197 Returns 

198 ------- 

199 Wrapped function. 

200 """ 

201 jitted_func = jit(func, *args, **kwargs) 1a!$%

202 

203 @wraps(func) 1a!$%

204 def wrapper(*args: Any, **kwargs: Any) -> T: 1a!$%

205 if get_profile_mode(): 2c d b e f g o p q ybzbAbBbCbDbr s t abbbcbu v w x Eby z FbA B C D E F G GbHbIbJbKbLbMbNbObPbQbRbH I J SbTbUbK L M VbWbXbN O P Q R YbZb0bS T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb) * + , - . / : ; = ? @ n h [ ] ^ _ ! $ %

206 return func(*args, **kwargs) 1cdbefgh!%

207 else: 

208 return jitted_func(*args, **kwargs) 2c d b e f g o p q ybzbAbBbCbDbr s t abbbcbu v w x Eby z FbA B C D E F G GbHbIbJbKbLbMbNbObPbQbRbH I J SbTbUbK L M VbWbXbN O P Q R YbZb0bS T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb) * + , - . / : ; = ? @ n [ ] ^ _ ! $

209 

210 return wrapper 1a!$%

211 

212 

213def scan_if_not_profiling( 

214 f: Callable[[Carry, None], tuple[Carry, None]], 

215 init: Carry, 

216 xs: None, 

217 length: int, 

218 /, 

219) -> tuple[Carry, None]: 

220 """Restricted replacement for `jax.lax.scan` that uses a Python loop when profiling. 

221 

222 Parameters 

223 ---------- 

224 f 

225 Scan body function with signature (carry, None) -> (carry, None). 

226 init 

227 Initial carry value. 

228 xs 

229 Input values to scan over (not supported). 

230 length 

231 Integer specifying the number of loop iterations. 

232 

233 Returns 

234 ------- 

235 Tuple of (final_carry, None) (stacked outputs not supported). 

236 """ 

237 assert xs is None 2c d b e f g o p q r s t abbbcbu v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb# | '

238 if get_profile_mode(): 2c d b e f g o p q r s t abbbcbu v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb# | '

239 carry = init 1cdbefg#'

240 for _i in range(length): 1cdbefg#'

241 carry, _ = f(carry, None) 1cdbefg#'

242 return carry, None 1cdbefg#'

243 

244 else: 

245 return scan(f, init, None, length) 2b o p q r s t abbbcbu v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb# |

246 

247 

248def cond_if_not_profiling( 

249 pred: bool | Bool[Array, ''], 

250 true_fun: Callable[..., T], 

251 false_fun: Callable[..., T], 

252 /, 

253 *operands, 

254) -> T: 

255 """Restricted replacement for `jax.lax.cond` that uses a Python if when profiling. 

256 

257 Parameters 

258 ---------- 

259 pred 

260 Boolean predicate to choose which function to execute. 

261 true_fun 

262 Function to execute if `pred` is True. 

263 false_fun 

264 Function to execute if `pred` is False. 

265 *operands 

266 Arguments passed to `true_fun` and `false_fun`. 

267 

268 Returns 

269 ------- 

270 Result of either `true_fun()` or `false_fun()`. 

271 """ 

272 if get_profile_mode(): 2c d b e f g o p q r s t abbbcbu v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb( } ` ~ {

273 if pred: 1cdbefg(`{

274 return true_fun(*operands) 1cdbefg({

275 else: 

276 return false_fun(*operands) 1cdbefg(`

277 else: 

278 return cond(pred, true_fun, false_fun, *operands) 2b o p q r s t abbbcbu v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 dbebfbgbhbibjbkblbmbnbobpbqbrbsb( } ~

279 

280 

281def callback_if_not_profiling( 

282 callback: Callable[..., None], *args: Any, ordered: bool = False, **kwargs: Any 

283): 

284 """Restricted replacement for `jax.debug.callback` that calls the callback directly in profiling mode.""" 

285 if get_profile_mode(): 

286 callback(*args, **kwargs) 

287 else: 

288 debug.callback(callback, *args, ordered=ordered, **kwargs) 

289 

290 

291def vmap_chains_if_profiling(fun: Callable[..., T], **kwargs) -> Callable[..., T]: 

292 """Apply `vmap_chains` only when profile mode is enabled.""" 

293 new_fun = vmap_chains(fun, **kwargs) 

294 

295 @wraps(fun) 

296 def wrapper(*args, **kwargs): 

297 if get_profile_mode(): 2c d b o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 ) * + , - . / : ; = ? @ n h [ ] ^ _ ubvbwbxb

298 return new_fun(*args, **kwargs) 1cdbh

299 else: 

300 return fun(*args, **kwargs) 2b o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 ) * + , - . / : ; = ? @ n [ ] ^ _ ubvbwbxb

301 

302 return wrapper 

303 

304 

305def vmap_chains_if_not_profiling(fun: Callable[..., T], **kwargs) -> Callable[..., T]: 

306 """Apply `vmap_chains` only when profile mode is disabled.""" 

307 new_fun = vmap_chains(fun, **kwargs) 

308 

309 @wraps(fun) 

310 def wrapper(*args, **kwargs): 

311 if get_profile_mode(): 1cdbefgopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ012345678)*+,-./:;=?@nh[]^_

312 return fun(*args, **kwargs) 1cdbefgh

313 else: 

314 return new_fun(*args, **kwargs) 1bopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ012345678)*+,-./:;=?@n[]^_

315 

316 return wrapper