Coverage for src / bartz / mcmcloop.py: 96%
227 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/mcmcloop.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions that implement the full BART posterior MCMC loop.
27The entry points are `run_mcmc` and `make_default_callback`.
28"""
30from collections.abc import Callable
31from dataclasses import fields
32from functools import partial, update_wrapper, wraps
33from math import floor
34from typing import Any, NamedTuple, Protocol, TypeVar
36import jax
37import numpy
38from equinox import Module
39from jax import (
40 NamedSharding,
41 ShapeDtypeStruct,
42 debug,
43 device_put,
44 eval_shape,
45 jit,
46 lax,
47 named_call,
48 tree,
49)
50from jax import numpy as jnp
51from jax.nn import softmax
52from jax.sharding import Mesh, PartitionSpec
53from jaxtyping import (
54 Array,
55 ArrayLike,
56 Bool,
57 Float32,
58 Int32,
59 Integer,
60 Key,
61 PyTree,
62 Shaped,
63 UInt,
64)
66from bartz import jaxext, mcmcstep
67from bartz.grove import TreeHeaps, evaluate_forest, forest_fill, var_histogram
68from bartz.jaxext import autobatch, jit_active
69from bartz.mcmcstep import State
70from bartz.mcmcstep._state import chain_vmap_axes, field, get_axis_size, get_num_chains
73class BurninTrace(Module):
74 """MCMC trace with only diagnostic values."""
76 error_cov_inv: (
77 Float32[Array, '*chains_and_samples']
78 | Float32[Array, '*chains_and_samples k k']
79 | None
80 ) = field(chains=True)
81 theta: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
82 grow_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
83 grow_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
84 prune_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
85 prune_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
86 log_likelihood: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
87 log_trans_prior: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
89 @classmethod
90 def from_state(cls, state: State) -> 'BurninTrace':
91 """Create a single-item burn-in trace from a MCMC state."""
92 return cls( 1BqfcbjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./
93 error_cov_inv=state.error_cov_inv,
94 theta=state.forest.theta,
95 grow_prop_count=state.forest.grow_prop_count,
96 grow_acc_count=state.forest.grow_acc_count,
97 prune_prop_count=state.forest.prune_prop_count,
98 prune_acc_count=state.forest.prune_acc_count,
99 log_likelihood=state.forest.log_likelihood,
100 log_trans_prior=state.forest.log_trans_prior,
101 )
104class MainTrace(BurninTrace):
105 """MCMC trace with trees and diagnostic values."""
107 leaf_tree: (
108 Float32[Array, '*chains_and_samples 2**d']
109 | Float32[Array, '*chains_and_samples k 2**d']
110 ) = field(chains=True)
111 var_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
112 split_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
113 offset: Float32[Array, '*samples'] | Float32[Array, '*samples k']
114 varprob: Float32[Array, '*chains_and_samples p'] | None = field(chains=True)
116 @classmethod
117 def from_state(cls, state: State) -> 'MainTrace':
118 """Create a single-item main trace from a MCMC state."""
119 # compute varprob
120 log_s = state.forest.log_s 1BqfcbjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./
121 if log_s is None: 1Bqfc@bjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./
122 varprob = None 1BjghkKmzGxstoI[1+WXYZQRS0TUV2345}6789!#$%'()*,-./
123 else:
124 varprob = softmax(log_s, where=state.forest.max_split.astype(bool)) 1qfc@badeulMLyrNEOFvn;:wHipPJCD
126 return cls( 1BqfcbjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./
127 leaf_tree=state.forest.leaf_tree,
128 var_tree=state.forest.var_tree,
129 split_tree=state.forest.split_tree,
130 offset=state.offset,
131 varprob=varprob,
132 **vars(BurninTrace.from_state(state)),
133 )
136CallbackState = PyTree[Any, 'T']
139class RunMCMCResult(NamedTuple):
140 """Return value of `run_mcmc`."""
142 final_state: State
143 """The final MCMC state."""
145 burnin_trace: PyTree[
146 Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']
147 ]
148 """The trace of the burn-in phase. For the default layout, see `BurninTrace`."""
150 main_trace: PyTree[
151 Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']
152 ]
153 """The trace of the main phase. For the default layout, see `MainTrace`."""
156class Callback(Protocol):
157 """Callback type for `run_mcmc`."""
159 def __call__(
160 self,
161 *,
162 key: Key[Array, ''],
163 bart: State,
164 burnin: Bool[Array, ''],
165 i_total: Int32[Array, ''],
166 callback_state: CallbackState,
167 n_burn: Int32[Array, ''],
168 n_save: Int32[Array, ''],
169 n_skip: Int32[Array, ''],
170 i_outer: Int32[Array, ''],
171 inner_loop_length: Int32[Array, ''],
172 ) -> tuple[State, CallbackState] | None:
173 """Do an arbitrary action after an iteration of the MCMC.
175 Parameters
176 ----------
177 key
178 A key for random number generation.
179 bart
180 The MCMC state just after updating it.
181 burnin
182 Whether the last iteration was in the burn-in phase.
183 i_total
184 The index of the last MCMC iteration (0-based).
185 callback_state
186 The callback state, initially set to the argument passed to
187 `run_mcmc`, afterwards to the value returned by the last invocation
188 of the callback.
189 n_burn
190 n_save
191 n_skip
192 The corresponding `run_mcmc` arguments as-is.
193 i_outer
194 The index of the last outer loop iteration (0-based).
195 inner_loop_length
196 The number of MCMC iterations in the inner loop.
198 Returns
199 -------
200 bart : State
201 A possibly modified MCMC state. To avoid modifying the state,
202 return the `bart` argument passed to the callback as-is.
203 callback_state : CallbackState
204 The new state to be passed on the next callback invocation.
206 Notes
207 -----
208 For convenience, the callback may return `None`, and the states won't
209 be updated.
210 """
211 ...
214class _Carry(Module):
215 """Carry used in the loop in `run_mcmc`."""
217 bart: State
218 i_total: Int32[Array, '']
219 key: Key[Array, '']
220 burnin_trace: PyTree[
221 Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']
222 ]
223 main_trace: PyTree[
224 Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']
225 ]
226 callback_state: CallbackState
229def run_mcmc(
230 key: Key[Array, ''],
231 bart: State,
232 n_save: int,
233 *,
234 n_burn: int = 0,
235 n_skip: int = 1,
236 inner_loop_length: int | None = None,
237 callback: Callback | None = None,
238 callback_state: CallbackState = None,
239 burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
240 main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
241) -> RunMCMCResult:
242 """
243 Run the MCMC for the BART posterior.
245 Parameters
246 ----------
247 key
248 A key for random number generation.
249 bart
250 The initial MCMC state, as created and updated by the functions in
251 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
252 so this variable is invalidated after running `run_mcmc`. Make a copy
253 beforehand to use it again.
254 n_save
255 The number of iterations to save.
256 n_burn
257 The number of initial iterations which are not saved.
258 n_skip
259 The number of iterations to skip between each saved iteration, plus 1.
260 The effective burn-in is ``n_burn + n_skip - 1``.
261 inner_loop_length
262 The MCMC loop is split into an outer and an inner loop. The outer loop
263 is in Python, while the inner loop is in JAX. `inner_loop_length` is the
264 number of iterations of the inner loop to run for each iteration of the
265 outer loop. If not specified, the outer loop will iterate just once,
266 with all iterations done in a single inner loop run. The inner stride is
267 unrelated to the stride used for saving the trace.
268 callback
269 An arbitrary function run during the loop after updating the state. For
270 the signature, see `Callback`. The callback is called under the jax jit,
271 so the argument values are not available at the time the Python code is
272 executed. Use the utilities in `jax.debug` to access the values at
273 actual runtime. The callback may return new values for the MCMC state
274 and the callback state.
275 callback_state
276 The initial custom state for the callback.
277 burnin_extractor
278 main_extractor
279 Functions that extract the variables to be saved respectively in the
280 burnin trace and main traces, given the MCMC state as argument. Must
281 return a pytree, and must be vmappable.
283 Returns
284 -------
285 A namedtuple with the final state, the burn-in trace, and the main trace.
287 Raises
288 ------
289 RuntimeError
290 If `run_mcmc` detects it's being invoked in a `jit`-wrapped context and
291 with settings that would create unrolled loops in the trace.
293 Notes
294 -----
295 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
296 not include the initial state, and include the final state.
297 """
298 # create empty traces
299 burnin_trace = _empty_trace(n_burn, bart, burnin_extractor) 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 } 6 7 8 9 ! # $ % ' ( ) * , - . /
300 main_trace = _empty_trace(n_save, bart, main_extractor) 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 } 6 7 8 9 ! # $ % ' ( ) * , - . /
302 # determine number of iterations for inner and outer loops
303 n_iters = n_burn + n_skip * n_save 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 } 6 7 8 9 ! # $ % ' ( ) * , - . /
304 if inner_loop_length is None: 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 } 6 7 8 9 ! # $ % ' ( ) * , - . /
305 inner_loop_length = n_iters 2abf j g bbcbh k y dbr ebfbm z G ~ ibjbkbx s t lbo mbI 6 7 8 9 ! # $ % ' ( ) * , - . /
306 if inner_loop_length: 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 } 6 7 8 9 ! # $ % ' ( ) * , - . /
307 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length) 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 }
308 else:
309 n_outer = 1 1az6789!#$%'()*,-./
310 # setting to 0 would make for a clean noop, but it's useful to keep the
311 # same code path for benchmarking and testing
313 # error if under jit and there are unrolled loops
314 if jit_active() and n_outer > 1: 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 } 6 7 8 9 ! # $ % ' ( ) * , - . /
315 msg = ( 1}
316 '`run_mcmc` was called within a jit-compiled function and '
317 'there are more than 1 outer loops, '
318 'please either do not jit or set `inner_loop_length=None`'
319 )
320 raise RuntimeError(msg) 1}
322 replicate = partial(_replicate, mesh=bart.config.mesh) 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /
323 carry = _Carry( 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /
324 bart,
325 replicate(jnp.int32(0)),
326 replicate(key),
327 burnin_trace,
328 main_trace,
329 callback_state,
330 )
331 _run_mcmc_inner_loop._fun.reset_call_counter() # noqa: SLF001 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /
332 for i_outer in range(n_outer): 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /
333 carry = _run_mcmc_inner_loop( 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n z w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /
334 carry,
335 inner_loop_length,
336 callback,
337 burnin_extractor,
338 main_extractor,
339 n_burn,
340 n_save,
341 n_skip,
342 i_outer,
343 n_iters,
344 )
346 return RunMCMCResult(carry.bart, carry.burnin_trace, carry.main_trace) 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /
349def _replicate(x: Array, mesh: Mesh | None) -> Array:
350 if mesh is None: 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /
351 return x 2abf j g bbcbd h e k K dbebfbv m z w G ~ ibjbkbx s t lbo mbI 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /
352 else:
353 return device_put(x, NamedSharding(mesh, PartitionSpec())) 2nbobB q c b a pbqbrbsbtbubd h e u l M L y r N E O F n ; : H gbhbvbwbxbyb= ? zbAbi p BbCbP J C D [
356@partial(jit, static_argnums=(0, 2))
357def _empty_trace(
358 length: int, bart: State, extractor: Callable[[State], PyTree]
359) -> PyTree:
360 num_chains = get_num_chains(bart) 1BqfcbjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./
361 if num_chains is None: 1BqfcbjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./
362 out_axes = 0 1quMyNOv;wxstiP12345},-./
363 else:
364 example_output = eval_shape(extractor, bart) 1BfcbjagdheklKLrEFmnz:GHopIJCD[+WXYZQRS0TUV6789!#$%'()*
365 chain_axes = chain_vmap_axes(example_output) 1BfcbjagdheklKLrEFmnz:GHopIJCD[+WXYZQRS0TUV6789!#$%'()*
366 out_axes = tree.map( 1BfcbjagdheklKLrEFmnz:GHopIJCD[+WXYZQRS0TUV6789!#$%'()*
367 lambda a: 0 if a is None else 1, chain_axes, is_leaf=lambda a: a is None
368 )
369 return jax.vmap(extractor, in_axes=None, out_axes=out_axes, axis_size=length)(bart) 1BqfcbjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./
372T = TypeVar('T')
375class _CallCounter:
376 """Wrap a callable to check it's not called more than once."""
378 def __init__(self, func: Callable[..., T]) -> None:
379 self.func = func
380 self.n_calls = 0
381 update_wrapper(self, func)
383 def reset_call_counter(self) -> None:
384 """Reset the call counter."""
385 self.n_calls = 0 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /
387 def __call__(self, *args: Any, **kwargs: Any) -> T:
388 if self.n_calls: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
389 msg = ( 11
390 'The inner loop of `run_mcmc` was traced more than once, '
391 'which indicates a double compilation of the MCMC code. This '
392 'probably depends on the input state having different type from the '
393 'output state. Check the input is in a format that is the '
394 'same jax would output, e.g., all arrays and scalars are jax '
395 'arrays, with the right shardings.'
396 )
397 raise RuntimeError(msg) 11
398 self.n_calls += 1 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
399 return self.func(*args, **kwargs) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
402@partial(jit, donate_argnums=(0,), static_argnums=(2, 3, 4))
403@_CallCounter
404def _run_mcmc_inner_loop(
405 carry: _Carry,
406 inner_loop_length: Int32[Array, ''],
407 callback: Callback | None,
408 burnin_extractor: Callable[[State], PyTree],
409 main_extractor: Callable[[State], PyTree],
410 n_burn: Int32[Array, ''],
411 n_save: Int32[Array, ''],
412 n_skip: Int32[Array, ''],
413 i_outer: Int32[Array, ''],
414 n_iters: Int32[Array, ''],
415) -> _Carry:
416 # determine number of iterations for this loop batch
417 i_upper = jnp.minimum(carry.i_total + inner_loop_length, n_iters) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
419 def cond(carry: _Carry) -> Bool[Array, '']: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
420 """Whether to continue the MCMC loop."""
421 return carry.i_total < i_upper 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
423 def body(carry: _Carry) -> _Carry: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
424 """Update the MCMC state."""
425 # split random key
426 keys = jaxext.split(carry.key, 3) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
427 key = keys.pop() 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
429 # update state
430 bart = mcmcstep.step(keys.pop(), carry.bart) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
432 # invoke callback
433 callback_state = carry.callback_state 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
434 if callback is not None: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
435 rt = callback( 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
436 key=keys.pop(),
437 bart=bart,
438 burnin=carry.i_total < n_burn,
439 i_total=carry.i_total,
440 callback_state=callback_state,
441 n_burn=n_burn,
442 n_save=n_save,
443 n_skip=n_skip,
444 i_outer=i_outer,
445 inner_loop_length=inner_loop_length,
446 )
447 if rt is not None: 447 ↛ 448line 447 didn't jump to line 448 because the condition on line 447 was never true1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
448 bart, callback_state = rt
450 # save to trace
451 burnin_trace, main_trace = _save_state_to_trace( 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
452 carry.burnin_trace,
453 carry.main_trace,
454 burnin_extractor,
455 main_extractor,
456 bart,
457 carry.i_total,
458 n_burn,
459 n_skip,
460 )
462 return _Carry( 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
463 bart=bart,
464 i_total=carry.i_total + 1,
465 key=key,
466 burnin_trace=burnin_trace,
467 main_trace=main_trace,
468 callback_state=callback_state,
469 )
471 return lax.while_loop(cond, body, carry) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
474@named_call
475def _save_state_to_trace(
476 burnin_trace: PyTree,
477 main_trace: PyTree,
478 burnin_extractor: Callable[[State], PyTree],
479 main_extractor: Callable[[State], PyTree],
480 bart: State,
481 i_total: Int32[Array, ''],
482 n_burn: Int32[Array, ''],
483 n_skip: Int32[Array, ''],
484) -> tuple[PyTree, PyTree]:
485 # trace index where to save during burnin; out-of-bounds => noop after
486 # burnin
487 burnin_idx = i_total 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
489 # trace index where to save during main phase; force it out-of-bounds
490 # during burnin
491 main_idx = (i_total - n_burn) // n_skip 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
492 noop_idx = jnp.iinfo(jnp.int32).max 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
493 noop_cond = i_total < n_burn 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
494 main_idx = jnp.where(noop_cond, noop_idx, main_idx) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
496 # prepare array index
497 num_chains = get_num_chains(bart) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
498 burnin_trace = _set(burnin_trace, burnin_idx, burnin_extractor(bart), num_chains) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
499 main_trace = _set(main_trace, main_idx, main_extractor(bart), num_chains) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
501 return burnin_trace, main_trace 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
504def _set(
505 trace: PyTree[Array, ' T'],
506 index: Int32[Array, ''],
507 val: PyTree[Array, ' T'],
508 num_chains: int | None,
509) -> PyTree[Array, ' T']:
510 """Do ``trace[index] = val`` but fancier."""
511 chain_axis = chain_vmap_axes(val) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
513 def at_set( 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
514 trace: Shaped[Array, 'chains samples *shape']
515 | None
516 | Shaped[Array, ' samples *shape']
517 | None,
518 val: Shaped[Array, ' chains *shape'] | Shaped[Array, '*shape'] | None,
519 chain_axis: int | None,
520 ) -> Shaped[Array, 'chains samples *shape'] | None:
521 if trace is None or trace.size == 0: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
522 # this handles the case where an array is empty because jax refuses
523 # to index into an axis of length 0, even if just in the abstract,
524 # and optional elements that are considered leaves due to `is_leaf`
525 # below needed to traverse `chain_axis`.
526 return trace 1BqfcjgdhekMKLmzGxstoICD1+WXYZQRS0TUV23456789!#$%'()*,-./
528 if num_chains is None or chain_axis is None: 1BqfcbjagdheuklMKLyrNEOFvmnwGHxstiopPIJCD1WXYZQRS0TUV2345}
529 ndindex = (index, ...) 1BqfcbjagdheuklMyrNEOFvmnwGHxstiopPIJCD1WXYQRSTUV2345}
530 else:
531 ndindex = (slice(None), index, ...) 1BfcbjagdheklKLrEFmnGHopIJCDZQRS0TUV}
533 return trace.at[ndindex].set(val, mode='drop') 1BqfcbjagdheuklMKLyrNEOFvmnwGHxstiopPIJCD1WXYZQRS0TUV2345
535 return tree.map(at_set, trace, val, chain_axis, is_leaf=lambda x: x is None) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./
538def make_default_callback(
539 state: State,
540 *,
541 dot_every: int | Integer[Array, ''] | None = 1,
542 report_every: int | Integer[Array, ''] | None = 100,
543) -> dict[str, Any]:
544 """
545 Prepare a default callback for `run_mcmc`.
547 The callback prints a dot on every iteration, and a longer
548 report outer loop iteration, and can do variable selection.
550 Parameters
551 ----------
552 state
553 The bart state to use the callback with, used to determine device
554 sharding.
555 dot_every
556 A dot is printed every `dot_every` MCMC iterations, `None` to disable.
557 report_every
558 A one line report is printed every `report_every` MCMC iterations,
559 `None` to disable.
561 Returns
562 -------
563 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
565 Examples
566 --------
567 >>> run_mcmc(key, state, ..., **make_default_callback(state, ...))
568 """
570 def as_replicated_array_or_none(val: ArrayLike | None) -> None | Array: 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [
571 return None if val is None else _replicate(jnp.asarray(val), state.config.mesh) 2nbabobB q f c @ b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [
573 return dict( 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [
574 callback=print_callback,
575 callback_state=PrintCallbackState(
576 as_replicated_array_or_none(dot_every),
577 as_replicated_array_or_none(report_every),
578 ),
579 )
582class PrintCallbackState(Module):
583 """State for `print_callback`."""
585 dot_every: Int32[Array, ''] | None
586 """A dot is printed every `dot_every` MCMC iterations, `None` to disable."""
588 report_every: Int32[Array, ''] | None
589 """A one line report is printed every `report_every` MCMC iterations,
590 `None` to disable."""
593def print_callback(
594 *,
595 bart: State,
596 burnin: Bool[Array, ''],
597 i_total: Int32[Array, ''],
598 n_burn: Int32[Array, ''],
599 n_save: Int32[Array, ''],
600 n_skip: Int32[Array, ''],
601 callback_state: PrintCallbackState,
602 **_: Any,
603) -> None:
604 """Print a dot and/or a report periodically during the MCMC."""
605 report_every = callback_state.report_every 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
606 dot_every = callback_state.dot_every 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
607 it = i_total + 1 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
609 def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
610 return False if every is None else it % every == 0 1Bqfc@bjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
612 report_cond = get_cond(report_every) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
613 dot_cond = get_cond(dot_every) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
615 def line_report_branch() -> None: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
616 if report_every is None: 1Bqfc@bjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
617 return 1fjghkyrmzGxstoI
618 if dot_every is None: 1Bqc@badeulMKLNEOFvnwHipPJCD
619 print_newline = False 1MKL
620 else:
621 print_newline = it % report_every > it % dot_every 1Bqc@badeulNEOFvnwHipPJCD
622 debug.callback( 1Bqc@badeulMKLNEOFvnwHipPJCD
623 _print_report,
624 print_dot=dot_cond,
625 print_newline=print_newline,
626 burnin=burnin,
627 it=it,
628 n_iters=n_burn + n_save * n_skip,
629 num_chains=bart.forest.num_chains(),
630 grow_prop_count=bart.forest.grow_prop_count.mean(),
631 grow_acc_count=bart.forest.grow_acc_count.mean(),
632 prune_acc_count=bart.forest.prune_acc_count.mean(),
633 prop_total=bart.forest.split_tree.shape[-2],
634 fill=forest_fill(bart.forest.split_tree),
635 )
637 def just_dot_branch() -> None: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
638 if dot_every is None: 1Bqfc@bjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
639 return 1fjghkMKLyrmzGxstoI
640 debug.callback( 1Bqc@badeulNEOFvnwHipPJCD
641 lambda: print('.', end='', flush=True) # noqa: T201
642 )
643 # logging can't do in-line printing so we use print
645 lax.cond( 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD
646 report_cond,
647 line_report_branch,
648 lambda: lax.cond(dot_cond, just_dot_branch, lambda: None),
649 )
652def _convert_jax_arrays_in_args(func: Callable[..., T]) -> Callable[..., T]:
653 """Remove jax arrays from a function arguments.
655 Converts all `jax.Array` instances in the arguments to either Python scalars
656 or numpy arrays.
657 """
659 def convert_jax_arrays(pytree: PyTree) -> PyTree:
660 def convert_jax_array(val: object) -> object: 1@ba
661 if not isinstance(val, Array): 661 ↛ 662line 661 didn't jump to line 662 because the condition on line 661 was never true1@ba
662 return val
663 elif val.shape: 663 ↛ 664line 663 didn't jump to line 664 because the condition on line 663 was never true1@ba
664 return numpy.array(val)
665 else:
666 return val.item() 1@ba
668 return tree.map(convert_jax_array, pytree) 1@ba
670 @wraps(func)
671 def new_func(*args: Any, **kw: Any) -> T:
672 args = convert_jax_arrays(args) 1@ba
673 kw = convert_jax_arrays(kw) 1@ba
674 return func(*args, **kw) 1@ba
676 return new_func
679@_convert_jax_arrays_in_args
680# convert all jax arrays in arguments because operations on them could lead to
681# deadlock with the main thread
682def _print_report(
683 *,
684 print_dot: bool,
685 print_newline: bool,
686 burnin: bool,
687 it: int,
688 n_iters: int,
689 num_chains: int | None,
690 grow_prop_count: float,
691 grow_acc_count: float,
692 prune_acc_count: float,
693 prop_total: int,
694 fill: float,
695) -> None:
696 """Print the report for `print_callback`."""
697 # compute fractions
698 grow_prop = grow_prop_count / prop_total 1@ba
699 move_acc = (grow_acc_count + prune_acc_count) / prop_total 1@ba
701 # determine prefix
702 if print_dot: 702 ↛ 704line 702 didn't jump to line 704 because the condition on line 702 was always true1@ba
703 prefix = '.\n' 1@ba
704 elif print_newline:
705 prefix = '\n'
706 else:
707 prefix = ''
709 # determine suffix in parentheses
710 msgs = [] 1@ba
711 if num_chains is not None: 1@bai
712 msgs.append(f'avg. {num_chains} chains') 1@ba
713 if burnin: 1@bai
714 msgs.append('burnin') 1@ba
715 suffix = f' ({", ".join(msgs)})' if msgs else '' 1@bai
717 print( # noqa: T201, see print_callback for why not logging 1@bai
718 f'{prefix}Iteration {it}/{n_iters}, '
719 f'grow prob: {grow_prop:.0%}, '
720 f'move acc: {move_acc:.0%}, '
721 f'fill: {fill:.0%}{suffix}'
722 )
725class Trace(TreeHeaps, Protocol):
726 """Protocol for a MCMC trace."""
728 offset: Float32[Array, '*trace_shape']
731class TreesTrace(Module):
732 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
734 leaf_tree: (
735 Float32[Array, '*trace_shape num_trees 2**d']
736 | Float32[Array, '*trace_shape num_trees k 2**d']
737 )
738 var_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
739 split_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
741 @classmethod
742 def from_dataclass(cls, obj: TreeHeaps) -> 'TreesTrace':
743 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
744 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) 2q f c ] ^ _ ` | { Dbb Ebj Fba g d h e u k l y r v m n ; z : w = ? x s t i o p
747@jit
748def evaluate_trace(
749 X: UInt[Array, 'p n'], trace: Trace
750) -> Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']:
751 """
752 Compute predictions for all iterations of the BART MCMC.
754 Parameters
755 ----------
756 X
757 The predictors matrix, with `p` predictors and `n` observations.
758 trace
759 A main trace of the BART MCMC, as returned by `run_mcmc`.
761 Returns
762 -------
763 The predictions for each chain and iteration of the MCMC.
764 """
765 # per-device memory limit
766 max_io_nbytes = 2**27 # 128 MiB 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
768 # adjust memory limit for number of devices
769 mesh = jax.typeof(trace.leaf_tree).sharding.mesh 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
770 num_devices = get_axis_size(mesh, 'chains') * get_axis_size(mesh, 'data') 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
771 max_io_nbytes *= num_devices 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
773 # determine batching axes
774 has_chains = trace.split_tree.ndim > 3 # chains, samples, trees, nodes 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
775 if has_chains: 1qfc]^_`|{@bjagdheuklyrvmn;z:w=?xstiop
776 sample_axis = 1 1fc`|{@bjagdheklrmnz:?op
777 tree_axis = 2 1fc`|{@bjagdheklrmnz:?op
778 else:
779 sample_axis = 0 1q]^_uyv;w=xsti
780 tree_axis = 1 1q]^_uyv;w=xsti
782 # batch and sum over trees
783 batched_eval = autobatch( 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
784 evaluate_forest,
785 max_io_nbytes,
786 (None, tree_axis),
787 tree_axis,
788 reduce_ufunc=jnp.add,
789 )
791 # determine output shape (to avoid autobatch tracing everything 4 times)
792 is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
793 k = trace.leaf_tree.shape[-2] if is_mv else 1 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
794 mv_shape = (k,) if is_mv else () 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
795 _, n = X.shape 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
796 out_shape = (*trace.split_tree.shape[:-2], *mv_shape, n) 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
798 # adjust memory limit keeping into account that trees are summed over
799 num_trees, hts = trace.split_tree.shape[-2:] 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
800 out_size = k * n * jnp.float32.dtype.itemsize # the value of the forest 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
801 core_io_size = ( 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
802 num_trees
803 * hts
804 * (
805 2 * k * trace.leaf_tree.itemsize
806 + trace.var_tree.itemsize
807 + trace.split_tree.itemsize
808 )
809 + out_size
810 )
811 core_int_size = (num_trees - 1) * out_size 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
812 max_io_nbytes = max(1, floor(max_io_nbytes / (1 + core_int_size / core_io_size))) 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
814 # batch over mcmc samples
815 batched_eval = autobatch( 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
816 batched_eval,
817 max_io_nbytes,
818 (None, sample_axis),
819 sample_axis,
820 warn_on_overflow=False, # the inner autobatch will handle it
821 result_shape_dtype=ShapeDtypeStruct(out_shape, jnp.float32),
822 )
824 # extract only the trees from the trace
825 trees = TreesTrace.from_dataclass(trace) 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
827 # evaluate trees
828 y_centered: Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']
829 y_centered = batched_eval(X, trees) 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
830 return y_centered + trace.offset[..., None] 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop
833@partial(jit, static_argnums=(0,))
834def compute_varcount(p: int, trace: TreeHeaps) -> Int32[Array, '*trace_shape {p}']:
835 """
836 Count how many times each predictor is used in each MCMC state.
838 Parameters
839 ----------
840 p
841 The number of predictors.
842 trace
843 A main trace of the BART MCMC, as returned by `run_mcmc`.
845 Returns
846 -------
847 Histogram of predictor usage in each MCMC state.
848 """
849 # var_tree has shape (chains? samples trees nodes)
850 return var_histogram(p, trace.var_tree, trace.split_tree, sum_batch_axis=-1) 2q f c ] ^ _ ` { g gb~ hbs t