MCMC setup and step

Functions that implement the BART posterior MCMC initialization and update step.

  • init: Creates an initial State from data and configurations.

  • step: Performs one full MCMC step on a State, returning a new State.

class bartz.mcmcstep.Forest(leaf_tree, var_tree, split_tree, affluence_tree, max_split, blocked_vars, p_nonterminal, p_propose_grow, leaf_indices, min_points_per_decision_node, min_points_per_leaf, log_trans_prior, log_likelihood, grow_prop_count, prune_prop_count, grow_acc_count, prune_acc_count, leaf_prior_cov_inv, log_s, theta, a, b, rho)[source]

Represents the MCMC state of a sum of trees.

Parameters:
  • leaf_tree (Float32[Array, '*chains num_trees 2**d'] | Float32[Array, '*chains num_trees k 2**d']) – The leaf values.

  • var_tree (UInt[Array, '*chains num_trees 2**(d-1)']) – The decision axes.

  • split_tree (UInt[Array, '*chains num_trees 2**(d-1)']) – The decision boundaries.

  • affluence_tree (Bool[Array, '*chains num_trees 2**(d-1)']) – Marks leaves that can be grown.

  • max_split (UInt[Array, 'p']) – The maximum split index for each predictor.

  • blocked_vars (UInt[Array, 'q'] | None) – Indices of variables that are not used. This shall include at least the i such that max_split[i] == 0, otherwise behavior is undefined.

  • p_nonterminal (Float32[Array, '2**d']) – The prior probability of each node being nonterminal, conditional on its ancestors. Includes the nodes at maximum depth which should be set to 0.

  • p_propose_grow (Float32[Array, '2**(d-1)']) – The unnormalized probability of picking a leaf for a grow proposal.

  • leaf_indices (UInt[Array, '*chains num_trees n']) – The index of the leaf each datapoints falls into, for each tree.

  • min_points_per_decision_node (Int32[Array, ''] | None) – The minimum number of data points in a decision node.

  • min_points_per_leaf (Int32[Array, ''] | None) – The minimum number of data points in a leaf node.

  • log_trans_prior (Float32[Array, '*chains num_trees'] | None) – The log transition and prior Metropolis-Hastings ratio for the proposed move on each tree.

  • log_likelihood (Float32[Array, '*chains num_trees'] | None) – The log likelihood ratio.

  • grow_prop_count (Int32[Array, '*chains'])

  • prune_prop_count (Int32[Array, '*chains']) – The number of grow/prune proposals made during one full MCMC cycle.

  • grow_acc_count (Int32[Array, '*chains'])

  • prune_acc_count (Int32[Array, '*chains']) – The number of grow/prune moves accepted during one full MCMC cycle.

  • leaf_prior_cov_inv (Float32[Array, ''] | Float32[Array, 'k k'] | None) – The prior precision matrix of a leaf, conditional on the tree structure. For the univariate case (k=1), this is a scalar (the inverse variance). The prior covariance of the sum of trees is num_trees * leaf_prior_cov_inv^-1.

  • log_s (Float32[Array, '*chains p'] | None) – The logarithm of the prior probability for choosing a variable to split along in a decision rule, conditional on the ancestors. Not normalized. If None, use a uniform distribution.

  • theta (Float32[Array, '*chains'] | None) – The concentration parameter for the Dirichlet prior on the variable distribution s. Required only to update s.

  • a (Float32[Array, ''] | None)

  • b (Float32[Array, ''] | None)

  • rho (Float32[Array, ''] | None) – Parameters of the prior on theta. Required only to sample theta. See step_theta.

num_chains()[source]

Return the number of chains, or None if not multichain.

Return type:

int | None

class bartz.mcmcstep.State(X, y, z, offset, resid, error_cov_inv, prec_scale, error_cov_df, error_cov_scale, forest, config)[source]

Represents the MCMC state of BART.

