Coverage for src/bartz/mcmcloop/_loop.py: 99%

101 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/mcmcloop/_loop.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement `run_mcmc`, the MCMC loop driver.""" 

26 

27from collections.abc import Callable 

28from functools import partial, update_wrapper 

29from typing import ( 

30 Any, 

31 Generic, 

32 NamedTuple, 

33 Protocol, 

34 TypeAlias, 

35 TypeVar, 

36 runtime_checkable, 

37) 

38 

39from equinox import Module 

40from jax import ( 

41 NamedSharding, 

42 device_put, 

43 eval_shape, 

44 lax, 

45 named_call, 

46 random, 

47 tree, 

48 vmap, 

49) 

50from jax import numpy as jnp 

51from jax.sharding import Mesh, PartitionSpec 

52from jaxtyping import Array, Bool, Int32, Key, PyTree, Shaped 

53 

54from bartz._jaxext import jit, jit_active, split 

55from bartz.mcmcloop._trace import BurninTrace, MainTrace 

56from bartz.mcmcstep import State, step 

57from bartz.mcmcstep._axes import trace_sample_axes 

58from bartz.mcmcstep._lazy import add_dummy_axis 

59 

60# WORKAROUND(python<3.12): use `type CallbackState = PyTree[Any, 'T']` 

61CallbackState: TypeAlias = PyTree[Any, 'T'] 

62 

63 

64class RunMCMCResult(NamedTuple): 

65 """Return value of `run_mcmc`.""" 

66 

67 final_state: State 

68 """The final MCMC state.""" 

69 

70 burnin_trace: BurninTrace 

71 """The trace of the burn-in phase.""" 

72 

73 main_trace: MainTrace 

74 """The trace of the main phase.""" 

75 

76 

77@runtime_checkable 

78class Callback(Protocol): 

79 """Callback type for `run_mcmc`.""" 

80 

81 def __call__( 

82 self, 

83 *, 

84 key: Key[Array, ''], 

85 state: State, 

86 burnin: Bool[Array, ''], 

87 i_total: Int32[Array, ''], 

88 callback_state: CallbackState, 

89 n_burn: Int32[Array, ''], 

90 n_save: Int32[Array, ''], 

91 n_skip: Int32[Array, ''], 

92 i_outer: Int32[Array, ''], 

93 inner_loop_length: Int32[Array, ''], 

94 ) -> tuple[State, CallbackState] | None: 

95 """Do an arbitrary action after an iteration of the MCMC. 

96 

97 Parameters 

98 ---------- 

99 key 

100 A key for random number generation. 

101 state 

102 The MCMC state just after updating it. 

103 burnin 

104 Whether the last iteration was in the burn-in phase. 

105 i_total 

106 The index of the last MCMC iteration (0-based). 

107 callback_state 

108 The callback state, initially set to the argument passed to 

109 `run_mcmc`, afterwards to the value returned by the last invocation 

110 of the callback. 

111 n_burn 

112 n_save 

113 n_skip 

114 The corresponding `run_mcmc` arguments as-is. 

115 i_outer 

116 The index of the last outer loop iteration (0-based). 

117 inner_loop_length 

118 The number of MCMC iterations in the inner loop. 

119 

120 Returns 

121 ------- 

122 state : State 

123 A possibly modified MCMC state. To avoid modifying the state, 

124 return the `state` argument passed to the callback as-is. 

125 callback_state : CallbackState 

126 The new state to be passed on the next callback invocation. 

127 

128 Notes 

129 ----- 

130 For convenience, the callback may return `None`, and the states won't 

131 be updated. 

