MCMC setup and step¶
Functions that implement the BART posterior MCMC initialization and update step.
- class bartz.mcmcstep.Forest(leaf_tree, var_tree, split_tree, affluence_tree, max_split, blocked_vars, p_nonterminal, p_propose_grow, leaf_indices, min_points_per_decision_node, min_points_per_leaf, log_trans_prior, log_likelihood, grow_prop_count, prune_prop_count, grow_acc_count, prune_acc_count, leaf_prior_cov_inv, log_s, theta, a, b, rho)[source]¶
Represents the MCMC state of a sum of trees.
- Parameters:
leaf_tree (
Float32[Array, '*chains num_trees 2**d']|Float32[Array, '*chains num_trees k 2**d']) – The leaf values.var_tree (
UInt[Array, '*chains num_trees 2**(d-1)']) – The decision axes.split_tree (
UInt[Array, '*chains num_trees 2**(d-1)']) – The decision boundaries.affluence_tree (
Bool[Array, '*chains num_trees 2**(d-1)']) – Marks leaves that can be grown.max_split (
UInt[Array, 'p']) – The maximum split index for each predictor.blocked_vars (
UInt[Array, 'q']|None) – Indices of variables that are not used. This shall include at least theisuch thatmax_split[i] == 0, otherwise behavior is undefined.p_nonterminal (
Float32[Array, '2**d']) – The prior probability of each node being nonterminal, conditional on its ancestors. Includes the nodes at maximum depth which should be set to 0.p_propose_grow (
Float32[Array, '2**(d-1)']) – The unnormalized probability of picking a leaf for a grow proposal.leaf_indices (
UInt[Array, '*chains num_trees n']) – The index of the leaf each datapoints falls into, for each tree.min_points_per_decision_node (
Int32[Array, '']|None) – The minimum number of data points in a decision node.min_points_per_leaf (
Int32[Array, '']|None) – The minimum number of data points in a leaf node.log_trans_prior (
Float32[Array, '*chains num_trees']|None) – The log transition and prior Metropolis-Hastings ratio for the proposed move on each tree.log_likelihood (
Float32[Array, '*chains num_trees']|None) – The log likelihood ratio.grow_prop_count (
Int32[Array, '*chains'])prune_prop_count (
Int32[Array, '*chains']) – The number of grow/prune proposals made during one full MCMC cycle.grow_acc_count (
Int32[Array, '*chains'])prune_acc_count (
Int32[Array, '*chains']) – The number of grow/prune moves accepted during one full MCMC cycle.leaf_prior_cov_inv (
Float32[Array, '']|Float32[Array, 'k k']|None) – The prior precision matrix of a leaf, conditional on the tree structure. For the univariate case (k=1), this is a scalar (the inverse variance). The prior covariance of the sum of trees isnum_trees * leaf_prior_cov_inv^-1.log_s (
Float32[Array, '*chains p']|None) – The logarithm of the prior probability for choosing a variable to split along in a decision rule, conditional on the ancestors. Not normalized. IfNone, use a uniform distribution.theta (
Float32[Array, '*chains']|None) – The concentration parameter for the Dirichlet prior on the variable distributions. Required only to updates.a (
Float32[Array, '']|None)b (
Float32[Array, '']|None)rho (
Float32[Array, '']|None) – Parameters of the prior ontheta. Required only to sampletheta. Seestep_theta.
- class bartz.mcmcstep.State(X, y, z, offset, resid, error_cov_inv, prec_scale, error_cov_df, error_cov_scale, forest, config)[source]¶
Represents the MCMC state of BART.
- Parameters:
X (
UInt[Array, 'p n']) – The predictors.y (
Float32[Array, 'n']|Float32[Array, 'k n']|Bool[Array, 'n']) – The response. If the data type isbool, the model is binary regression.resid (
Float32[Array, '*chains n']|Float32[Array, '*chains k n']) – The residuals (yorzminus sum of trees).z (
None|Float32[Array, '*chains n']) – The latent variable for binary regression.Nonein continuous regression.offset (
Float32[Array, '']|Float32[Array, 'k']) – Constant shift added to the sum of trees.error_cov_inv (
Float32[Array, '*chains']|Float32[Array, '*chains k k']|None) – The inverse error covariance (scalar for univariate, matrix for multivariate).Nonein binary regression.prec_scale (
Float32[Array, 'n']|None) – The scale on the error precision, i.e.,1 / error_scale ** 2.Nonein binary regression.error_cov_df (
Float32[Array, '']|None)error_cov_scale (
Float32[Array, '']|Float32[Array, 'k k']|None) – The df and scale parameters of the inverse Wishart prior on the noise covariance. For the univariate case, the relationship to the inverse gamma prior parameters isalpha = df / 2,beta = scale / 2.Nonein binary regression.forest (
Forest) – The sum of trees model.config (
StepConfig) – Metadata and configurations for the MCMC step.
- class bartz.mcmcstep.StepConfig(steps_done, sparse_on_at, resid_batch_size, count_batch_size, mesh)[source]¶
Options for the MCMC step.
- Parameters:
steps_done (
Int32[Array, '']) – The number of MCMC steps completed so far.sparse_on_at (
Int32[Array, '']|None) – After how many steps to turn on variable selection.resid_batch_size (
int|None)count_batch_size (
int|None) – The data batch sizes for computing the sufficient statistics. IfNone, they are computed with no batching.mesh (
Mesh|None) – The mesh used to shard data and computation across multiple devices.
- bartz.mcmcstep.init(*, X, y, offset, max_split, num_trees, p_nonterminal, leaf_prior_cov_inv, error_cov_df=None, error_cov_scale=None, error_scale=None, min_points_per_decision_node=None, resid_batch_size='auto', count_batch_size='auto', save_ratios=False, filter_splitless_vars=True, min_points_per_leaf=None, log_s=None, theta=None, a=None, b=None, rho=None, sparse_on_at=None, num_chains=None, mesh=None, target_platform=None)[source]¶
Make a BART posterior sampling MCMC initial state.
- Parameters:
X (
UInt[Any, 'p n']) – The predictors. Note this is trasposed compared to the usual convention.y (
Float32[Any, 'n']|Float32[Any, 'k n']|Bool[Any, 'n']) – The response. If the data type isbool, the regression model is binary regression with probit. If two-dimensional, the outcome is multivariate with the first axis indicating the component.offset (
float|Float32[Any, '']|Float32[Any, 'k']) – Constant shift added to the sum of trees. 0 if not specified.max_split (
UInt[Any, 'p']) – The maximum split index for each variable. All split ranges start at 1.num_trees (
int) – The number of trees in the forest.p_nonterminal (
Float32[Any, 'd_minus_1']) – The probability of a nonterminal node at each depth. The maximum depth of trees is fixed by the length of this array.leaf_prior_cov_inv (
float|Float32[Any, '']|Float32[Array, 'k k']) – The prior precision matrix of a leaf, conditional on the tree structure. For the univariate case (k=1), this is a scalar (the inverse variance). The prior covariance of the sum of trees isnum_trees * leaf_prior_cov_inv^-1. The prior mean of leaves is always zero.error_cov_df (
float|Float32[Any, '']|None, default:None)error_cov_scale (
float|Float32[Any, '']|Float32[Array, 'k k']|None, default:None) – The df and scale parameters of the inverse Wishart prior on the error covariance. For the univariate case, the relationship to the inverse gamma prior parameters isalpha = df / 2,beta = scale / 2. Leave unspecified for binary regression.error_scale (
Float32[Any, 'n']|None, default:None) – Each error is scaled by the corresponding factor inerror_scale, so the error variance fory[i]issigma2 * error_scale[i] ** 2. Not supported for binary regression. If not specified, defaults to 1 for all points, but potentially skipping calculations.min_points_per_decision_node (
int|Integer[Any, '']|None, default:None) – The minimum number of data points in a decision node. 0 if not specified.resid_batch_size (
int|None|Literal['auto'], default:'auto')count_batch_size (
int|None|Literal['auto'], default:'auto') – The batch sizes, along datapoints, for summing the residuals and counting the number of datapoints in each leaf.Nonefor no batching. If ‘auto’, it’s chosen automatically based on the target platform; see the description oftarget_platformbelow for how it is determined.save_ratios (
bool, default:False) – Whether to save the Metropolis-Hastings ratios.filter_splitless_vars (
bool, default:True) – Whether to checkmax_splitfor variables without available cutpoints. If any are found, they are put into a list of variables to exclude from the MCMC. IfFalse, no check is performed, but the results may be wrong if any variable is blocked. The function is jax-traceable only if this is set toFalse.min_points_per_leaf (
int|Integer[Any, '']|None, default:None) – The minimum number of datapoints in a leaf node. 0 if not specified. Unlikemin_points_per_decision_node, this constraint is not taken into account in the Metropolis-Hastings ratio because it would be expensive to compute. Grow moves that would violate this constraint are vetoed. This parameter is independent ofmin_points_per_decision_nodeand there is no check that they are coherent. It makes sense to setmin_points_per_decision_node >= 2 * min_points_per_leaf.log_s (
Float32[Any, 'p']|None, default:None) – The logarithm of the prior probability for choosing a variable to split along in a decision rule, conditional on the ancestors. Not normalized. If not specified, use a uniform distribution. If not specified andthetaorrho,a,bare, it’s initialized automatically.theta (
float|Float32[Any, '']|None, default:None) – The concentration parameter for the Dirichlet prior ons. Required only to updatelog_s. If not specified, andrho,a,bare specified, it’s initialized automatically.a (
float|Float32[Any, '']|None, default:None)b (
float|Float32[Any, '']|None, default:None)rho (
float|Float32[Any, '']|None, default:None) – Parameters of the prior ontheta. Required only to sampletheta.sparse_on_at (
int|Integer[Any, '']|None, default:None) – After how many MCMC steps to turn on variable selection.num_chains (
int|None, default:None) – The number of independent MCMC chains to represent in the state. Single chain with scalar values if not specified.mesh (
Mesh|dict[str,int] |None, default:None) –A jax mesh used to shard data and computation across multiple devices. If it has a ‘chains’ axis, that axis is used to shard the chains. If it has a ‘data’ axis, that axis is used to shard the datapoints.
As a shorthand, if a dictionary mapping axis names to axis size is passed, the corresponding mesh is created, e.g.,
dict(chains=4, data=2)will let jax pick 8 devices to split chains (which must be a multiple of 4) across 4 pairs of devices, where in each pair the data is split in two.Note: if a mesh is passed, the arrays are always sharded according to it. In particular even if the mesh has no ‘chains’ or ‘data’ axis, the arrays will be replicated on all devices in the mesh.
target_platform (
Literal['cpu','gpu'] |None, default:None) –Platform (‘cpu’ or ‘gpu’) used to determine the batch sizes automatically. If
meshis specified, the platform is inferred from the devices in the mesh. Otherwise, ifyis a concrete array (i.e.,initis not invoked in ajax.jitcontext), the platform is set to the platform ofy. Otherwise, usetarget_platform.To avoid confusion, in all cases where the
target_platformargument would be ignored,initraises an exception iftarget_platformis set.
- Returns:
State– An initialized BART MCMC state.- Raises:
ValueError – If
yis boolean and arguments unused in binary regression are set.
Notes
In decision nodes, the values in
X[i, :]are compared to a cutpoint out of the range[1, 2, ..., max_split[i]]. A point belongs to the left child iffX[i, j] < cutpoint. Thus it makes sense forX[i, :]to be integers in the range[0, 1, ..., max_split[i]].