Coverage for src / bartz / mcmcloop.py: 95%
205 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
1# bartz/src/bartz/mcmcloop.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions that implement the full BART posterior MCMC loop.
27The entry points are `run_mcmc` and `make_default_callback`.
28"""
30from collections.abc import Callable
31from dataclasses import fields
32from functools import partial, wraps
33from math import floor
34from typing import Any, Protocol
36import jax
37import numpy
38from equinox import Module
39from jax import ShapeDtypeStruct, debug, eval_shape, jit, tree
40from jax import numpy as jnp
41from jax.nn import softmax
42from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt
44from bartz import jaxext, mcmcstep
45from bartz._profiler import (
46 cond_if_not_profiling,
47 jit_if_not_profiling,
48 scan_if_not_profiling,
49)
50from bartz.grove import TreeHeaps, evaluate_forest, forest_fill, var_histogram
51from bartz.jaxext import autobatch
52from bartz.mcmcstep import State
53from bartz.mcmcstep._state import chain_vmap_axes, field, get_axis_size, get_num_chains
56class BurninTrace(Module):
57 """MCMC trace with only diagnostic values."""
59 error_cov_inv: (
60 Float32[Array, '*chains_and_samples']
61 | Float32[Array, '*chains_and_samples k k']
62 | None
63 ) = field(chains=True)
64 theta: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
65 grow_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
66 grow_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
67 prune_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
68 prune_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
69 log_likelihood: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
70 log_trans_prior: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
72 @classmethod
73 def from_state(cls, state: State) -> 'BurninTrace':
74 """Create a single-item burn-in trace from a MCMC state."""
75 return cls( 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+
76 error_cov_inv=state.error_cov_inv,
77 theta=state.forest.theta,
78 grow_prop_count=state.forest.grow_prop_count,
79 grow_acc_count=state.forest.grow_acc_count,
80 prune_prop_count=state.forest.prune_prop_count,
81 prune_acc_count=state.forest.prune_acc_count,
82 log_likelihood=state.forest.log_likelihood,
83 log_trans_prior=state.forest.log_trans_prior,
84 )
87class MainTrace(BurninTrace):
88 """MCMC trace with trees and diagnostic values."""
90 leaf_tree: (
91 Float32[Array, '*chains_and_samples 2**d']
92 | Float32[Array, '*chains_and_samples k 2**d']
93 ) = field(chains=True)
94 var_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
95 split_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
96 offset: Float32[Array, '*samples'] | Float32[Array, '*samples k']
97 varprob: Float32[Array, '*chains_and_samples p'] | None = field(chains=True)
99 @classmethod
100 def from_state(cls, state: State) -> 'MainTrace':
101 """Create a single-item main trace from a MCMC state."""
102 # compute varprob
103 log_s = state.forest.log_s 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+
104 if log_s is None: 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+
105 varprob = None 1fgDwizrmnkBv^'TUVWNOPXQRSYZ0123456789!#$%()*+
106 else:
107 varprob = softmax(log_s, where=state.forest.max_split.astype(bool)) 1abdohFEslGxHypjqAceICtu
109 return cls( 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+
110 leaf_tree=state.forest.leaf_tree,
111 var_tree=state.forest.var_tree,
112 split_tree=state.forest.split_tree,
113 offset=state.offset,
114 varprob=varprob,
115 **vars(BurninTrace.from_state(state)),
116 )
119CallbackState = PyTree[Any, 'T']
122class Callback(Protocol):
123 """Callback type for `run_mcmc`."""
125 def __call__(
126 self,
127 *,
128 key: Key[Array, ''],
129 bart: State,
130 burnin: Bool[Array, ''],
131 i_total: Int32[Array, ''],
132 i_skip: Int32[Array, ''],
133 callback_state: CallbackState,
134 n_burn: Int32[Array, ''],
135 n_save: Int32[Array, ''],
136 n_skip: Int32[Array, ''],
137 i_outer: Int32[Array, ''],
138 inner_loop_length: int,
139 ) -> tuple[State, CallbackState] | None:
140 """Do an arbitrary action after an iteration of the MCMC.
142 Parameters
143 ----------
144 key
145 A key for random number generation.
146 bart
147 The MCMC state just after updating it.
148 burnin
149 Whether the last iteration was in the burn-in phase.
150 i_total
151 The index of the last MCMC iteration (0-based).
152 i_skip
153 The number of MCMC updates from the last saved state. The initial
154 state counts as saved, even if it's not copied into the trace.
155 callback_state
156 The callback state, initially set to the argument passed to
157 `run_mcmc`, afterwards to the value returned by the last invocation
158 of the callback.
159 n_burn
160 n_save
161 n_skip
162 The corresponding `run_mcmc` arguments as-is.
163 i_outer
164 The index of the last outer loop iteration (0-based).
165 inner_loop_length
166 The number of MCMC iterations in the inner loop.
168 Returns
169 -------
170 bart : State
171 A possibly modified MCMC state. To avoid modifying the state,
172 return the `bart` argument passed to the callback as-is.
173 callback_state : CallbackState
174 The new state to be passed on the next callback invocation.
176 Notes
177 -----
178 For convenience, the callback may return `None`, and the states won't
179 be updated.
180 """
181 ...
184class _Carry(Module):
185 """Carry used in the loop in `run_mcmc`."""
187 bart: State
188 i_total: Int32[Array, '']
189 key: Key[Array, '']
190 burnin_trace: PyTree[
191 Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']
192 ]
193 main_trace: PyTree[
194 Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']
195 ]
196 callback_state: CallbackState
199def run_mcmc(
200 key: Key[Array, ''],
201 bart: State,
202 n_save: int,
203 *,
204 n_burn: int = 0,
205 n_skip: int = 1,
206 inner_loop_length: int | None = None,
207 callback: Callback | None = None,
208 callback_state: CallbackState = None,
209 burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
210 main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
211) -> tuple[
212 State,
213 PyTree[Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']],
214 PyTree[Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']],
215]:
216 """
217 Run the MCMC for the BART posterior.
219 Parameters
220 ----------
221 key
222 A key for random number generation.
223 bart
224 The initial MCMC state, as created and updated by the functions in
225 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
226 so this variable is invalidated after running `run_mcmc`. Make a copy
227 beforehand to use it again.
228 n_save
229 The number of iterations to save.
230 n_burn
231 The number of initial iterations which are not saved.
232 n_skip
233 The number of iterations to skip between each saved iteration, plus 1.
234 The effective burn-in is ``n_burn + n_skip - 1``.
235 inner_loop_length
236 The MCMC loop is split into an outer and an inner loop. The outer loop
237 is in Python, while the inner loop is in JAX. `inner_loop_length` is the
238 number of iterations of the inner loop to run for each iteration of the
239 outer loop. If not specified, the outer loop will iterate just once,
240 with all iterations done in a single inner loop run. The inner stride is
241 unrelated to the stride used for saving the trace.
242 callback
243 An arbitrary function run during the loop after updating the state. For
244 the signature, see `Callback`. The callback is called under the jax jit,
245 so the argument values are not available at the time the Python code is
246 executed. Use the utilities in `jax.debug` to access the values at
247 actual runtime. The callback may return new values for the MCMC state
248 and the callback state.
249 callback_state
250 The initial custom state for the callback.
251 burnin_extractor
252 main_extractor
253 Functions that extract the variables to be saved respectively in the
254 burnin trace and main traces, given the MCMC state as argument. Must
255 return a pytree, and must be vmappable.
257 Returns
258 -------
259 bart : State
260 The final MCMC state.
261 burnin_trace : PyTree[Shaped[Array, 'n_burn *']]
262 The trace of the burn-in phase. For the default layout, see `BurninTrace`.
263 main_trace : PyTree[Shaped[Array, 'n_save *']]
264 The trace of the main phase. For the default layout, see `MainTrace`.
266 Notes
267 -----
268 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
269 not include the initial state, and include the final state.
270 """
271 burnin_trace = _empty_trace(n_burn, bart, burnin_extractor) 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +
272 main_trace = _empty_trace(n_save, bart, main_extractor) 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +
274 # determine number of iterations for inner and outer loops
275 n_iters = n_burn + n_skip * n_save 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +
276 if inner_loop_length is None: 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +
277 inner_loop_length = n_iters 2, - f ` { g s w l | } i z _ bbcbdbr m n ebk fbB gb2 3 4 5 6 7 8 9 ! # $ % ( ) * +
278 if inner_loop_length: 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +
279 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length) 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1
280 else:
281 n_outer = 1 123456789!#$%()*+
282 # setting to 0 would make for a clean noop, but it's useful to keep the
283 # same code path for benchmarking and testing
285 carry = _Carry(bart, jnp.int32(0), key, burnin_trace, main_trace, callback_state) 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +
286 for i_outer in range(n_outer): 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +
287 carry = _run_mcmc_inner_loop( 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +
288 carry,
289 inner_loop_length,
290 callback,
291 burnin_extractor,
292 main_extractor,
293 n_burn,
294 n_save,
295 n_skip,
296 i_outer,
297 n_iters,
298 )
300 return carry.bart, carry.burnin_trace, carry.main_trace 2L , a M - K b f d hb` ibjb{ kbo g h s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +
303@partial(jit, static_argnums=(0, 2))
304def _empty_trace(
305 length: int, bart: State, extractor: Callable[[State], PyTree]
306) -> PyTree:
307 num_chains = get_num_chains(bart) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+
308 if num_chains is None: 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+
309 out_axes = 0 1oFsGHpqrmncIYZ01()*+
310 else:
311 example_output = eval_shape(extractor, bart) 1abfdghDEwlxyijzAkeBCtuv^'TUVWNOPXQRS23456789!#$%
312 chain_axes = chain_vmap_axes(example_output) 1abfdghDEwlxyijzAkeBCtuv^'TUVWNOPXQRS23456789!#$%
313 out_axes = tree.map( 1abfdghDEwlxyijzAkeBCtuv^'TUVWNOPXQRS23456789!#$%
314 lambda a: 0 if a is None else 1, chain_axes, is_leaf=lambda a: a is None
315 )
316 return jax.vmap(extractor, in_axes=None, out_axes=out_axes, axis_size=length)(bart) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+
319@jit
320def _compute_i_skip(
321 i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, '']
322) -> Int32[Array, '']:
323 """Compute the `i_skip` argument passed to `callback`."""
324 burnin = i_total < n_burn 1abdce
325 return jnp.where( 1abdce
326 burnin,
327 i_total + 1,
328 (i_total - n_burn + 1) % n_skip
329 + jnp.where(i_total - n_burn + 1 < n_skip, n_burn, 0),
330 )
333@partial(jit_if_not_profiling, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
334def _run_mcmc_inner_loop(
335 carry: _Carry,
336 inner_loop_length: int,
337 callback: Callback | None,
338 burnin_extractor: Callable[[State], PyTree],
339 main_extractor: Callable[[State], PyTree],
340 n_burn: Int32[Array, ''],
341 n_save: Int32[Array, ''],
342 n_skip: Int32[Array, ''],
343 i_outer: Int32[Array, ''],
344 n_iters: Int32[Array, ''],
345) -> _Carry:
346 def loop_impl(carry: _Carry) -> _Carry: 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
347 """Loop body to run if i_total < n_iters."""
348 # split random key
349 keys = jaxext.split(carry.key, 3) 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
350 key = keys.pop() 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
352 # update state
353 bart = mcmcstep.step(keys.pop(), carry.bart) 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
355 # invoke callback
356 callback_state = carry.callback_state 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
357 if callback is not None: 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
358 i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip) 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
359 rt = callback( 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
360 key=keys.pop(),
361 bart=bart,
362 burnin=carry.i_total < n_burn,
363 i_total=carry.i_total,
364 i_skip=i_skip,
365 callback_state=callback_state,
366 n_burn=n_burn,
367 n_save=n_save,
368 n_skip=n_skip,
369 i_outer=i_outer,
370 inner_loop_length=inner_loop_length,
371 )
372 if rt is not None: 372 ↛ 373line 372 didn't jump to line 373 because the condition on line 372 was never true1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
373 bart, callback_state = rt
375 # save to trace
376 burnin_trace, main_trace = _save_state_to_trace( 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
377 carry.burnin_trace,
378 carry.main_trace,
379 burnin_extractor,
380 main_extractor,
381 bart,
382 carry.i_total,
383 n_burn,
384 n_skip,
385 )
387 return _Carry( 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
388 bart=bart,
389 i_total=carry.i_total + 1,
390 key=key,
391 burnin_trace=burnin_trace,
392 main_trace=main_trace,
393 callback_state=callback_state,
394 )
396 def loop_noop(carry: _Carry) -> _Carry: 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
397 """Loop body to run if i_total >= n_iters; it does nothing."""
398 return carry 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
400 def loop(carry: _Carry, _) -> tuple[_Carry, None]: 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
401 carry = cond_if_not_profiling( 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
402 carry.i_total < n_iters, loop_impl, loop_noop, carry
403 )
404 return carry, None 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
406 carry, _ = scan_if_not_profiling(loop, carry, None, inner_loop_length) 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
407 return carry 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
410@partial(jit, donate_argnums=(0, 1), static_argnums=(2, 3))
411# this is jitted because under profiling _run_mcmc_inner_loop and the loop
412# within it are not, so I need the donate_argnums feature of jit to avoid
413# creating copies of the traces
414def _save_state_to_trace(
415 burnin_trace: PyTree,
416 main_trace: PyTree,
417 burnin_extractor: Callable[[State], PyTree],
418 main_extractor: Callable[[State], PyTree],
419 bart: State,
420 i_total: Int32[Array, ''],
421 n_burn: Int32[Array, ''],
422 n_skip: Int32[Array, ''],
423) -> tuple[PyTree, PyTree]:
424 # trace index where to save during burnin; out-of-bounds => noop after
425 # burnin
426 burnin_idx = i_total 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
428 # trace index where to save during main phase; force it out-of-bounds
429 # during burnin
430 main_idx = (i_total - n_burn) // n_skip 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
431 noop_idx = jnp.iinfo(jnp.int32).max 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
432 noop_cond = i_total < n_burn 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
433 main_idx = jnp.where(noop_cond, noop_idx, main_idx) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
435 # prepare array index
436 num_chains = get_num_chains(bart) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
437 burnin_trace = _set(burnin_trace, burnin_idx, burnin_extractor(bart), num_chains) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
438 main_trace = _set(main_trace, main_idx, main_extractor(bart), num_chains) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
440 return burnin_trace, main_trace 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
443def _set(
444 trace: PyTree[Array, ' T'],
445 index: Int32[Array, ''],
446 val: PyTree[Array, ' T'],
447 num_chains: int | None,
448) -> PyTree[Array, ' T']:
449 """Do ``trace[index] = val`` but fancier."""
450 chain_axis = chain_vmap_axes(val) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
452 def at_set( 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
453 trace: Shaped[Array, 'chains samples *shape']
454 | Shaped[Array, ' samples *shape']
455 | None,
456 val: Shaped[Array, ' chains *shape'] | Shaped[Array, '*shape'] | None,
457 chain_axis: int | None,
458 ):
459 if trace is None or trace.size == 0: 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
460 # this handles the case where an array is empty because jax refuses
461 # to index into an axis of length 0, even if just in the abstract,
462 # and optional elements that are considered leaves due to `is_leaf`
463 # below needed to traverse `chain_axis`.
464 return trace 1fgFDEwizrmnkBtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
466 if num_chains is None or chain_axis is None: 2a b f d o g h F D E s w l G x H y p i j q z A r m n c k e I B C t u v T U V W N O P X Q R S Y Z 0 1 ub
467 ndindex = (index, ...) 1abfdoghFswlGxHypijqzArmnckeIBCtuvTUVNOPQRSYZ01
468 else:
469 ndindex = (slice(None), index, ...) 2a b f d g h D E w l x y i j z A k e B C t u v W N O P X Q R S ub
471 return trace.at[ndindex].set(val, mode='drop') 1abfdoghFDEswlGxHypijqzArmnckeIBCtuvTUVWNOPXQRSYZ01
473 return tree.map(at_set, trace, val, chain_axis, is_leaf=lambda x: x is None) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+
476def make_default_callback(
477 *,
478 dot_every: int | Integer[Array, ''] | None = 1,
479 report_every: int | Integer[Array, ''] | None = 100,
480) -> dict[str, Any]:
481 """
482 Prepare a default callback for `run_mcmc`.
484 The callback prints a dot on every iteration, and a longer
485 report outer loop iteration, and can do variable selection.
487 Parameters
488 ----------
489 dot_every
490 A dot is printed every `dot_every` MCMC iterations, `None` to disable.
491 report_every
492 A one line report is printed every `report_every` MCMC iterations,
493 `None` to disable.
495 Returns
496 -------
497 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
499 Examples
500 --------
501 >>> run_mcmc(..., **make_default_callback())
502 """
504 def asarray_or_none(val: None | Any) -> None | Array: 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^
505 return None if val is None else jnp.asarray(val) 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^
507 return dict( 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^
508 callback=print_callback,
509 callback_state=PrintCallbackState(
510 asarray_or_none(dot_every), asarray_or_none(report_every)
511 ),
512 )
515class PrintCallbackState(Module):
516 """State for `print_callback`.
518 Parameters
519 ----------
520 dot_every
521 A dot is printed every `dot_every` MCMC iterations, `None` to disable.
522 report_every
523 A one line report is printed every `report_every` MCMC iterations,
524 `None` to disable.
525 """
527 dot_every: Int32[Array, ''] | None
528 report_every: Int32[Array, ''] | None
531def print_callback(
532 *,
533 bart: State,
534 burnin: Bool[Array, ''],
535 i_total: Int32[Array, ''],
536 n_burn: Int32[Array, ''],
537 n_save: Int32[Array, ''],
538 n_skip: Int32[Array, ''],
539 callback_state: PrintCallbackState,
540 **_,
541):
542 """Print a dot and/or a report periodically during the MCMC."""
543 report_every = callback_state.report_every 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
544 dot_every = callback_state.dot_every 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
545 it = i_total + 1 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
547 def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']: 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
548 return False if every is None else it % every == 0 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
550 report_cond = get_cond(report_every) 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
551 dot_cond = get_cond(dot_every) 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
553 def line_report_branch(): 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
554 if report_every is None: 1LaMKbfdoghFDEswlGxHypijqzArmnckeIBCtuv
555 return 1fgswlizrmnkB
556 if dot_every is None: 1LaMKbdohFDEGxHypjqAceICtuv
557 print_newline = False 1FDE
558 else:
559 print_newline = it % report_every > it % dot_every 1LaMKbdohGxHypjqAceICtuv
560 debug.callback( 1LaMKbdohFDEGxHypjqAceICtuv
561 _print_report,
562 print_dot=dot_cond,
563 print_newline=print_newline,
564 burnin=burnin,
565 it=it,
566 n_iters=n_burn + n_save * n_skip,
567 num_chains=bart.forest.num_chains(),
568 grow_prop_count=bart.forest.grow_prop_count.mean(),
569 grow_acc_count=bart.forest.grow_acc_count.mean(),
570 prune_acc_count=bart.forest.prune_acc_count.mean(),
571 prop_total=bart.forest.split_tree.shape[-2],
572 fill=forest_fill(bart.forest.split_tree),
573 )
575 def just_dot_branch(): 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
576 if dot_every is None: 1LaMKbfdoghFDEswlGxHypijqzArmnckeIBCtuv
577 return 1fgFDEswlizrmnkB
578 debug.callback( 1LaMKbdohGxHypjqAceICtuv
579 lambda: print('.', end='', flush=True) # noqa: T201
580 )
581 # logging can't do in-line printing so we use print
583 cond_if_not_profiling( 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv
584 report_cond,
585 line_report_branch,
586 lambda: cond_if_not_profiling(dot_cond, just_dot_branch, lambda: None),
587 )
590def _convert_jax_arrays_in_args(func: Callable) -> Callable:
591 """Remove jax arrays from a function arguments.
593 Converts all `jax.Array` instances in the arguments to either Python scalars
594 or numpy arrays.
595 """
597 def convert_jax_arrays(pytree: PyTree) -> PyTree:
598 def convert_jax_array(val: Any) -> Any: 1LaMKb
599 if not isinstance(val, Array): 599 ↛ 600line 599 didn't jump to line 600 because the condition on line 599 was never true1LaMKb
600 return val
601 elif val.shape: 601 ↛ 602line 601 didn't jump to line 602 because the condition on line 601 was never true1LaMKb
602 return numpy.array(val)
603 else:
604 return val.item() 1LaMKb
606 return tree.map(convert_jax_array, pytree) 1LaMKb
608 @wraps(func)
609 def new_func(*args, **kw):
610 args = convert_jax_arrays(args) 1LaMKb
611 kw = convert_jax_arrays(kw) 1LaMKb
612 return func(*args, **kw) 1LaMKb
614 return new_func
617@_convert_jax_arrays_in_args
618# convert all jax arrays in arguments because operations on them could lead to
619# deadlock with the main thread
620def _print_report(
621 *,
622 print_dot: bool,
623 print_newline: bool,
624 burnin: bool,
625 it: int,
626 n_iters: int,
627 num_chains: int | None,
628 grow_prop_count: float,
629 grow_acc_count: float,
630 prune_acc_count: float,
631 prop_total: int,
632 fill: float,
633):
634 """Print the report for `print_callback`."""
635 # compute fractions
636 grow_prop = grow_prop_count / prop_total 1LaMKb
637 move_acc = (grow_acc_count + prune_acc_count) / prop_total 1LaMKb
639 # determine prefix
640 if print_dot: 640 ↛ 642line 640 didn't jump to line 642 because the condition on line 640 was always true1LaMKb
641 prefix = '.\n' 1LaMKb
642 elif print_newline:
643 prefix = '\n'
644 else:
645 prefix = ''
647 # determine suffix in parentheses
648 msgs = [] 1LaMKb
649 if num_chains is not None: 1LaMKbc
650 msgs.append(f'avg. {num_chains} chains') 1aKb
651 if burnin: 1LaMKbc
652 msgs.append('burnin') 1LaMKb
653 suffix = f' ({", ".join(msgs)})' if msgs else '' 1LaMKbc
655 print( # noqa: T201, see print_callback for why not logging 1LaMKbc
656 f'{prefix}Iteration {it}/{n_iters}, '
657 f'grow prob: {grow_prop:.0%}, '
658 f'move acc: {move_acc:.0%}, '
659 f'fill: {fill:.0%}{suffix}'
660 )
663class Trace(TreeHeaps, Protocol):
664 """Protocol for a MCMC trace."""
666 offset: Float32[Array, '*trace_shape']
669class TreesTrace(Module):
670 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
672 leaf_tree: (
673 Float32[Array, '*trace_shape num_trees 2**d']
674 | Float32[Array, '*trace_shape num_trees k 2**d']
675 )
676 var_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
677 split_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
679 @classmethod
680 def from_dataclass(cls, obj: TreeHeaps):
681 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
682 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) 2a ; = ? @ ] [ vbb wbf xbd o g h s l p i j q / : r m n c k e .
685@jit
686def evaluate_trace(
687 X: UInt[Array, 'p n'], trace: Trace
688) -> Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']:
689 """
690 Compute predictions for all iterations of the BART MCMC.
692 Parameters
693 ----------
694 X
695 The predictors matrix, with `p` predictors and `n` observations.
696 trace
697 A main trace of the BART MCMC, as returned by `run_mcmc`.
699 Returns
700 -------
701 The predictions for each chain and iteration of the MCMC.
702 """
703 # per-device memory limit
704 max_io_nbytes = 2**27 # 128 MiB 1a;=?@][bfdoghslpijq/:rmncke.
706 # adjust memory limit for number of devices
707 mesh = jax.typeof(trace.leaf_tree).sharding.mesh 1a;=?@][bfdoghslpijq/:rmncke.
708 num_devices = get_axis_size(mesh, 'chains') * get_axis_size(mesh, 'data') 1a;=?@][bfdoghslpijq/:rmncke.
709 max_io_nbytes *= num_devices 1a;=?@][bfdoghslpijq/:rmncke.
711 # determine batching axes
712 has_chains = trace.split_tree.ndim > 3 # chains, samples, trees, nodes 1a;=?@][bfdoghslpijq/:rmncke.
713 if has_chains: 1a;=?@][bfdoghslpijq/:rmncke.
714 sample_axis = 1 1a@][bfdghlij:ke.
715 tree_axis = 2 1a@][bfdghlij:ke.
716 else:
717 sample_axis = 0 1;=?ospq/rmnc
718 tree_axis = 1 1;=?ospq/rmnc
720 # batch and sum over trees
721 batched_eval = autobatch( 1a;=?@][bfdoghslpijq/:rmncke.
722 evaluate_forest,
723 max_io_nbytes,
724 (None, tree_axis),
725 tree_axis,
726 reduce_ufunc=jnp.add,
727 )
729 # determine output shape (to avoid autobatch tracing everything 4 times)
730 is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim 1a;=?@][bfdoghslpijq/:rmncke.
731 k = trace.leaf_tree.shape[-2] if is_mv else 1 1a;=?@][bfdoghslpijq/:rmncke.
732 mv_shape = (k,) if is_mv else () 1a;=?@][bfdoghslpijq/:rmncke.
733 _, n = X.shape 1a;=?@][bfdoghslpijq/:rmncke.
734 out_shape = (*trace.split_tree.shape[:-2], *mv_shape, n) 1a;=?@][bfdoghslpijq/:rmncke.
736 # adjust memory limit keeping into account that trees are summed over
737 num_trees, hts = trace.split_tree.shape[-2:] 1a;=?@][bfdoghslpijq/:rmncke.
738 out_size = k * n * jnp.float32.dtype.itemsize # the value of the forest 1a;=?@][bfdoghslpijq/:rmncke.
739 core_io_size = ( 1a;=?@][bfdoghslpijq/:rmncke.
740 num_trees
741 * hts
742 * (
743 2 * k * trace.leaf_tree.itemsize
744 + trace.var_tree.itemsize
745 + trace.split_tree.itemsize
746 )
747 + out_size
748 )
749 core_int_size = (num_trees - 1) * out_size 1a;=?@][bfdoghslpijq/:rmncke.
750 max_io_nbytes = max(1, floor(max_io_nbytes / (1 + core_int_size / core_io_size))) 1a;=?@][bfdoghslpijq/:rmncke.
752 # batch over mcmc samples
753 batched_eval = autobatch( 1a;=?@][bfdoghslpijq/:rmncke.
754 batched_eval,
755 max_io_nbytes,
756 (None, sample_axis),
757 sample_axis,
758 warn_on_overflow=False, # the inner autobatch will handle it
759 result_shape_dtype=ShapeDtypeStruct(out_shape, jnp.float32),
760 )
762 # extract only the trees from the trace
763 trees = TreesTrace.from_dataclass(trace) 1a;=?@][bfdoghslpijq/:rmncke.
765 # evaluate trees
766 y_centered: Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']
767 y_centered = batched_eval(X, trees) 1a;=?@][bfdoghslpijq/:rmncke.
768 return y_centered + trace.offset[..., None] 1a;=?@][bfdoghslpijq/:rmncke.
771@partial(jit, static_argnums=(0,))
772def compute_varcount(p: int, trace: TreeHeaps) -> Int32[Array, '*trace_shape {p}']:
773 """
774 Count how many times each predictor is used in each MCMC state.
776 Parameters
777 ----------
778 p
779 The number of predictors.
780 trace
781 A main trace of the BART MCMC, as returned by `run_mcmc`.
783 Returns
784 -------
785 Histogram of predictor usage in each MCMC state.
786 """
787 # var_tree has shape (chains? samples trees nodes)
788 return var_histogram(p, trace.var_tree, trace.split_tree, sum_batch_axis=-1) 2; = ? @ [ ~ _ abm n .