Parameters:
  • X (UInt[Array, 'p n']) – The predictors.

  • y (Float32[Array, 'n'] | Float32[Array, 'k n'] | Bool[Array, 'n']) – The response. If the data type is bool, the model is binary regression.

  • resid (Float32[Array, '*chains n'] | Float32[Array, '*chains k n']) – The residuals (y or z minus sum of trees).

  • z (None | Float32[Array, '*chains n']) – The latent variable for binary regression. None in continuous regression.

  • offset (Float32[Array, ''] | Float32[Array, 'k']) – Constant shift added to the sum of trees.

  • error_cov_inv (Float32[Array, '*chains'] | Float32[Array, '*chains k k'] | None) – The inverse error covariance (scalar for univariate, matrix for multivariate). None in binary regression.

  • prec_scale (Float32[Array, 'n'] | None) – The scale on the error precision, i.e., 1 / error_scale ** 2. None in binary regression.

  • error_cov_df (Float32[Array, ''] | None)

  • error_cov_scale (Float32[Array, ''] | Float32[Array, 'k k'] | None) – The df and scale parameters of the inverse Wishart prior on the noise covariance. For the univariate case, the relationship to the inverse gamma prior parameters is alpha = df / 2, beta = scale / 2. None in binary regression.

  • forest (Forest) – The sum of trees model.

  • config (StepConfig) – Metadata and configurations for the MCMC step.

class bartz.mcmcstep.StepConfig(steps_done, sparse_on_at, resid_batch_size, count_batch_size, mesh)[source]

Options for the MCMC step.

Parameters:
  • steps_done (Int32[Array, '']) – The number of MCMC steps completed so far.

  • sparse_on_at (Int32[Array, ''] | None) – After how many steps to turn on variable selection.

  • resid_batch_size (int | None)

  • count_batch_size (int | None) – The data batch sizes for computing the sufficient statistics. If None, they are computed with no batching.

  • mesh (Mesh | None) – The mesh used to shard data and computation across multiple devices.

bartz.mcmcstep.init(*, X, y, offset, max_split, num_trees, p_nonterminal, leaf_prior_cov_inv, error_cov_df=None, error_cov_scale=None, error_scale=None, min_points_per_decision_node=None, resid_batch_size='auto', count_batch_size='auto', save_ratios=False, filter_splitless_vars=True, min_points_per_leaf=None, log_s=None, theta=None, a=None, b=None, rho=None, sparse_on_at=None, num_chains=None, mesh=None, target_platform=None)[source]

Make a BART posterior sampling MCMC initial state.

