bartz.mcmcloop

Functions that implement the full BART posterior MCMC loop.

Running the MCMC

run_mcmc(key, state, n_save, *[, n_burn, ...])

Run the MCMC for the BART posterior.

RunMCMCResult(final_state, burnin_trace, ...)

Return value of run_mcmc.

BurninTrace(has_chains, mesh, ...)

MCMC trace with only diagnostic values.

MainTrace(has_chains, mesh, grow_prop_count, ...)

MCMC trace with trees and diagnostic values.

Evaluating the trace

evaluate_trace(X, trace, *[, ...])

Compute predictions for all iterations of the BART MCMC.

compute_varcount(p, trace, *[, out_chain_axis])

Count how many times each predictor is used in each MCMC state.

EvaluableTrace(*args, **kwargs)

Structural type of the traces accepted by evaluate_trace.

Progress callbacks

The entry points are make_print_callback and make_tqdm_callback.

make_print_callback(state, *[, dot_every, ...])

Prepare a progress-printing callback for run_mcmc.

make_tqdm_callback(state, *[, update_every, ...])

Prepare a tqdm progress-bar callback for run_mcmc.

Callback(*args, **kwargs)

Callback type for run_mcmc.

CallbackState

print_callback(*, state, burnin, i_total, ...)

Print a dot and/or a report periodically during the MCMC.

tqdm_callback(*, state, i_total, n_burn, ...)

Advance a tqdm progress bar during the MCMC.

PrintCallbackState(dot_every, report_every, ...)

State for print_callback.

TqdmCallbackState(bar_id, update_every, ...)

State for tqdm_callback.

StatsAccumulator(sums, count)

Running average of the forest diagnostics shown during the MCMC.

StatsReport(grow_prop, move_acc, ...)

Forest diagnostics produced by StatsAccumulator.report for one report.