Coverage for src/bartz/mcmcloop/_loop.py: 99%
101 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/mcmcloop/_loop.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement `run_mcmc`, the MCMC loop driver."""
27from collections.abc import Callable
28from functools import partial, update_wrapper
29from typing import (
30 Any,
31 Generic,
32 NamedTuple,
33 Protocol,
34 TypeAlias,
35 TypeVar,
36 runtime_checkable,
37)
39from equinox import Module
40from jax import (
41 NamedSharding,
42 device_put,
43 eval_shape,
44 lax,
45 named_call,
46 random,
47 tree,
48 vmap,
49)
50from jax import numpy as jnp
51from jax.sharding import Mesh, PartitionSpec
52from jaxtyping import Array, Bool, Int32, Key, PyTree, Shaped
54from bartz._jaxext import jit, jit_active, split
55from bartz.mcmcloop._trace import BurninTrace, MainTrace
56from bartz.mcmcstep import State, step
57from bartz.mcmcstep._axes import trace_sample_axes
58from bartz.mcmcstep._lazy import add_dummy_axis
60# WORKAROUND(python<3.12): use `type CallbackState = PyTree[Any, 'T']`
61CallbackState: TypeAlias = PyTree[Any, 'T']
64class RunMCMCResult(NamedTuple):
65 """Return value of `run_mcmc`."""
67 final_state: State
68 """The final MCMC state."""
70 burnin_trace: BurninTrace
71 """The trace of the burn-in phase."""
73 main_trace: MainTrace
74 """The trace of the main phase."""
77@runtime_checkable
78class Callback(Protocol):
79 """Callback type for `run_mcmc`."""
81 def __call__(
82 self,
83 *,
84 key: Key[Array, ''],
85 state: State,
86 burnin: Bool[Array, ''],
87 i_total: Int32[Array, ''],
88 callback_state: CallbackState,
89 n_burn: Int32[Array, ''],
90 n_save: Int32[Array, ''],
91 n_skip: Int32[Array, ''],
92 i_outer: Int32[Array, ''],
93 inner_loop_length: Int32[Array, ''],
94 ) -> tuple[State, CallbackState] | None:
95 """Do an arbitrary action after an iteration of the MCMC.
97 Parameters
98 ----------
99 key
100 A key for random number generation.
101 state
102 The MCMC state just after updating it.
103 burnin
104 Whether the last iteration was in the burn-in phase.
105 i_total
106 The index of the last MCMC iteration (0-based).
107 callback_state
108 The callback state, initially set to the argument passed to
109 `run_mcmc`, afterwards to the value returned by the last invocation
110 of the callback.
111 n_burn
112 n_save
113 n_skip
114 The corresponding `run_mcmc` arguments as-is.
115 i_outer
116 The index of the last outer loop iteration (0-based).
117 inner_loop_length
118 The number of MCMC iterations in the inner loop.
120 Returns
121 -------
122 state : State
123 A possibly modified MCMC state. To avoid modifying the state,
124 return the `state` argument passed to the callback as-is.
125 callback_state : CallbackState
126 The new state to be passed on the next callback invocation.
128 Notes
129 -----
130 For convenience, the callback may return `None`, and the states won't
131 be updated.
132 """
133 ...
136class _Carry(Module):
137 """Carry used in the loop in `run_mcmc`."""
139 state: State
140 i_total: Int32[Array, '']
141 key: Key[Array, '']
142 burnin_trace: BurninTrace
143 main_trace: MainTrace
144 callback_state: CallbackState
147def run_mcmc(
148 key: Key[Array, ''],
149 state: State,
150 n_save: int,
151 *,
152 n_burn: int = 0,
153 n_skip: int = 1,
154 inner_loop_length: int | None = None,
155 callback: Callback | None = None,
156 callback_state: CallbackState = None,
157) -> RunMCMCResult:
158 """
159 Run the MCMC for the BART posterior.
161 Parameters
162 ----------
163 key
164 A key for random number generation.
165 state
166 The initial MCMC state, as created and updated by the functions in
167 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
168 so this variable is invalidated after running `run_mcmc`. Make a copy
169 beforehand to use it again.
170 n_save
171 The number of iterations to save.
172 n_burn
173 The number of initial iterations which are not saved.
174 n_skip
175 The number of iterations to skip between each saved iteration, plus 1.
176 The effective burn-in is ``n_burn + n_skip - 1``.
177 inner_loop_length
178 The MCMC loop is split into an outer and an inner loop. The outer loop
179 is in Python, while the inner loop is in JAX. `inner_loop_length` is the
180 number of iterations of the inner loop to run for each iteration of the
181 outer loop. If not specified, the outer loop will iterate just once,
182 with all iterations done in a single inner loop run. The inner stride is
183 unrelated to the stride used for saving the trace.
184 callback
185 An arbitrary function run during the loop after updating the state. For
186 the signature, see `Callback`. The callback is called under the jax jit,
187 so the argument values are not available at the time the Python code is
188 executed. Use the utilities in `jax.debug` to access the values at
189 actual runtime. The callback may return new values for the MCMC state
190 and the callback state.
191 callback_state
192 The initial custom state for the callback.
194 Returns
195 -------
196 A namedtuple with the final state, the burn-in trace, and the main trace.
198 Raises
199 ------
200 RuntimeError
201 If `run_mcmc` detects it's being invoked in a `jax.jit`-wrapped context and
202 with settings that would create unrolled loops in the trace.
204 Notes
205 -----
206 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
207 not include the initial state, and include the final state.
209 Resuming is exact: passing the returned `~RunMCMCResult.final_state` and the same `key` to
210 a new call continues the run as if it had not stopped, so splitting a run
211 into several consecutive calls gives the same result as a single call.
212 """
213 # copy the key so buffer donation does not invalidate the caller's copy
214 key = jnp.copy(key)
216 # create empty traces
217 burnin_trace = _empty_trace(n_burn, state, BurninTrace)
218 main_trace = _empty_trace(n_save, state, MainTrace)
220 # determine number of iterations for inner and outer loops
221 n_iters = n_burn + n_skip * n_save
222 if inner_loop_length is None:
223 inner_loop_length = n_iters
224 if inner_loop_length:
225 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length)
226 else:
227 n_outer = 1
228 # setting to 0 would make for a clean noop, but it's useful to keep the
229 # same code path for benchmarking and testing
231 # error if under jit and there are unrolled loops
232 if jit_active() and n_outer > 1:
233 msg = (
234 '`run_mcmc` was called within a jit-compiled function and '
235 'there are more than 1 outer loops, '
236 'please either do not jit or set `inner_loop_length=None`'
237 )
238 raise RuntimeError(msg)
240 replicate = partial(_replicate, mesh=state.config.mesh)
241 carry = _Carry(
242 state,
243 replicate(jnp.int32(0)),
244 replicate(key),
245 burnin_trace,
246 main_trace,
247 callback_state,
248 )
249 _inner_loop_counter.reset_call_counter()
250 for i_outer in range(n_outer):
251 carry = _run_mcmc_inner_loop(
252 carry, inner_loop_length, callback, n_burn, n_save, n_skip, i_outer, n_iters
253 )
255 return RunMCMCResult(carry.state, carry.burnin_trace, carry.main_trace)
258def _replicate(
259 x: Shaped[Array, '*shape'], mesh: Mesh | None
260) -> Shaped[Array, '*shape']:
261 if mesh is None:
262 return x
263 else:
264 return device_put(x, NamedSharding(mesh, PartitionSpec()))
267TraceT = TypeVar('TraceT', bound=BurninTrace)
270@jit(static_argnums=(0, 2))
271def _empty_trace(length: int, state: State, trace_cls: type[TraceT]) -> TraceT:
272 example_output = eval_shape(trace_cls.from_state, state)
273 out_axes = trace_sample_axes(add_dummy_axis(example_output))
275 return vmap(
276 trace_cls.from_state, in_axes=None, out_axes=out_axes, axis_size=length
277 )(state)
280T = TypeVar('T')
283class _CallCounter(Generic[T]):
284 """Wrap a callable to check it's not called more than once."""
286 def __init__(self, func: Callable[..., T]) -> None:
287 self.func = func
288 self.n_calls = 0
289 update_wrapper(self, func)
291 def reset_call_counter(self) -> None:
292 """Reset the call counter."""
293 self.n_calls = 0
295 def __call__(self, *args: Any, **kwargs: Any) -> T:
296 if self.n_calls:
297 msg = (
298 'The inner loop of `run_mcmc` was traced more than once, '
299 'which indicates a double compilation of the MCMC code. This '
300 'probably depends on the input state having different type from the '
301 'output state. Check the input is in a format that is the '
302 'same jax would output, e.g., all arrays and scalars are jax '
303 'arrays, with the right shardings.'
304 )
305 raise RuntimeError(msg)
306 self.n_calls += 1
307 return self.func(*args, **kwargs)
310def _run_mcmc_inner_loop_impl(
311 carry: _Carry,
312 inner_loop_length: Int32[Array, ''],
313 callback: Callback | None,
314 n_burn: Int32[Array, ''],
315 n_save: Int32[Array, ''],
316 n_skip: Int32[Array, ''],
317 i_outer: Int32[Array, ''],
318 n_iters: Int32[Array, ''],
319) -> _Carry:
320 # determine number of iterations for this loop batch
321 i_upper = jnp.minimum(carry.i_total + inner_loop_length, n_iters)
323 def cond(carry: _Carry) -> Bool[Array, '']:
324 """Whether to continue the MCMC loop."""
325 return carry.i_total < i_upper
327 def body(carry: _Carry) -> _Carry:
328 """Update the MCMC state."""
329 iter_key = random.fold_in(carry.key, carry.state.config.steps_done)
330 keys = split(iter_key, 2)
332 # update state
333 state = step(keys.pop(), carry.state)
335 # invoke callback
336 callback_state = carry.callback_state
337 if callback is not None:
338 rt = callback(
339 key=keys.pop(),
340 state=state,
341 burnin=carry.i_total < n_burn,
342 i_total=carry.i_total,
343 callback_state=callback_state,
344 n_burn=n_burn,
345 n_save=n_save,
346 n_skip=n_skip,
347 i_outer=i_outer,
348 inner_loop_length=inner_loop_length,
349 )
350 if rt is not None: 350 ↛ 354line 350 didn't jump to line 354 because the condition on line 350 was always true
351 state, callback_state = rt
353 # save to trace
354 burnin_trace, main_trace = _save_state_to_trace(
355 carry.burnin_trace, carry.main_trace, state, carry.i_total, n_burn, n_skip
356 )
358 return _Carry(
359 state=state,
360 i_total=carry.i_total + 1,
361 key=carry.key,
362 burnin_trace=burnin_trace,
363 main_trace=main_trace,
364 callback_state=callback_state,
365 )
367 return lax.while_loop(cond, body, carry)
370# Wrap the inner loop in an explicit `_CallCounter`, kept in `_inner_loop_counter`
371# so `run_mcmc` can reset it directly instead of reaching into jit internals,
372# then jit the wrapped callable.
373_inner_loop_counter: _CallCounter[_Carry] = _CallCounter(_run_mcmc_inner_loop_impl)
374_run_mcmc_inner_loop = jit(donate_argnums=(0,), static_argnums=(2,))(
375 _inner_loop_counter
376)
379@named_call
380def _save_state_to_trace(
381 burnin_trace: BurninTrace,
382 main_trace: MainTrace,
383 state: State,
384 i_total: Int32[Array, ''],
385 n_burn: Int32[Array, ''],
386 n_skip: Int32[Array, ''],
387) -> tuple[BurninTrace, MainTrace]:
388 # trace index where to save during burnin; out-of-bounds => noop after
389 # burnin
390 burnin_idx = i_total
392 # trace index where to save during main phase; force it out-of-bounds
393 # during burnin
394 main_idx = (i_total - n_burn) // n_skip
395 noop_idx = jnp.iinfo(jnp.int32).max
396 noop_cond = i_total < n_burn
397 main_idx = jnp.where(noop_cond, noop_idx, main_idx)
399 # prepare array index
400 burnin_trace = _set(burnin_trace, burnin_idx, BurninTrace.from_state(state))
401 main_trace = _set(main_trace, main_idx, MainTrace.from_state(state))
403 return burnin_trace, main_trace
406def _set(
407 trace: PyTree[Array, ' T'], index: Int32[Array, ''], val: PyTree[Array, ' T']
408) -> PyTree[Array, ' T']:
409 """Do ``trace[index] = val`` but fancier."""
410 # WORKAROUND(jax<0.7.1): once we bump jax to v0.7.1 we can use mutable
411 # arrays to save the trace instead of this functional update.
412 sample_axes = trace_sample_axes(trace)
414 # `trace` is `(*chains, samples, *shape)` and `val` the same without the
415 # `samples` axis. The optional `chains` axis cannot share an annotation with
416 # the variadic `*shape` (two variadics are ambiguous), and a union of the
417 # with/without-`chains` layouts is rank-ambiguous under the runtime checker,
418 # so the trace/val shapes are kept independent; their relationship is
419 # enforced dynamically. The return has the trace shape.
420 def at_set(
421 trace: Shaped[Array, '*chains_samples_core'],
422 val: Shaped[Array, '*chains_core'],
423 sample_axis: int | None,
424 ) -> Shaped[Array, '*chains_samples_core']:
425 if sample_axis is None or trace.size == 0:
426 # `sample_axis is None`: fields without a `samples` marker have
427 # no per-iteration slot to update.
428 # `trace.size == 0`: jax refuses to index into an axis of length
429 # 0, even in the abstract.
430 return trace
432 ndindex = (slice(None),) * sample_axis + (index, ...)
433 return trace.at[ndindex].set(val, mode='drop')
435 return tree.map(at_set, trace, val, sample_axes)