Coverage for src / bartz / mcmcloop.py: 95%
222 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
1# bartz/src/bartz/mcmcloop.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions that implement the full BART posterior MCMC loop.
27The entry points are `run_mcmc` and `make_default_callback`.
28"""
30from collections.abc import Callable
31from functools import partial, update_wrapper, wraps
32from math import floor
33from typing import Any, NamedTuple, Protocol, TypeVar
35import jax
36import numpy
37from equinox import Module
38from jax import (
39 NamedSharding,
40 ShapeDtypeStruct,
41 debug,
42 device_put,
43 eval_shape,
44 jit,
45 lax,
46 named_call,
47 tree,
48)
49from jax import numpy as jnp
50from jax.nn import softmax
51from jax.sharding import Mesh, PartitionSpec
52from jaxtyping import (
53 Array,
54 ArrayLike,
55 Bool,
56 Float32,
57 Int32,
58 Integer,
59 Key,
60 PyTree,
61 Shaped,
62 UInt,
63)
65from bartz import jaxext, mcmcstep
66from bartz.grove import (
67 TreeHeaps,
68 TreesTrace,
69 evaluate_forest,
70 forest_fill,
71 var_histogram,
72)
73from bartz.jaxext import autobatch, jit_active
74from bartz.mcmcstep import State
75from bartz.mcmcstep._state import chain_vmap_axes, field, get_axis_size, get_num_chains
78class BurninTrace(Module):
79 """MCMC trace with only diagnostic values."""
81 error_cov_inv: (
82 Float32[Array, '*chains_and_samples']
83 | Float32[Array, '*chains_and_samples k k']
84 ) = field(chains=True)
85 theta: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
86 grow_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
87 grow_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
88 prune_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
89 prune_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
90 log_likelihood: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
91 log_trans_prior: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
93 @classmethod
94 def from_state(cls, state: State) -> 'BurninTrace':
95 """Create a single-item burn-in trace from a MCMC state."""
96 return cls( 1ad
97 error_cov_inv=state.error_cov_inv,
98 theta=state.forest.theta,
99 grow_prop_count=state.forest.grow_prop_count,
100 grow_acc_count=state.forest.grow_acc_count,
101 prune_prop_count=state.forest.prune_prop_count,
102 prune_acc_count=state.forest.prune_acc_count,
103 log_likelihood=state.forest.log_likelihood,
104 log_trans_prior=state.forest.log_trans_prior,
105 )
108class MainTrace(BurninTrace):
109 """MCMC trace with trees and diagnostic values."""
111 leaf_tree: (
112 Float32[Array, '*chains_and_samples 2**d']
113 | Float32[Array, '*chains_and_samples k 2**d']
114 ) = field(chains=True)
115 var_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
116 split_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
117 offset: Float32[Array, '*samples'] | Float32[Array, '*samples k']
118 varprob: Float32[Array, '*chains_and_samples p'] | None = field(chains=True)
120 @classmethod
121 def from_state(cls, state: State) -> 'MainTrace':
122 """Create a single-item main trace from a MCMC state."""
123 # compute varprob
124 log_s = state.forest.log_s 1ad
125 if log_s is None: 1ajecd
126 varprob = None 1jd
127 else:
128 varprob = softmax(log_s, where=state.forest.max_split.astype(bool)) 1aec
130 return cls( 1ad
131 leaf_tree=state.forest.leaf_tree,
132 var_tree=state.forest.var_tree,
133 split_tree=state.forest.split_tree,
134 offset=state.offset,
135 varprob=varprob,
136 **vars(BurninTrace.from_state(state)),
137 )
140CallbackState = PyTree[Any, 'T']
143class RunMCMCResult(NamedTuple):
144 """Return value of `run_mcmc`."""
146 final_state: State
147 """The final MCMC state."""
149 burnin_trace: PyTree[
150 Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']
151 ]
152 """The trace of the burn-in phase. For the default layout, see `BurninTrace`."""
154 main_trace: PyTree[
155 Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']
156 ]
157 """The trace of the main phase. For the default layout, see `MainTrace`."""
160class Callback(Protocol):
161 """Callback type for `run_mcmc`."""
163 def __call__(
164 self,
165 *,
166 key: Key[Array, ''],
167 bart: State,
168 burnin: Bool[Array, ''],
169 i_total: Int32[Array, ''],
170 callback_state: CallbackState,
171 n_burn: Int32[Array, ''],
172 n_save: Int32[Array, ''],
173 n_skip: Int32[Array, ''],
174 i_outer: Int32[Array, ''],
175 inner_loop_length: Int32[Array, ''],
176 ) -> tuple[State, CallbackState] | None:
177 """Do an arbitrary action after an iteration of the MCMC.
179 Parameters
180 ----------
181 key
182 A key for random number generation.
183 bart
184 The MCMC state just after updating it.
185 burnin
186 Whether the last iteration was in the burn-in phase.
187 i_total
188 The index of the last MCMC iteration (0-based).
189 callback_state
190 The callback state, initially set to the argument passed to
191 `run_mcmc`, afterwards to the value returned by the last invocation
192 of the callback.
193 n_burn
194 n_save
195 n_skip
196 The corresponding `run_mcmc` arguments as-is.
197 i_outer
198 The index of the last outer loop iteration (0-based).
199 inner_loop_length
200 The number of MCMC iterations in the inner loop.
202 Returns
203 -------
204 bart : State
205 A possibly modified MCMC state. To avoid modifying the state,
206 return the `bart` argument passed to the callback as-is.
207 callback_state : CallbackState
208 The new state to be passed on the next callback invocation.
210 Notes
211 -----
212 For convenience, the callback may return `None`, and the states won't
213 be updated.
214 """
215 ...
218class _Carry(Module):
219 """Carry used in the loop in `run_mcmc`."""
221 bart: State
222 i_total: Int32[Array, '']
223 key: Key[Array, '']
224 burnin_trace: PyTree[
225 Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']
226 ]
227 main_trace: PyTree[
228 Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']
229 ]
230 callback_state: CallbackState
233def run_mcmc(
234 key: Key[Array, ''],
235 bart: State,
236 n_save: int,
237 *,
238 n_burn: int = 0,
239 n_skip: int = 1,
240 inner_loop_length: int | None = None,
241 callback: Callback | None = None,
242 callback_state: CallbackState = None,
243 burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
244 main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
245) -> RunMCMCResult:
246 """
247 Run the MCMC for the BART posterior.
249 Parameters
250 ----------
251 key
252 A key for random number generation.
253 bart
254 The initial MCMC state, as created and updated by the functions in
255 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
256 so this variable is invalidated after running `run_mcmc`. Make a copy
257 beforehand to use it again.
258 n_save
259 The number of iterations to save.
260 n_burn
261 The number of initial iterations which are not saved.
262 n_skip
263 The number of iterations to skip between each saved iteration, plus 1.
264 The effective burn-in is ``n_burn + n_skip - 1``.
265 inner_loop_length
266 The MCMC loop is split into an outer and an inner loop. The outer loop
267 is in Python, while the inner loop is in JAX. `inner_loop_length` is the
268 number of iterations of the inner loop to run for each iteration of the
269 outer loop. If not specified, the outer loop will iterate just once,
270 with all iterations done in a single inner loop run. The inner stride is
271 unrelated to the stride used for saving the trace.
272 callback
273 An arbitrary function run during the loop after updating the state. For
274 the signature, see `Callback`. The callback is called under the jax jit,
275 so the argument values are not available at the time the Python code is
276 executed. Use the utilities in `jax.debug` to access the values at
277 actual runtime. The callback may return new values for the MCMC state
278 and the callback state.
279 callback_state
280 The initial custom state for the callback.
281 burnin_extractor
282 main_extractor
283 Functions that extract the variables to be saved respectively in the
284 burnin trace and main traces, given the MCMC state as argument. Must
285 return a pytree, and must be vmappable.
287 Returns
288 -------
289 A namedtuple with the final state, the burn-in trace, and the main trace.
291 Raises
292 ------
293 RuntimeError
294 If `run_mcmc` detects it's being invoked in a `jit`-wrapped context and
295 with settings that would create unrolled loops in the trace.
297 Notes
298 -----
299 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
300 not include the initial state, and include the final state.
301 """
302 # create empty traces
303 burnin_trace = _empty_trace(n_burn, bart, burnin_extractor) 1ad
304 main_trace = _empty_trace(n_save, bart, main_extractor) 1ad
306 # determine number of iterations for inner and outer loops
307 n_iters = n_burn + n_skip * n_save 1ad
308 if inner_loop_length is None: 1agdk
309 inner_loop_length = n_iters 1gk
310 if inner_loop_length: 1atudk
311 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length) 1ad
312 else:
313 n_outer = 1 1tuk
314 # setting to 0 would make for a clean noop, but it's useful to keep the
315 # same code path for benchmarking and testing
317 # error if under jit and there are unrolled loops
318 if jit_active() and n_outer > 1: 1agvdn
319 msg = ( 1n
320 '`run_mcmc` was called within a jit-compiled function and '
321 'there are more than 1 outer loops, '
322 'please either do not jit or set `inner_loop_length=None`'
323 )
324 raise RuntimeError(msg) 1n
326 replicate = partial(_replicate, mesh=bart.config.mesh) 1agvd
327 carry = _Carry( 1ad
328 bart,
329 replicate(jnp.int32(0)),
330 replicate(key),
331 burnin_trace,
332 main_trace,
333 callback_state,
334 )
335 _run_mcmc_inner_loop._fun.reset_call_counter() # noqa: SLF001 1ad
336 for i_outer in range(n_outer): 1ad
337 carry = _run_mcmc_inner_loop( 1ad
338 carry,
339 inner_loop_length,
340 callback,
341 burnin_extractor,
342 main_extractor,
343 n_burn,
344 n_save,
345 n_skip,
346 i_outer,
347 n_iters,
348 )
350 return RunMCMCResult(carry.bart, carry.burnin_trace, carry.main_trace) 1ad
353def _replicate(x: Array, mesh: Mesh | None) -> Array:
354 if mesh is None: 1wajhd
355 return x 1ad
356 else:
357 return device_put(x, NamedSharding(mesh, PartitionSpec())) 1wjh
360@partial(jit, static_argnums=(0, 2))
361def _empty_trace(
362 length: int, bart: State, extractor: Callable[[State], PyTree]
363) -> PyTree:
364 num_chains = get_num_chains(bart) 1ad
365 if num_chains is None: 1aild
366 out_axes = 0 1id
367 else:
368 example_output = eval_shape(extractor, bart) 1al
369 chain_axes = chain_vmap_axes(example_output) 1al
370 out_axes = tree.map( 1al
371 lambda a: 0 if a is None else 1, chain_axes, is_leaf=lambda a: a is None
372 )
373 return jax.vmap(extractor, in_axes=None, out_axes=out_axes, axis_size=length)(bart) 1ad
376T = TypeVar('T')
379class _CallCounter:
380 """Wrap a callable to check it's not called more than once."""
382 def __init__(self, func: Callable[..., T]) -> None:
383 self.func = func
384 self.n_calls = 0
385 update_wrapper(self, func)
387 def reset_call_counter(self) -> None:
388 """Reset the call counter."""
389 self.n_calls = 0 1ad
391 def __call__(self, *args: Any, **kwargs: Any) -> T:
392 if self.n_calls: 1aod
393 msg = ( 1o
394 'The inner loop of `run_mcmc` was traced more than once, '
395 'which indicates a double compilation of the MCMC code. This '
396 'probably depends on the input state having different type from the '
397 'output state. Check the input is in a format that is the '
398 'same jax would output, e.g., all arrays and scalars are jax '
399 'arrays, with the right shardings.'
400 )
401 raise RuntimeError(msg) 1o
402 self.n_calls += 1 1ad
403 return self.func(*args, **kwargs) 1ad
406@partial(jit, donate_argnums=(0,), static_argnums=(2, 3, 4))
407@_CallCounter
408def _run_mcmc_inner_loop(
409 carry: _Carry,
410 inner_loop_length: Int32[Array, ''],
411 callback: Callback | None,
412 burnin_extractor: Callable[[State], PyTree],
413 main_extractor: Callable[[State], PyTree],
414 n_burn: Int32[Array, ''],
415 n_save: Int32[Array, ''],
416 n_skip: Int32[Array, ''],
417 i_outer: Int32[Array, ''],
418 n_iters: Int32[Array, ''],
419) -> _Carry:
420 # determine number of iterations for this loop batch
421 i_upper = jnp.minimum(carry.i_total + inner_loop_length, n_iters) 1ad
423 def cond(carry: _Carry) -> Bool[Array, '']: 1ad
424 """Whether to continue the MCMC loop."""
425 return carry.i_total < i_upper 1ad
427 def body(carry: _Carry) -> _Carry: 1ad
428 """Update the MCMC state."""
429 # split random key
430 keys = jaxext.split(carry.key, 3) 1ad
431 key = keys.pop() 1ad
433 # update state
434 bart = mcmcstep.step(keys.pop(), carry.bart) 1ad
436 # invoke callback
437 callback_state = carry.callback_state 1ad
438 if callback is not None: 1aecd
439 rt = callback( 1aec
440 key=keys.pop(),
441 bart=bart,
442 burnin=carry.i_total < n_burn,
443 i_total=carry.i_total,
444 callback_state=callback_state,
445 n_burn=n_burn,
446 n_save=n_save,
447 n_skip=n_skip,
448 i_outer=i_outer,
449 inner_loop_length=inner_loop_length,
450 )
451 if rt is not None: 451 ↛ 452line 451 didn't jump to line 452 because the condition on line 451 was never true1aec
452 bart, callback_state = rt
454 # save to trace
455 burnin_trace, main_trace = _save_state_to_trace( 1aecd
456 carry.burnin_trace,
457 carry.main_trace,
458 burnin_extractor,
459 main_extractor,
460 bart,
461 carry.i_total,
462 n_burn,
463 n_skip,
464 )
466 return _Carry( 1ad
467 bart=bart,
468 i_total=carry.i_total + 1,
469 key=key,
470 burnin_trace=burnin_trace,
471 main_trace=main_trace,
472 callback_state=callback_state,
473 )
475 return lax.while_loop(cond, body, carry) 1ad
478@named_call
479def _save_state_to_trace(
480 burnin_trace: PyTree,
481 main_trace: PyTree,
482 burnin_extractor: Callable[[State], PyTree],
483 main_extractor: Callable[[State], PyTree],
484 bart: State,
485 i_total: Int32[Array, ''],
486 n_burn: Int32[Array, ''],
487 n_skip: Int32[Array, ''],
488) -> tuple[PyTree, PyTree]:
489 # trace index where to save during burnin; out-of-bounds => noop after
490 # burnin
491 burnin_idx = i_total 1ad
493 # trace index where to save during main phase; force it out-of-bounds
494 # during burnin
495 main_idx = (i_total - n_burn) // n_skip 1ad
496 noop_idx = jnp.iinfo(jnp.int32).max 1ad
497 noop_cond = i_total < n_burn 1ad
498 main_idx = jnp.where(noop_cond, noop_idx, main_idx) 1ad
500 # prepare array index
501 num_chains = get_num_chains(bart) 1ad
502 burnin_trace = _set(burnin_trace, burnin_idx, burnin_extractor(bart), num_chains) 1ad
503 main_trace = _set(main_trace, main_idx, main_extractor(bart), num_chains) 1ad
505 return burnin_trace, main_trace 1ad
508def _set(
509 trace: PyTree[Array, ' T'],
510 index: Int32[Array, ''],
511 val: PyTree[Array, ' T'],
512 num_chains: int | None,
513) -> PyTree[Array, ' T']:
514 """Do ``trace[index] = val`` but fancier."""
515 chain_axis = chain_vmap_axes(val) 1ad
517 def at_set( 1ad
518 trace: Shaped[Array, 'chains samples *shape']
519 | None
520 | Shaped[Array, ' samples *shape']
521 | None,
522 val: Shaped[Array, ' chains *shape'] | Shaped[Array, '*shape'] | None,
523 chain_axis: int | None,
524 ) -> Shaped[Array, 'chains samples *shape'] | None:
525 if trace is None or trace.size == 0: 1xamd
526 # this handles the case where an array is empty because jax refuses
527 # to index into an axis of length 0, even if just in the abstract,
528 # and optional elements that are considered leaves due to `is_leaf`
529 # below needed to traverse `chain_axis`.
530 return trace 1xmd
532 if num_chains is None or chain_axis is None: 1aipdq
533 ndindex = (index, ...) 1aipdq
534 else:
535 ndindex = (slice(None), index, ...) 1apq
537 return trace.at[ndindex].set(val, mode='drop') 1ad
539 return tree.map(at_set, trace, val, chain_axis, is_leaf=lambda x: x is None) 1ad
542def make_default_callback(
543 state: State,
544 *,
545 dot_every: int | Integer[Array, ''] | None = 1,
546 report_every: int | Integer[Array, ''] | None = 100,
547) -> dict[str, Any]:
548 """
549 Prepare a default callback for `run_mcmc`.
551 The callback prints a dot on every iteration, and a longer
552 report outer loop iteration, and can do variable selection.
554 Parameters
555 ----------
556 state
557 The bart state to use the callback with, used to determine device
558 sharding.
559 dot_every
560 A dot is printed every `dot_every` MCMC iterations, `None` to disable.
561 report_every
562 A one line report is printed every `report_every` MCMC iterations,
563 `None` to disable.
565 Returns
566 -------
567 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
569 Examples
570 --------
571 >>> run_mcmc(key, state, ..., **make_default_callback(state, ...))
572 """
574 def as_replicated_array_or_none(val: ArrayLike | None) -> None | Array: 1aec
575 return None if val is None else _replicate(jnp.asarray(val), state.config.mesh) 1aghec
577 return dict( 1aec
578 callback=print_callback,
579 callback_state=PrintCallbackState(
580 as_replicated_array_or_none(dot_every),
581 as_replicated_array_or_none(report_every),
582 ),
583 )
586class PrintCallbackState(Module):
587 """State for `print_callback`."""
589 dot_every: Int32[Array, ''] | None
590 """A dot is printed every `dot_every` MCMC iterations, `None` to disable."""
592 report_every: Int32[Array, ''] | None
593 """A one line report is printed every `report_every` MCMC iterations,
594 `None` to disable."""
597def print_callback(
598 *,
599 bart: State,
600 burnin: Bool[Array, ''],
601 i_total: Int32[Array, ''],
602 n_burn: Int32[Array, ''],
603 n_save: Int32[Array, ''],
604 n_skip: Int32[Array, ''],
605 callback_state: PrintCallbackState,
606 **_: Any,
607) -> None:
608 """Print a dot and/or a report periodically during the MCMC."""
609 report_every = callback_state.report_every 1aec
610 dot_every = callback_state.dot_every 1aec
611 it = i_total + 1 1aec
613 def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']: 1aec
614 return False if every is None else it % every == 0 1aghec
616 report_cond = get_cond(report_every) 1aec
617 dot_cond = get_cond(dot_every) 1aec
619 def line_report_branch() -> None: 1aec
620 if report_every is None: 1aghec
621 return 1gh
622 if dot_every is None: 1ameyc
623 print_newline = False 1my
624 else:
625 print_newline = it % report_every > it % dot_every 1aec
626 debug.callback( 1aec
627 _print_report,
628 print_dot=dot_cond,
629 print_newline=print_newline,
630 burnin=burnin,
631 it=it,
632 n_iters=n_burn + n_save * n_skip,
633 num_chains=bart.forest.num_chains(),
634 grow_prop_count=bart.forest.grow_prop_count.mean(),
635 grow_acc_count=bart.forest.grow_acc_count.mean(),
636 prune_acc_count=bart.forest.prune_acc_count.mean(),
637 prop_total=bart.forest.split_tree.shape[-2],
638 fill=forest_fill(bart.forest.split_tree),
639 )
641 def just_dot_branch() -> None: 1aec
642 if dot_every is None: 1aghec
643 return 1gh
644 debug.callback( 1aec
645 lambda: print('.', end='', flush=True) # noqa: T201
646 )
647 # logging can't do in-line printing so we use print
649 lax.cond( 1aec
650 report_cond,
651 line_report_branch,
652 lambda: lax.cond(dot_cond, just_dot_branch, lambda: None),
653 )
656def _convert_jax_arrays_in_args(func: Callable[..., T]) -> Callable[..., T]:
657 """Remove jax arrays from a function arguments.
659 Converts all `jax.Array` instances in the arguments to either Python scalars
660 or numpy arrays.
661 """
663 def convert_jax_arrays(pytree: PyTree) -> PyTree:
664 def convert_jax_array(val: object) -> object: 1aec
665 if not isinstance(val, Array): 665 ↛ 666line 665 didn't jump to line 666 because the condition on line 665 was never true1aec
666 return val
667 elif val.shape: 667 ↛ 668line 667 didn't jump to line 668 because the condition on line 667 was never true1aec
668 return numpy.array(val)
669 else:
670 return val.item() 1aec
672 return tree.map(convert_jax_array, pytree) 1aec
674 @wraps(func)
675 def new_func(*args: Any, **kw: Any) -> T:
676 args = convert_jax_arrays(args) 1aec
677 kw = convert_jax_arrays(kw) 1aec
678 return func(*args, **kw) 1aec
680 return new_func
683@_convert_jax_arrays_in_args
684# convert all jax arrays in arguments because operations on them could lead to
685# deadlock with the main thread
686def _print_report(
687 *,
688 print_dot: bool,
689 print_newline: bool,
690 burnin: bool,
691 it: int,
692 n_iters: int,
693 num_chains: int | None,
694 grow_prop_count: float,
695 grow_acc_count: float,
696 prune_acc_count: float,
697 prop_total: int,
698 fill: float,
699) -> None:
700 """Print the report for `print_callback`."""
701 # compute fractions
702 grow_prop = grow_prop_count / prop_total 1aec
703 move_acc = (grow_acc_count + prune_acc_count) / prop_total 1aec
705 # determine prefix
706 if print_dot: 706 ↛ 708line 706 didn't jump to line 708 because the condition on line 706 was always true1aec
707 prefix = '.\n' 1aec
708 elif print_newline:
709 prefix = '\n'
710 else:
711 prefix = ''
713 # determine suffix in parentheses
714 msgs = [] 1aec
715 if num_chains is not None: 1aizec
716 msgs.append(f'avg. {num_chains} chains') 1az
717 if burnin: 1aiec
718 msgs.append('burnin') 1aec
719 suffix = f' ({", ".join(msgs)})' if msgs else '' 1aiec
721 print( # noqa: T201, see print_callback for why not logging 1aiec
722 f'{prefix}Iteration {it}/{n_iters}, '
723 f'grow prob: {grow_prop:.0%}, '
724 f'move acc: {move_acc:.0%}, '
725 f'fill: {fill:.0%}{suffix}'
726 )
729class Trace(TreeHeaps, Protocol):
730 """Protocol for a MCMC trace."""
732 offset: Float32[Array, '*trace_shape']
735@jit
736def evaluate_trace(
737 X: UInt[Array, 'p n'], trace: Trace
738) -> Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']:
739 """
740 Compute predictions for all iterations of the BART MCMC.
742 Parameters
743 ----------
744 X
745 The predictors matrix, with `p` predictors and `n` observations.
746 trace
747 A main trace of the BART MCMC, as returned by `run_mcmc`.
749 Returns
750 -------
751 The predictions for each chain and iteration of the MCMC.
752 """
753 # per-device memory limit
754 max_io_nbytes = 2**27 # 128 MiB 1afc
756 # adjust memory limit for number of devices
757 mesh = jax.typeof(trace.leaf_tree).sharding.mesh 1afc
758 num_devices = get_axis_size(mesh, 'chains') * get_axis_size(mesh, 'data') 1afc
759 max_io_nbytes *= num_devices 1afc
761 # determine batching axes
762 has_chains = trace.split_tree.ndim > 3 # chains, samples, trees, nodes 1afc
763 if has_chains: 1rafc
764 sample_axis = 1 1af
765 tree_axis = 2 1af
766 else:
767 sample_axis = 0 1rc
768 tree_axis = 1 1rc
770 # batch and sum over trees
771 batched_eval = autobatch( 1afc
772 evaluate_forest,
773 max_io_nbytes,
774 (None, tree_axis),
775 tree_axis,
776 reduce_ufunc=jnp.add,
777 )
779 # determine output shape (to avoid autobatch tracing everything 4 times)
780 is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim 1afc
781 k = trace.leaf_tree.shape[-2] if is_mv else 1 1asfc
782 mv_shape = (k,) if is_mv else () 1asfc
783 _, n = X.shape 1asfc
784 out_shape = (*trace.split_tree.shape[:-2], *mv_shape, n) 1afc
786 # adjust memory limit keeping into account that trees are summed over
787 num_trees, hts = trace.split_tree.shape[-2:] 1afc
788 out_size = k * n * jnp.float32.dtype.itemsize # the value of the forest 1afc
789 core_io_size = ( 1afc
790 num_trees
791 * hts
792 * (
793 2 * k * trace.leaf_tree.itemsize
794 + trace.var_tree.itemsize
795 + trace.split_tree.itemsize
796 )
797 + out_size
798 )
799 core_int_size = (num_trees - 1) * out_size 1afc
800 max_io_nbytes = max(1, floor(max_io_nbytes / (1 + core_int_size / core_io_size))) 1afc
802 # batch over mcmc samples
803 batched_eval = autobatch( 1afc
804 batched_eval,
805 max_io_nbytes,
806 (None, sample_axis),
807 sample_axis,
808 warn_on_overflow=False, # the inner autobatch will handle it
809 result_shape_dtype=ShapeDtypeStruct(out_shape, jnp.float32),
810 )
812 # extract only the trees from the trace
813 trees = TreesTrace.from_dataclass(trace) 1afc
815 # evaluate trees
816 y_centered: Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']
817 y_centered = batched_eval(X, trees) 1afc
818 return y_centered + trace.offset[..., None] 1afc
821@partial(jit, static_argnums=(0,))
822def compute_varcount(p: int, trace: TreeHeaps) -> Int32[Array, '*trace_shape {p}']:
823 """
824 Count how many times each predictor is used in each MCMC state.
826 Parameters
827 ----------
828 p
829 The number of predictors.
830 trace
831 A main trace of the BART MCMC, as returned by `run_mcmc`.
833 Returns
834 -------
835 Histogram of predictor usage in each MCMC state.
836 """
837 # var_tree has shape (chains? samples trees nodes)
838 return var_histogram(p, trace.var_tree, trace.split_tree, sum_batch_axis=-1) 1ABc