bartz.mcmcloop.run_mcmc

bartz.mcmcloop.run_mcmc(key, state, n_save, *, n_burn=0, n_skip=1, inner_loop_length=None, callback=None, callback_state=None)[source]

Run the MCMC for the BART posterior.

Parameters:
  • key (Key[Array, '']) – A key for random number generation.

  • state (State) – The initial MCMC state, as created and updated by the functions in bartz.mcmcstep. The MCMC loop uses buffer donation to avoid copies, so this variable is invalidated after running run_mcmc. Make a copy beforehand to use it again.

  • n_save (int) – The number of iterations to save.

  • n_burn (int, default: 0) – The number of initial iterations which are not saved.

  • n_skip (int, default: 1) – The number of iterations to skip between each saved iteration, plus 1. The effective burn-in is n_burn + n_skip - 1.

  • inner_loop_length (int | None, default: None) – The MCMC loop is split into an outer and an inner loop. The outer loop is in Python, while the inner loop is in JAX. inner_loop_length is the number of iterations of the inner loop to run for each iteration of the outer loop. If not specified, the outer loop will iterate just once, with all iterations done in a single inner loop run. The inner stride is unrelated to the stride used for saving the trace.

  • callback (Callback | None, default: None) – An arbitrary function run during the loop after updating the state. For the signature, see Callback. The callback is called under the jax jit, so the argument values are not available at the time the Python code is executed. Use the utilities in jax.debug to access the values at actual runtime. The callback may return new values for the MCMC state and the callback state.

  • callback_state (PyTree[Any, 'T'], default: None) – The initial custom state for the callback.

Returns:

RunMCMCResultA namedtuple with the final state, the burn-in trace, and the main trace.

Raises:

RuntimeError – If run_mcmc detects it’s being invoked in a jax.jit-wrapped context and with settings that would create unrolled loops in the trace.

Notes

The number of MCMC updates is n_burn + n_skip * n_save. The traces do not include the initial state, and include the final state.

Resuming is exact: passing the returned final_state and the same key to a new call continues the run as if it had not stopped, so splitting a run into several consecutive calls gives the same result as a single call.