bartz.mcmcstep.Forest

class bartz.mcmcstep.Forest(var_tree, split_tree, affluence_tree, leaf_tree, grow_prop_count, prune_prop_count, grow_acc_count, prune_acc_count, max_split, blocked_vars, p_nonterminal, p_propose_grow, leaf_indices, count_tree, prec_tree, min_points_per_decision_node, min_points_per_leaf, log_trans_prior, log_likelihood, leaf_prior_cov_inv, log_s, theta, a, b, rho)[source]

Represents the MCMC state of a sum of trees.

var_tree: UInt[Array, '*chains num_trees half_tree_size']

The decision axes.

split_tree: UInt[Array, '*chains num_trees half_tree_size']

The decision boundaries.

affluence_tree: Bool[Array, '*chains num_trees half_tree_size']

Marks leaves that can be grown.

leaf_tree: Float32[Array, '*chains num_trees 2*half_tree_size'] | Float32[Array, '*chains num_trees k 2*half_tree_size']

The leaf values.

grow_prop_count: Int32[Array, '*chains']

The number of grow proposals made during one full MCMC cycle.

prune_prop_count: Int32[Array, '*chains']

The number of prune proposals made during one full MCMC cycle.

grow_acc_count: Int32[Array, '*chains']

The number of grow moves accepted during one full MCMC cycle.

prune_acc_count: Int32[Array, '*chains']

The number of prune moves accepted during one full MCMC cycle.

max_split: UInt[Array, 'p']

The maximum split index for each predictor.

blocked_vars: UInt[Array, 'q'] | None

Indices of variables that are not used. This shall include at least the i such that max_split[i] == 0, otherwise behavior is undefined.

p_nonterminal: Float32[Array, '2*half_tree_size']

The prior probability of each node being nonterminal, conditional on its ancestors. Includes the nodes at maximum depth which should be set to 0.

p_propose_grow: Float32[Array, 'half_tree_size']

The unnormalized probability of picking a leaf for a grow proposal.

leaf_indices: UInt[Array, 'num_trees *chains n']

The index of the leaf each datapoints falls into, for each tree.

The chain axis sits after num_trees (not leading, unlike sibling fields) so the per-tree lax.scan in step, under the chain vmap, avoids a transpose of this large array that otherwise inflates gpu peak memory.

count_tree: UInt32[Array, '*chains num_trees 2*half_tree_size'] | None

The number of datapoints per leaf. Valid at the leaves and at the nodes involved in the latest moves, dirty elsewhere. None if the error precision is weighted and there are no minimum-points-per-node constraints, which makes the counts unused.

prec_tree: Float32[Array, '*chains num_trees 2*half_tree_size'] | Float32[Array, '*chains num_trees k k 2*half_tree_size'] | None

The likelihood precision scale summed over the datapoints in each leaf. Valid at the leaves and at the nodes involved in the latest moves, dirty elsewhere. None if the error precision is not weighted, in which case count_tree takes its place.

min_points_per_decision_node: Int32[Array, ''] | None

The minimum number of data points in a decision node.

min_points_per_leaf: Int32[Array, ''] | None

The minimum number of data points in a leaf node.

log_trans_prior: Float32[Array, '*chains num_trees'] | None

The log transition and prior Metropolis-Hastings ratio for the proposed move on each tree.

log_likelihood: Float32[Array, '*chains num_trees'] | None

The log likelihood ratio.

leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None

The prior precision matrix of a leaf, conditional on the tree structure. For the univariate case (k=1), this is a scalar (the inverse variance). The prior covariance of the sum of trees is num_trees * leaf_prior_cov_inv^-1.

log_s: Float32[Array, '*chains p'] | None

The logarithm of the prior probability for choosing a variable to split along in a decision rule, conditional on the ancestors. Not normalized. If None, use a uniform distribution.

theta: Float32[Array, '*chains'] | None

The concentration parameter for the Dirichlet prior on the variable distribution s. Required only to update log_s.

a: Float32[Array, ''] | None

Parameter of the prior on theta. Required only to sample theta.

b: Float32[Array, ''] | None

Parameter of the prior on theta. Required only to sample theta.

rho: Float32[Array, ''] | None

Parameter of the prior on theta. Required only to sample theta.

property has_chains: bool[source]

Whether this forest carries an explicit chain axis.