132 """ 

133 ... 

134 

135 

136class _Carry(Module): 

137 """Carry used in the loop in `run_mcmc`.""" 

138 

139 state: State 

140 i_total: Int32[Array, ''] 

141 key: Key[Array, ''] 

142 burnin_trace: BurninTrace 

143 main_trace: MainTrace 

144 callback_state: CallbackState 

145 

146 

147def run_mcmc( 

148 key: Key[Array, ''], 

149 state: State, 

150 n_save: int, 

151 *, 

152 n_burn: int = 0, 

153 n_skip: int = 1, 

154 inner_loop_length: int | None = None, 

155 callback: Callback | None = None, 

156 callback_state: CallbackState = None, 

157) -> RunMCMCResult: 

158 """ 

159 Run the MCMC for the BART posterior. 

160 

161 Parameters 

162 ---------- 

163 key 

164 A key for random number generation. 

165 state 

166 The initial MCMC state, as created and updated by the functions in 

167 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies, 

168 so this variable is invalidated after running `run_mcmc`. Make a copy 

169 beforehand to use it again. 

170 n_save 

171 The number of iterations to save. 

172 n_burn 

173 The number of initial iterations which are not saved. 

174 n_skip 

175 The number of iterations to skip between each saved iteration, plus 1. 

176 The effective burn-in is ``n_burn + n_skip - 1``. 

177 inner_loop_length 

178 The MCMC loop is split into an outer and an inner loop. The outer loop 

179 is in Python, while the inner loop is in JAX. `inner_loop_length` is the 

180 number of iterations of the inner loop to run for each iteration of the 

181 outer loop. If not specified, the outer loop will iterate just once, 

182 with all iterations done in a single inner loop run. The inner stride is 

183 unrelated to the stride used for saving the trace. 

184 callback 

185 An arbitrary function run during the loop after updating the state. For 

186 the signature, see `Callback`. The callback is called under the jax jit, 

187 so the argument values are not available at the time the Python code is 

188 executed. Use the utilities in `jax.debug` to access the values at 

189 actual runtime. The callback may return new values for the MCMC state 

190 and the callback state. 

191 callback_state 

192 The initial custom state for the callback. 

193 

194 Returns 

195 ------- 

196 A namedtuple with the final state, the burn-in trace, and the main trace. 

197 

198 Raises 

199 ------ 

200 RuntimeError 

201 If `run_mcmc` detects it's being invoked in a `jax.jit`-wrapped context and 

202 with settings that would create unrolled loops in the trace. 

203 

204 Notes 

205 ----- 

206 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do 

207 not include the initial state, and include the final state. 

208 

209 Resuming is exact: passing the returned `~RunMCMCResult.final_state` and the same `key` to 

210 a new call continues the run as if it had not stopped, so splitting a run 

211 into several consecutive calls gives the same result as a single call. 

212 """ 

213 # copy the key so buffer donation does not invalidate the caller's copy 

214 key = jnp.copy(key) 

215 

216 # create empty traces 

217 burnin_trace = _empty_trace(n_burn, state, BurninTrace) 

218 main_trace = _empty_trace(n_save, state, MainTrace) 

219 

220 # determine number of iterations for inner and outer loops 

221 n_iters = n_burn + n_skip * n_save 

222 if inner_loop_length is None: 

223 inner_loop_length = n_iters 

224 if inner_loop_length: 

225 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length) 

226 else: 

227 n_outer = 1 

228 # setting to 0 would make for a clean noop, but it's useful to keep the 

229 # same code path for benchmarking and testing 

230 

231 # error if under jit and there are unrolled loops 

232 if jit_active() and n_outer > 1: 

233 msg = ( 

234 '`run_mcmc` was called within a jit-compiled function and ' 

235 'there are more than 1 outer loops, ' 

236 'please either do not jit or set `inner_loop_length=None`' 

237 ) 

238 raise RuntimeError(msg) 

239 

240 replicate = partial(_replicate, mesh=state.config.mesh) 

241 carry = _Carry( 

242 state, 

243 replicate(jnp.int32(0)), 

244 replicate(key), 

245 burnin_trace, 

246 main_trace, 

247 callback_state, 

248 ) 

249 _inner_loop_counter.reset_call_counter() 

250 for i_outer in range(n_outer): 

251 carry = _run_mcmc_inner_loop( 

252 carry, inner_loop_length, callback, n_burn, n_save, n_skip, i_outer, n_iters 

253 ) 

254 

255 return RunMCMCResult(carry.state, carry.burnin_trace, carry.main_trace) 

256 

257 

258def _replicate( 

259 x: Shaped[Array, '*shape'], mesh: Mesh | None 

260) -> Shaped[Array, '*shape']: 

261 if mesh is None: 

262 return x 

263 else: 

264 return device_put(x, NamedSharding(mesh, PartitionSpec())) 

265 

266 

267TraceT = TypeVar('TraceT', bound=BurninTrace) 

268 

269 

270@jit(static_argnums=(0, 2)) 

271def _empty_trace(length: int, state: State, trace_cls: type[TraceT]) -> TraceT: 

272 example_output = eval_shape(trace_cls.from_state, state) 

273 out_axes = trace_sample_axes(add_dummy_axis(example_output)) 

274 

275 return vmap( 

276 trace_cls.from_state, in_axes=None, out_axes=out_axes, axis_size=length 

277 )(state) 

278 

279 

280T = TypeVar('T') 

281 

282 

283class _CallCounter(Generic[T]): 

284 """Wrap a callable to check it's not called more than once.""" 

