bartz.mcmcstep

Functions that implement the BART posterior MCMC initialization and update step.

Initialization and stepping

init(*, X, y[, outcome_type, error_cov_inv, ...])

Make a BART posterior sampling MCMC initial state.

step(key, state)

Do one MCMC step.

make_p_nonterminal(d[, alpha, beta])

Prepare the p_nonterminal argument to init.

OutcomeType(*values)

Likelihood types for each outcome component in the regression.

MCMC state

State(_chain_anchor, X, binary_y, z, ...)

Represents the MCMC state of BART.

Forest(var_tree, split_tree, affluence_tree, ...)

Represents the MCMC state of a sum of trees.

StepConfig(steps_done, sparse_on_at, ...)

Options for the MCMC step.

Wishart(nu, rate, value)

A precision matrix with a Wishart prior, bundled with its current value.

DiagWishart(nu, rate, value)

A diagonal precision matrix with independent chi-square diagonal entries.

Reduction strategies

Configurations for the per-leaf scatter-add reductions, to pass to init.

ReductionConfig()

Select and configure an indexed-reduce (scatter-add) implementation.

BatchedReduction([num_batches, ...])

Segment-sum with optional batching along the datapoints.

AutoBatchedReduction([min_batch_size, ...])

BatchedReduction that picks num_batches automatically per platform.

OneHotReduction([method, n_inner])

Dense one-hot reduction.

AutoOneHotReduction([min_matmul_bins])

OneHotReduction that picks method and n_inner automatically.

PallasReduction([block_size, num_blocks, ...])

Blocked one-hot scatter-add written as a Pallas kernel.