Do one MCMC step.
- Parameters:
key (Key[Array, '']) – A jax random key.
state (State) – A BART mcmc state, as created by init.
- Returns:
State – The new BART mcmc state.
Notes
The memory of the input state is re-used for the output state, so the input
state can not be used any more after calling step. All this applies
outside of jax.jit.