bartz.mcmcstep.step

bartz.mcmcstep.step(key, state)[source]

Do one MCMC step.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • state (State) – A BART mcmc state, as created by init.

Returns:

StateThe new BART mcmc state.

Notes

The memory of the input state is re-used for the output state, so the input state can not be used any more after calling step. All this applies outside of jax.jit.