285 

286 def __init__(self, func: Callable[..., T]) -> None: 

287 self.func = func 

288 self.n_calls = 0 

289 update_wrapper(self, func) 

290 

291 def reset_call_counter(self) -> None: 

292 """Reset the call counter.""" 

293 self.n_calls = 0 

294 

295 def __call__(self, *args: Any, **kwargs: Any) -> T: 

296 if self.n_calls: 

297 msg = ( 

298 'The inner loop of `run_mcmc` was traced more than once, ' 

299 'which indicates a double compilation of the MCMC code. This ' 

300 'probably depends on the input state having different type from the ' 

301 'output state. Check the input is in a format that is the ' 

302 'same jax would output, e.g., all arrays and scalars are jax ' 

303 'arrays, with the right shardings.' 

304 ) 

305 raise RuntimeError(msg) 

306 self.n_calls += 1 

307 return self.func(*args, **kwargs) 

308 

309 

310def _run_mcmc_inner_loop_impl( 

311 carry: _Carry, 

312 inner_loop_length: Int32[Array, ''], 

313 callback: Callback | None, 

314 n_burn: Int32[Array, ''], 

315 n_save: Int32[Array, ''], 

316 n_skip: Int32[Array, ''], 

317 i_outer: Int32[Array, ''], 

318 n_iters: Int32[Array, ''], 

319) -> _Carry: 

320 # determine number of iterations for this loop batch 

321 i_upper = jnp.minimum(carry.i_total + inner_loop_length, n_iters) 

322 

323 def cond(carry: _Carry) -> Bool[Array, '']: 

324 """Whether to continue the MCMC loop.""" 

325 return carry.i_total < i_upper 

326 

327 def body(carry: _Carry) -> _Carry: 

328 """Update the MCMC state.""" 

329 iter_key = random.fold_in(carry.key, carry.state.config.steps_done) 

330 keys = split(iter_key, 2) 

331 

332 # update state 

333 state = step(keys.pop(), carry.state) 

334 

335 # invoke callback 

336 callback_state = carry.callback_state 

337 if callback is not None: 

338 rt = callback( 

339 key=keys.pop(), 

340 state=state, 

341 burnin=carry.i_total < n_burn, 

342 i_total=carry.i_total, 

343 callback_state=callback_state, 

344 n_burn=n_burn, 

345 n_save=n_save, 

346 n_skip=n_skip, 

347 i_outer=i_outer, 

348 inner_loop_length=inner_loop_length, 

349 ) 

350 if rt is not None: 350 ↛ 354line 350 didn't jump to line 354 because the condition on line 350 was always true

