bartz.mcmcloop.evaluate_trace

bartz.mcmcloop.evaluate_trace(X, trace, *, flatten_chains=False, out_chain_axis_w_trees=0, test_points='none', max_io_nbytes=134217728)[source]

Compute predictions for all iterations of the BART MCMC.

Parameters:
  • X (UInt[Array, 'p n']) – The predictors matrix, with p predictors and n observations.

  • trace (EvaluableTrace) – A main trace of the BART MCMC, as returned by run_mcmc.

  • flatten_chains (bool, default: False) – If True and trace has a chain axis, collapse it into the sample axis, so the leading dimension of the output is num_chains * num_samples instead of (num_chains, num_samples).

  • out_chain_axis_w_trees (int, default: 0) – Position of the chain axis in the output. Interpreted against the intermediate, pre-tree-reduction layout (sample, tree, *k, n); after summing over the tree axis the chain ends up one position earlier if it was after the tree axis. Negative values count from the end. Ignored when trace has no chain axis or flatten_chains is True.

  • test_points (Literal['none', 'autobatch', 'shard_and_autobatch'], default: 'none') –

    How to handle the observation (n) axis of X across devices. The sharding of X can’t be read at trace time, so the caller declares it:

    • 'none' (default): leave X alone, neither sharding nor batching its n axis. Safe whatever the sharding of X is.

    • 'autobatch': loop over the n axis to bound memory. Assumes X is not sharded over the mesh 'data' axis; batching a sharded axis would serialize the devices.

    • 'shard_and_autobatch': shard the n axis of X over the mesh 'data' axis with a manual jax.shard_map and batch the per-device chunk. Falls back to 'autobatch' if the mesh has no 'data' axis.

  • max_io_nbytes (int, default: 134217728) – Soft limit, in bytes, on the input plus output of each batch of the autobatching loops (per device). Lower it to reduce peak memory at the cost of more iterations.

Returns:

Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']The predictions for each chain and iteration of the MCMC.