bartz.mcmcloop.evaluate_trace¶
- bartz.mcmcloop.evaluate_trace(X, trace, *, flatten_chains=False, out_chain_axis_w_trees=0, test_points='none', max_io_nbytes=134217728)[source]¶
Compute predictions for all iterations of the BART MCMC.
- Parameters:
X (
UInt[Array, 'p n']) – The predictors matrix, withppredictors andnobservations.trace (
EvaluableTrace) – A main trace of the BART MCMC, as returned byrun_mcmc.flatten_chains (
bool, default:False) – IfTrueandtracehas a chain axis, collapse it into the sample axis, so the leading dimension of the output isnum_chains * num_samplesinstead of(num_chains, num_samples).out_chain_axis_w_trees (
int, default:0) – Position of the chain axis in the output. Interpreted against the intermediate, pre-tree-reduction layout(sample, tree, *k, n); after summing over the tree axis the chain ends up one position earlier if it was after the tree axis. Negative values count from the end. Ignored whentracehas no chain axis orflatten_chainsis True.test_points (
Literal['none','autobatch','shard_and_autobatch'], default:'none') –How to handle the observation (
n) axis ofXacross devices. The sharding ofXcan’t be read at trace time, so the caller declares it:'none'(default): leaveXalone, neither sharding nor batching itsnaxis. Safe whatever the sharding ofXis.'autobatch': loop over thenaxis to bound memory. AssumesXis not sharded over the mesh'data'axis; batching a sharded axis would serialize the devices.'shard_and_autobatch': shard thenaxis ofXover the mesh'data'axis with a manualjax.shard_mapand batch the per-device chunk. Falls back to'autobatch'if the mesh has no'data'axis.
max_io_nbytes (
int, default:134217728) – Soft limit, in bytes, on the input plus output of each batch of the autobatching loops (per device). Lower it to reduce peak memory at the cost of more iterations.
- Returns:
Float32[Array, '*trace_shape n']|Float32[Array, '*trace_shape k n']– The predictions for each chain and iteration of the MCMC.