351 state, callback_state = rt 

352 

353 # save to trace 

354 burnin_trace, main_trace = _save_state_to_trace( 

355 carry.burnin_trace, carry.main_trace, state, carry.i_total, n_burn, n_skip 

356 ) 

357 

358 return _Carry( 

359 state=state, 

360 i_total=carry.i_total + 1, 

361 key=carry.key, 

362 burnin_trace=burnin_trace, 

363 main_trace=main_trace, 

364 callback_state=callback_state, 

365 ) 

366 

367 return lax.while_loop(cond, body, carry) 

368 

369 

370# Wrap the inner loop in an explicit `_CallCounter`, kept in `_inner_loop_counter` 

371# so `run_mcmc` can reset it directly instead of reaching into jit internals, 

372# then jit the wrapped callable. 

373_inner_loop_counter: _CallCounter[_Carry] = _CallCounter(_run_mcmc_inner_loop_impl) 

374_run_mcmc_inner_loop = jit(donate_argnums=(0,), static_argnums=(2,))( 

375 _inner_loop_counter 

376) 

377 

378 

379@named_call 

380def _save_state_to_trace( 

381 burnin_trace: BurninTrace, 

382 main_trace: MainTrace, 

383 state: State, 

384 i_total: Int32[Array, ''], 

385 n_burn: Int32[Array, ''], 

386 n_skip: Int32[Array, ''], 

387) -> tuple[BurninTrace, MainTrace]: 

388 # trace index where to save during burnin; out-of-bounds => noop after 

389 # burnin 

390 burnin_idx = i_total 

391 

392 # trace index where to save during main phase; force it out-of-bounds 

393 # during burnin 

394 main_idx = (i_total - n_burn) // n_skip 

395 noop_idx = jnp.iinfo(jnp.int32).max 

396 noop_cond = i_total < n_burn 

397 main_idx = jnp.where(noop_cond, noop_idx, main_idx) 

398 

399 # prepare array index 

400 burnin_trace = _set(burnin_trace, burnin_idx, BurninTrace.from_state(state)) 

401 main_trace = _set(main_trace, main_idx, MainTrace.from_state(state)) 

402 

403 return burnin_trace, main_trace 

404 

405 

406def _set( 

407 trace: PyTree[Array, ' T'], index: Int32[Array, ''], val: PyTree[Array, ' T'] 

408) -> PyTree[Array, ' T']: 

409 """Do ``trace[index] = val`` but fancier.""" 

410 # WORKAROUND(jax<0.7.1): once we bump jax to v0.7.1 we can use mutable 

411 # arrays to save the trace instead of this functional update. 

412 sample_axes = trace_sample_axes(trace) 

413 

414 # `trace` is `(*chains, samples, *shape)` and `val` the same without the 

415 # `samples` axis. The optional `chains` axis cannot share an annotation with 

416 # the variadic `*shape` (two variadics are ambiguous), and a union of the 

417 # with/without-`chains` layouts is rank-ambiguous under the runtime checker, 

418 # so the trace/val shapes are kept independent; their relationship is 

419 # enforced dynamically. The return has the trace shape. 

420 def at_set( 

421 trace: Shaped[Array, '*chains_samples_core'], 

422 val: Shaped[Array, '*chains_core'], 

423 sample_axis: int | None, 

424 ) -> Shaped[Array, '*chains_samples_core']: 

425 if sample_axis is None or trace.size == 0: 

426 # `sample_axis is None`: fields without a `samples` marker have 

427 # no per-iteration slot to update. 

428 # `trace.size == 0`: jax refuses to index into an axis of length 

429 # 0, even in the abstract. 

430 return trace 

431 

432 ndindex = (slice(None),) * sample_axis + (index, ...) 

433 return trace.at[ndindex].set(val, mode='drop') 

434 

435 return tree.map(at_set, trace, val, sample_axes)