MCMC setup and step¶
Functions that implement the BART posterior MCMC initialization and update step.
- class bartz.mcmcstep.Forest(leaf_tree, var_tree, split_tree, affluence_tree, max_split, blocked_vars, p_nonterminal, p_propose_grow, leaf_indices, min_points_per_decision_node, min_points_per_leaf, log_trans_prior, log_likelihood, grow_prop_count, prune_prop_count, grow_acc_count, prune_acc_count, leaf_prior_cov_inv, log_s, theta, a, b, rho)[source]¶
Represents the MCMC state of a sum of trees.
- leaf_tree: Float32[Array, '*chains num_trees 2**d'] | Float32[Array, '*chains num_trees k 2**d']¶
The leaf values.
- var_tree: UInt[Array, '*chains num_trees 2**(d-1)']¶
The decision axes.
- split_tree: UInt[Array, '*chains num_trees 2**(d-1)']¶
The decision boundaries.
- affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)']¶
Marks leaves that can be grown.
- max_split: UInt[Array, 'p']¶
The maximum split index for each predictor.
- blocked_vars: UInt[Array, 'q'] | None¶
Indices of variables that are not used. This shall include at least the
isuch thatmax_split[i] == 0, otherwise behavior is undefined.
- p_nonterminal: Float32[Array, '2**d']¶
The prior probability of each node being nonterminal, conditional on its ancestors. Includes the nodes at maximum depth which should be set to 0.
- p_propose_grow: Float32[Array, '2**(d-1)']¶
The unnormalized probability of picking a leaf for a grow proposal.
- leaf_indices: UInt[Array, '*chains num_trees n']¶
The index of the leaf each datapoints falls into, for each tree.
- min_points_per_decision_node: Int32[Array, ''] | None¶
The minimum number of data points in a decision node.
- min_points_per_leaf: Int32[Array, ''] | None¶
The minimum number of data points in a leaf node.
- log_trans_prior: Float32[Array, '*chains num_trees'] | None¶
The log transition and prior Metropolis-Hastings ratio for the proposed move on each tree.
- log_likelihood: Float32[Array, '*chains num_trees'] | None¶
The log likelihood ratio.
- grow_prop_count: Int32[Array, '*chains']¶
The number of grow proposals made during one full MCMC cycle.
- prune_prop_count: Int32[Array, '*chains']¶
The number of prune proposals made during one full MCMC cycle.
- grow_acc_count: Int32[Array, '*chains']¶
The number of grow moves accepted during one full MCMC cycle.
- prune_acc_count: Int32[Array, '*chains']¶
The number of prune moves accepted during one full MCMC cycle.
- leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None¶
The prior precision matrix of a leaf, conditional on the tree structure. For the univariate case (k=1), this is a scalar (the inverse variance). The prior covariance of the sum of trees is
num_trees * leaf_prior_cov_inv^-1.
- log_s: Float32[Array, '*chains p'] | None¶
The logarithm of the prior probability for choosing a variable to split along in a decision rule, conditional on the ancestors. Not normalized. If
None, use a uniform distribution.
- theta: Float32[Array, '*chains'] | None¶
The concentration parameter for the Dirichlet prior on the variable distribution
s. Required only to updatelog_s.
- a: Float32[Array, ''] | None¶
Parameter of the prior on
theta. Required only to sampletheta. Seestep_theta.
- b: Float32[Array, ''] | None¶
Parameter of the prior on
theta. Required only to sampletheta. Seestep_theta.
- class bartz.mcmcstep.State(X, y, z, offset, resid, error_cov_inv, prec_scale, error_cov_df, error_cov_scale, forest, config)[source]¶
Represents the MCMC state of BART.
- X: UInt[Array, 'p n']¶
The predictors.
- y: Float32[Array, 'n'] | Float32[Array, 'k n'] | Bool[Array, 'n']¶
The response. If the data type is
bool, the model is binary regression.
- z: None | Float32[Array, '*chains n']¶
The latent variable for binary regression.
Nonein continuous regression.
- offset: Float32[Array, ''] | Float32[Array, 'k']¶
Constant shift added to the sum of trees.
- resid: Float32[Array, '*chains n'] | Float32[Array, '*chains k n']¶
- error_cov_inv: Float32[Array, '*chains'] | Float32[Array, '*chains k k'] | None¶
The inverse error covariance (scalar for univariate, matrix for multivariate).
Nonein binary regression.
- prec_scale: Float32[Array, 'n'] | None¶
The scale on the error precision, i.e.,
1 / error_scale ** 2.Nonein binary regression.
- error_cov_df: Float32[Array, ''] | None¶
The df parameter of the inverse Wishart prior on the noise covariance. For the univariate case, the relationship to the inverse gamma prior parameters is
alpha = df / 2.Nonein binary regression.
- error_cov_scale: Float32[Array, ''] | Float32[Array, 'k k'] | None¶
The scale parameter of the inverse Wishart prior on the noise covariance. For the univariate case, the relationship to the inverse gamma prior parameters is
beta = scale / 2.Nonein binary regression.
- config: StepConfig¶
Metadata and configurations for the MCMC step.
- class bartz.mcmcstep.StepConfig(steps_done, sparse_on_at, resid_num_batches, count_num_batches, prec_num_batches, prec_count_num_trees, mesh)[source]¶
Options for the MCMC step.
- steps_done: Int32[Array, '']¶
The number of MCMC steps completed so far.
- sparse_on_at: Int32[Array, ''] | None¶
After how many steps to turn on variable selection.
- resid_num_batches: int | None¶
The number of batches for computing the sum of residuals. If
None, they are computed with no batching.
- count_num_batches: int | None¶
The number of batches for computing counts. If
None, they are computed with no batching.
- prec_num_batches: int | None¶
The number of batches for computing precision scales. If
None, they are computed with no batching.
- prec_count_num_trees: int | None¶
Batch size for processing trees to compute count and prec trees.
- bartz.mcmcstep.init(*, X, y, offset, max_split, num_trees, p_nonterminal, leaf_prior_cov_inv, error_cov_df=None, error_cov_scale=None, error_scale=None, min_points_per_decision_node=None, resid_num_batches='auto', count_num_batches='auto', prec_num_batches='auto', prec_count_num_trees='auto', save_ratios=False, filter_splitless_vars=0, min_points_per_leaf=None, log_s=None, theta=None, a=None, b=None, rho=None, sparse_on_at=None, num_chains=None, mesh=None, target_platform=None)[source]¶
Make a BART posterior sampling MCMC initial state.
- Parameters:
X (
UInt[Any, 'p n']) – The predictors. Note this is trasposed compared to the usual convention.y (
Float32[Any, 'n']|Float32[Any, 'k n']|Bool[Any, 'n']) – The response. If the data type isbool, the regression model is binary regression with probit. If two-dimensional, the outcome is multivariate with the first axis indicating the component.offset (
float|Float32[Any, '']|Float32[Any, 'k']) – Constant shift added to the sum of trees. 0 if not specified.max_split (
UInt[Any, 'p']) – The maximum split index for each variable. All split ranges start at 1.num_trees (
int) – The number of trees in the forest.p_nonterminal (
Float32[Any, 'd_minus_1']) – The probability of a nonterminal node at each depth. The maximum depth of trees is fixed by the length of this array. Usemake_p_nonterminalto set it with the conventional formula.leaf_prior_cov_inv (
float|Float32[Any, '']|Float32[Array, 'k k']) – The prior precision matrix of a leaf, conditional on the tree structure. For the univariate case (k=1), this is a scalar (the inverse variance). The prior covariance of the sum of trees isnum_trees * leaf_prior_cov_inv^-1. The prior mean of leaves is always zero.error_cov_df (
float|Float32[Any, '']|None, default:None)error_cov_scale (
float|Float32[Any, '']|Float32[Array, 'k k']|None, default:None) – The df and scale parameters of the inverse Wishart prior on the error covariance. For the univariate case, the relationship to the inverse gamma prior parameters isalpha = df / 2,beta = scale / 2. Leave unspecified for binary regression.error_scale (
Float32[Any, 'n']|None, default:None) – Each error is scaled by the corresponding factor inerror_scale, so the error variance fory[i]issigma2 * error_scale[i] ** 2. Not supported for binary regression. If not specified, defaults to 1 for all points, but potentially skipping calculations.min_points_per_decision_node (
int|Integer[Any, '']|None, default:None) – The minimum number of data points in a decision node. 0 if not specified.resid_num_batches (
int|None|Literal['auto'], default:'auto')count_num_batches (
int|None|Literal['auto'], default:'auto')prec_num_batches (
int|None|Literal['auto'], default:'auto') – The number of batches, along datapoints, for summing the residuals, counting the number of datapoints in each leaf, and computing the likelihood precision in each leaf, respectively.Nonefor no batching. If ‘auto’, it’s chosen automatically based on the target platform; see the description oftarget_platformbelow for how it is determined.prec_count_num_trees (
int|None|Literal['auto'], default:'auto') – The number of trees to process at a time when counting datapoints or computing the likelihood precision. IfNone, do all trees at once, which may use too much memory. If ‘auto’ (default), it’s chosen automatically.save_ratios (
bool, default:False) – Whether to save the Metropolis-Hastings ratios.filter_splitless_vars (
int, default:0) – The maximum number of variables without splits that can be ignored. If there are more,initraises an exception.min_points_per_leaf (
int|Integer[Any, '']|None, default:None) – The minimum number of datapoints in a leaf node. 0 if not specified. Unlikemin_points_per_decision_node, this constraint is not taken into account in the Metropolis-Hastings ratio because it would be expensive to compute. Grow moves that would violate this constraint are vetoed. This parameter is independent ofmin_points_per_decision_nodeand there is no check that they are coherent. It makes sense to setmin_points_per_decision_node >= 2 * min_points_per_leaf.log_s (
Float32[Any, 'p']|None, default:None) – The logarithm of the prior probability for choosing a variable to split along in a decision rule, conditional on the ancestors. Not normalized. If not specified, use a uniform distribution. If not specified andthetaorrho,a,bare, it’s initialized automatically.theta (
float|Float32[Any, '']|None, default:None) – The concentration parameter for the Dirichlet prior ons. Required only to updatelog_s. If not specified, andrho,a,bare specified, it’s initialized automatically.a (
float|Float32[Any, '']|None, default:None)b (
float|Float32[Any, '']|None, default:None)rho (
float|Float32[Any, '']|None, default:None) – Parameters of the prior ontheta. Required only to sampletheta.sparse_on_at (
int|Integer[Any, '']|None, default:None) – After how many MCMC steps to turn on variable selection.num_chains (
int|None, default:None) – The number of independent MCMC chains to represent in the state. Single chain with scalar values if not specified.mesh (
Mesh|dict[str,int] |None, default:None) –A jax mesh used to shard data and computation across multiple devices. If it has a ‘chains’ axis, that axis is used to shard the chains. If it has a ‘data’ axis, that axis is used to shard the datapoints.
As a shorthand, if a dictionary mapping axis names to axis size is passed, the corresponding mesh is created, e.g.,
dict(chains=4, data=2)will let jax pick 8 devices to split chains (which must be a multiple of 4) across 4 pairs of devices, where in each pair the data is split in two.Note: if a mesh is passed, the arrays are always sharded according to it. In particular even if the mesh has no ‘chains’ or ‘data’ axis, the arrays will be replicated on all devices in the mesh.
target_platform (
Literal['cpu','gpu'] |None, default:None) –Platform (‘cpu’ or ‘gpu’) used to determine the number of batches automatically. If
meshis specified, the platform is inferred from the devices in the mesh. Otherwise, ifyis a concrete array (i.e.,initis not invoked in ajax.jitcontext), the platform is set to the platform ofy. Otherwise, usetarget_platform.To avoid confusion, in all cases where the
target_platformargument would be ignored,initraises an exception iftarget_platformis set.
- Returns:
State– An initialized BART MCMC state.- Raises:
ValueError – If
yis boolean and arguments unused in binary regression are set.
Notes
In decision nodes, the values in
X[i, :]are compared to a cutpoint out of the range[1, 2, ..., max_split[i]]. A point belongs to the left child iffX[i, j] < cutpoint. Thus it makes sense forX[i, :]to be integers in the range[0, 1, ..., max_split[i]].
- bartz.mcmcstep.make_p_nonterminal(d, alpha, beta)[source]¶
Prepare the
p_nonterminalargument toinit.It is calculated according to the formula:
P_nt(depth) = alpha / (1 + depth)^beta, with depth 0-based
- Parameters:
d (
int) – The maximum depth of the trees (d=1 means tree with only root node)alpha (
float|Float32[Array, '']) – The a priori probability of the root node having children, conditional on it being possiblebeta (
float|Float32[Array, '']) – The exponent of the power decay of the probability of having children with depth.
- Returns:
Float32[Array, '{d}-1']– An array of probabilities, one per tree level but the last.