bartz.mcmcstep.Forest¶
- class bartz.mcmcstep.Forest(var_tree, split_tree, affluence_tree, leaf_tree, grow_prop_count, prune_prop_count, grow_acc_count, prune_acc_count, max_split, blocked_vars, p_nonterminal, p_propose_grow, leaf_indices, count_tree, prec_tree, min_points_per_decision_node, min_points_per_leaf, log_trans_prior, log_likelihood, leaf_prior_cov_inv, log_s, theta, a, b, rho)[source]¶
Represents the MCMC state of a sum of trees.
- var_tree: UInt[Array, '*chains num_trees half_tree_size']¶
The decision axes.
- split_tree: UInt[Array, '*chains num_trees half_tree_size']¶
The decision boundaries.
- affluence_tree: Bool[Array, '*chains num_trees half_tree_size']¶
Marks leaves that can be grown.
- leaf_tree: Float32[Array, '*chains num_trees 2*half_tree_size'] | Float32[Array, '*chains num_trees k 2*half_tree_size']¶
The leaf values.
- grow_prop_count: Int32[Array, '*chains']¶
The number of grow proposals made during one full MCMC cycle.
- prune_prop_count: Int32[Array, '*chains']¶
The number of prune proposals made during one full MCMC cycle.
- grow_acc_count: Int32[Array, '*chains']¶
The number of grow moves accepted during one full MCMC cycle.
- prune_acc_count: Int32[Array, '*chains']¶
The number of prune moves accepted during one full MCMC cycle.
- max_split: UInt[Array, 'p']¶
The maximum split index for each predictor.
- blocked_vars: UInt[Array, 'q'] | None¶
Indices of variables that are not used. This shall include at least the
isuch thatmax_split[i] == 0, otherwise behavior is undefined.
- p_nonterminal: Float32[Array, '2*half_tree_size']¶
The prior probability of each node being nonterminal, conditional on its ancestors. Includes the nodes at maximum depth which should be set to 0.
- p_propose_grow: Float32[Array, 'half_tree_size']¶
The unnormalized probability of picking a leaf for a grow proposal.
- leaf_indices: UInt[Array, 'num_trees *chains n']¶
The index of the leaf each datapoints falls into, for each tree.
The chain axis sits after
num_trees(not leading, unlike sibling fields) so the per-treelax.scaninstep, under the chainvmap, avoids a transpose of this large array that otherwise inflates gpu peak memory.
- count_tree: UInt32[Array, '*chains num_trees 2*half_tree_size'] | None¶
The number of datapoints per leaf. Valid at the leaves and at the nodes involved in the latest moves, dirty elsewhere.
Noneif the error precision is weighted and there are no minimum-points-per-node constraints, which makes the counts unused.
- prec_tree: Float32[Array, '*chains num_trees 2*half_tree_size'] | Float32[Array, '*chains num_trees k k 2*half_tree_size'] | None¶
The likelihood precision scale summed over the datapoints in each leaf. Valid at the leaves and at the nodes involved in the latest moves, dirty elsewhere.
Noneif the error precision is not weighted, in which casecount_treetakes its place.
- min_points_per_decision_node: Int32[Array, ''] | None¶
The minimum number of data points in a decision node.
- log_trans_prior: Float32[Array, '*chains num_trees'] | None¶
The log transition and prior Metropolis-Hastings ratio for the proposed move on each tree.
- leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None¶
The prior precision matrix of a leaf, conditional on the tree structure. For the univariate case (k=1), this is a scalar (the inverse variance). The prior covariance of the sum of trees is
num_trees * leaf_prior_cov_inv^-1.
- log_s: Float32[Array, '*chains p'] | None¶
The logarithm of the prior probability for choosing a variable to split along in a decision rule, conditional on the ancestors. Not normalized. If
None, use a uniform distribution.