Parameters:
  • X (UInt[Any, 'p n']) – The predictors. Note this is trasposed compared to the usual convention.

  • y (Float32[Any, 'n'] | Float32[Any, 'k n'] | Bool[Any, 'n']) – The response. If the data type is bool, the regression model is binary regression with probit. If two-dimensional, the outcome is multivariate with the first axis indicating the component.

  • offset (float | Float32[Any, ''] | Float32[Any, 'k']) – Constant shift added to the sum of trees. 0 if not specified.

  • max_split (UInt[Any, 'p']) – The maximum split index for each variable. All split ranges start at 1.

  • num_trees (int) – The number of trees in the forest.

  • p_nonterminal (Float32[Any, 'd_minus_1']) – The probability of a nonterminal node at each depth. The maximum depth of trees is fixed by the length of this array.

  • leaf_prior_cov_inv (float | Float32[Any, ''] | Float32[Array, 'k k']) – The prior precision matrix of a leaf, conditional on the tree structure. For the univariate case (k=1), this is a scalar (the inverse variance). The prior covariance of the sum of trees is num_trees * leaf_prior_cov_inv^-1. The prior mean of leaves is always zero.

  • error_cov_df (float | Float32[Any, ''] | None, default: None)

  • error_cov_scale (float | Float32[Any, ''] | Float32[Array, 'k k'] | None, default: None) – The df and scale parameters of the inverse Wishart prior on the error covariance. For the univariate case, the relationship to the inverse gamma prior parameters is alpha = df / 2, beta = scale / 2. Leave unspecified for binary regression.

  • error_scale (Float32[Any, 'n'] | None, default: None) – Each error is scaled by the corresponding factor in error_scale, so the error variance for y[i] is sigma2 * error_scale[i] ** 2. Not supported for binary regression. If not specified, defaults to 1 for all points, but potentially skipping calculations.

  • min_points_per_decision_node (int | Integer[Any, ''] | None, default: None) – The minimum number of data points in a decision node. 0 if not specified.

  • resid_batch_size (int | None | Literal['auto'], default: 'auto')

  • count_batch_size (int | None | Literal['auto'], default: 'auto') – The batch sizes, along datapoints, for summing the residuals and counting the number of datapoints in each leaf. None for no batching. If ‘auto’, it’s chosen automatically based on the target platform; see the description of target_platform below for how it is determined.

  • save_ratios (bool, default: False) – Whether to save the Metropolis-Hastings ratios.

  • filter_splitless_vars (bool, default: True) – Whether to check max_split for variables without available cutpoints. If any are found, they are put into a list of variables to exclude from the MCMC. If False, no check is performed, but the results may be wrong if any variable is blocked. The function is jax-traceable only if this is set to False.

  • min_points_per_leaf (int | Integer[Any, ''] | None, default: None) – The minimum number of datapoints in a leaf node. 0 if not specified. Unlike min_points_per_decision_node, this constraint is not taken into account in the Metropolis-Hastings ratio because it would be expensive to compute. Grow moves that would violate this constraint are vetoed. This parameter is independent of min_points_per_decision_node and there is no check that they are coherent. It makes sense to set min_points_per_decision_node >= 2 * min_points_per_leaf.

  • log_s (Float32[Any, 'p'] | None, default: None) – The logarithm of the prior probability for choosing a variable to split along in a decision rule, conditional on the ancestors. Not normalized. If not specified, use a uniform distribution. If not specified and theta or rho, a, b are, it’s initialized automatically.

  • theta (float | Float32[Any, ''] | None, default: None) – The concentration parameter for the Dirichlet prior on s. Required only to update log_s. If not specified, and rho, a, b are specified, it’s initialized automatically.

  • a (float | Float32[Any, ''] | None, default: None)

  • b (float | Float32[Any, ''] | None, default: None)

  • rho (float | Float32[Any, ''] | None, default: None) – Parameters of the prior on theta. Required only to sample theta.

  • sparse_on_at (int | Integer[Any, ''] | None, default: None) – After how many MCMC steps to turn on variable selection.

  • num_chains (int | None, default: None) – The number of independent MCMC chains to represent in the state. Single chain with scalar values if not specified.

  • mesh (Mesh | dict[str, int] | None, default: None) –

    A jax mesh used to shard data and computation across multiple devices. If it has a ‘chains’ axis, that axis is used to shard the chains. If it has a ‘data’ axis, that axis is used to shard the datapoints.

    As a shorthand, if a dictionary mapping axis names to axis size is passed, the corresponding mesh is created, e.g., dict(chains=4, data=2) will let jax pick 8 devices to split chains (which must be a multiple of 4) across 4 pairs of devices, where in each pair the data is split in two.

    Note: if a mesh is passed, the arrays are always sharded according to it. In particular even if the mesh has no ‘chains’ or ‘data’ axis, the arrays will be replicated on all devices in the mesh.

  • target_platform (Literal['cpu', 'gpu'] | None, default: None) –

    Platform (‘cpu’ or ‘gpu’) used to determine the batch sizes automatically. If mesh is specified, the platform is inferred from the devices in the mesh. Otherwise, if y is a concrete array (i.e., init is not invoked in a jax.jit context), the platform is set to the platform of y. Otherwise, use target_platform.

    To avoid confusion, in all cases where the target_platform argument would be ignored, init raises an exception if target_platform is set.

Returns:

StateAn initialized BART MCMC state.

Raises:

ValueError – If y is boolean and arguments unused in binary regression are set.

Notes

In decision nodes, the values in X[i, :] are compared to a cutpoint out of the range [1, 2, ..., max_split[i]]. A point belongs to the left child iff X[i, j] < cutpoint. Thus it makes sense for X[i, :] to be integers in the range [0, 1, ..., max_split[i]].

bartz.mcmcstep.step(key, bart)[source]

Do one MCMC step.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • bart (State) – A BART mcmc state, as created by init.

Returns:

StateThe new BART mcmc state.

Notes

The memory of the input state is re-used for the output state, so the input state can not be used any more after calling step. All this applies outside of jax.jit.