MCMC setup and step

Functions that implement the BART posterior MCMC initialization and update step.

  • init: Creates an initial State from data and configurations.

  • step: Performs one full MCMC step on a State, returning a new State.

class bartz.mcmcstep.Forest(leaf_tree, var_tree, split_tree, affluence_tree, max_split, blocked_vars, p_nonterminal, p_propose_grow, leaf_indices, min_points_per_decision_node, min_points_per_leaf, log_trans_prior, log_likelihood, grow_prop_count, prune_prop_count, grow_acc_count, prune_acc_count, leaf_prior_cov_inv, log_s, theta, a, b, rho)[source]

Represents the MCMC state of a sum of trees.

leaf_tree: Float32[Array, '*chains num_trees 2**d'] | Float32[Array, '*chains num_trees k 2**d']

The leaf values.

var_tree: UInt[Array, '*chains num_trees 2**(d-1)']

The decision axes.

split_tree: UInt[Array, '*chains num_trees 2**(d-1)']

The decision boundaries.

affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)']

Marks leaves that can be grown.

max_split: UInt[Array, 'p']

The maximum split index for each predictor.

blocked_vars: UInt[Array, 'q'] | None

Indices of variables that are not used. This shall include at least the i such that max_split[i] == 0, otherwise behavior is undefined.

p_nonterminal: Float32[Array, '2**d']

The prior probability of each node being nonterminal, conditional on its ancestors. Includes the nodes at maximum depth which should be set to 0.

p_propose_grow: Float32[Array, '2**(d-1)']

The unnormalized probability of picking a leaf for a grow proposal.

leaf_indices: UInt[Array, '*chains num_trees n']

The index of the leaf each datapoints falls into, for each tree.

min_points_per_decision_node: Int32[Array, ''] | None

The minimum number of data points in a decision node.

min_points_per_leaf: Int32[Array, ''] | None

The minimum number of data points in a leaf node.

log_trans_prior: Float32[Array, '*chains num_trees'] | None

The log transition and prior Metropolis-Hastings ratio for the proposed move on each tree.

log_likelihood: Float32[Array, '*chains num_trees'] | None

The log likelihood ratio.

grow_prop_count: Int32[Array, '*chains']

The number of grow proposals made during one full MCMC cycle.

prune_prop_count: Int32[Array, '*chains']

The number of prune proposals made during one full MCMC cycle.

grow_acc_count: Int32[Array, '*chains']

The number of grow moves accepted during one full MCMC cycle.

prune_acc_count: Int32[Array, '*chains']

The number of prune moves accepted during one full MCMC cycle.

leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None

The prior precision matrix of a leaf, conditional on the tree structure. For the univariate case (k=1), this is a scalar (the inverse variance). The prior covariance of the sum of trees is num_trees * leaf_prior_cov_inv^-1.

log_s: Float32[Array, '*chains p'] | None

The logarithm of the prior probability for choosing a variable to split along in a decision rule, conditional on the ancestors. Not normalized. If None, use a uniform distribution.

theta: Float32[Array, '*chains'] | None

The concentration parameter for the Dirichlet prior on the variable distribution s. Required only to update log_s.

a: Float32[Array, ''] | None

Parameter of the prior on theta. Required only to sample theta. See step_theta.

b: Float32[Array, ''] | None

Parameter of the prior on theta. Required only to sample theta. See step_theta.

rho: Float32[Array, ''] | None

Parameter of the prior on theta. Required only to sample theta. See step_theta.

num_chains()[source]

Return the number of chains, or None if not multichain.

Return type:

DTypeLike[int, None]

class bartz.mcmcstep.State(X, y, z, offset, resid, error_cov_inv, prec_scale, error_cov_df, error_cov_scale, forest, config)[source]

Represents the MCMC state of BART.

X: UInt[Array, 'p n']

The predictors.

y: Float32[Array, 'n'] | Float32[Array, 'k n'] | Bool[Array, 'n']

The response. If the data type is bool, the model is binary regression.

z: None | Float32[Array, '*chains n']

The latent variable for binary regression. None in continuous regression.

offset: Float32[Array, ''] | Float32[Array, 'k']

Constant shift added to the sum of trees.

resid: Float32[Array, '*chains n'] | Float32[Array, '*chains k n']

The residuals (y or z minus sum of trees).

error_cov_inv: Float32[Array, '*chains'] | Float32[Array, '*chains k k'] | None

The inverse error covariance (scalar for univariate, matrix for multivariate). None in binary regression.

prec_scale: Float32[Array, 'n'] | None

The scale on the error precision, i.e., 1 / error_scale ** 2. None in binary regression.

error_cov_df: Float32[Array, ''] | None

The df parameter of the inverse Wishart prior on the noise covariance. For the univariate case, the relationship to the inverse gamma prior parameters is alpha = df / 2. None in binary regression.

error_cov_scale: Float32[Array, ''] | Float32[Array, 'k k'] | None

The scale parameter of the inverse Wishart prior on the noise covariance. For the univariate case, the relationship to the inverse gamma prior parameters is beta = scale / 2. None in binary regression.

forest: Forest

The sum of trees model.

config: StepConfig

Metadata and configurations for the MCMC step.

class bartz.mcmcstep.StepConfig(steps_done, sparse_on_at, resid_num_batches, count_num_batches, prec_num_batches, prec_count_num_trees, mesh)[source]

Options for the MCMC step.

steps_done: Int32[Array, '']

The number of MCMC steps completed so far.

sparse_on_at: Int32[Array, ''] | None

After how many steps to turn on variable selection.

resid_num_batches: int | None

The number of batches for computing the sum of residuals. If None, they are computed with no batching.

count_num_batches: int | None

The number of batches for computing counts. If None, they are computed with no batching.

prec_num_batches: int | None

The number of batches for computing precision scales. If None, they are computed with no batching.

prec_count_num_trees: int | None

Batch size for processing trees to compute count and prec trees.

mesh: Mesh | None

The mesh used to shard data and computation across multiple devices.

bartz.mcmcstep.init(*, X, y, offset, max_split, num_trees, p_nonterminal, leaf_prior_cov_inv, error_cov_df=None, error_cov_scale=None, error_scale=None, min_points_per_decision_node=None, resid_num_batches='auto', count_num_batches='auto', prec_num_batches='auto', prec_count_num_trees='auto', save_ratios=False, filter_splitless_vars=0, min_points_per_leaf=None, log_s=None, theta=None, a=None, b=None, rho=None, sparse_on_at=None, num_chains=None, mesh=None, target_platform=None)[source]

Make a BART posterior sampling MCMC initial state.

Parameters:
  • X (UInt[Any, 'p n']) – The predictors. Note this is trasposed compared to the usual convention.

  • y (DTypeLike[Float32[Any, 'n'], Float32[Any, 'k n'], Bool[Any, 'n']]) – The response. If the data type is bool, the regression model is binary regression with probit. If two-dimensional, the outcome is multivariate with the first axis indicating the component.

  • offset (DTypeLike[float, Float32[Any, ''], Float32[Any, 'k']]) – Constant shift added to the sum of trees. 0 if not specified.

  • max_split (UInt[Any, 'p']) – The maximum split index for each variable. All split ranges start at 1.

  • num_trees (int) – The number of trees in the forest.

  • p_nonterminal (Float32[Any, 'd_minus_1']) – The probability of a nonterminal node at each depth. The maximum depth of trees is fixed by the length of this array. Use make_p_nonterminal to set it with the conventional formula.

  • leaf_prior_cov_inv (DTypeLike[float, Float32[Any, ''], Float32[Array, 'k k']]) – The prior precision matrix of a leaf, conditional on the tree structure. For the univariate case (k=1), this is a scalar (the inverse variance). The prior covariance of the sum of trees is num_trees * leaf_prior_cov_inv^-1. The prior mean of leaves is always zero.

  • error_cov_df (DTypeLike[float, Float32[Any, ''], None], default: None)

  • error_cov_scale (DTypeLike[float, Float32[Any, ''], Float32[Array, 'k k'], None], default: None) – The df and scale parameters of the inverse Wishart prior on the error covariance. For the univariate case, the relationship to the inverse gamma prior parameters is alpha = df / 2, beta = scale / 2. Leave unspecified for binary regression.

  • error_scale (DTypeLike[Float32[Any, 'n'], None], default: None) – Each error is scaled by the corresponding factor in error_scale, so the error variance for y[i] is sigma2 * error_scale[i] ** 2. Not supported for binary regression. If not specified, defaults to 1 for all points, but potentially skipping calculations.

  • min_points_per_decision_node (DTypeLike[int, Integer[Any, ''], None], default: None) – The minimum number of data points in a decision node. 0 if not specified.

  • resid_num_batches (DTypeLike[int, None, Literal['auto']], default: 'auto')

  • count_num_batches (DTypeLike[int, None, Literal['auto']], default: 'auto')

  • prec_num_batches (DTypeLike[int, None, Literal['auto']], default: 'auto') – The number of batches, along datapoints, for summing the residuals, counting the number of datapoints in each leaf, and computing the likelihood precision in each leaf, respectively. None for no batching. If ‘auto’, it’s chosen automatically based on the target platform; see the description of target_platform below for how it is determined.

  • prec_count_num_trees (DTypeLike[int, None, Literal['auto']], default: 'auto') – The number of trees to process at a time when counting datapoints or computing the likelihood precision. If None, do all trees at once, which may use too much memory. If ‘auto’ (default), it’s chosen automatically.

  • save_ratios (bool, default: False) – Whether to save the Metropolis-Hastings ratios.

  • filter_splitless_vars (int, default: 0) – The maximum number of variables without splits that can be ignored. If there are more, init raises an exception.

  • min_points_per_leaf (DTypeLike[int, Integer[Any, ''], None], default: None) – The minimum number of datapoints in a leaf node. 0 if not specified. Unlike min_points_per_decision_node, this constraint is not taken into account in the Metropolis-Hastings ratio because it would be expensive to compute. Grow moves that would violate this constraint are vetoed. This parameter is independent of min_points_per_decision_node and there is no check that they are coherent. It makes sense to set min_points_per_decision_node >= 2 * min_points_per_leaf.

  • log_s (DTypeLike[Float32[Any, 'p'], None], default: None) – The logarithm of the prior probability for choosing a variable to split along in a decision rule, conditional on the ancestors. Not normalized. If not specified, use a uniform distribution. If not specified and theta or rho, a, b are, it’s initialized automatically.

  • theta (DTypeLike[float, Float32[Any, ''], None], default: None) – The concentration parameter for the Dirichlet prior on s. Required only to update log_s. If not specified, and rho, a, b are specified, it’s initialized automatically.

  • a (DTypeLike[float, Float32[Any, ''], None], default: None)

  • b (DTypeLike[float, Float32[Any, ''], None], default: None)

  • rho (DTypeLike[float, Float32[Any, ''], None], default: None) – Parameters of the prior on theta. Required only to sample theta.

  • sparse_on_at (DTypeLike[int, Integer[Any, ''], None], default: None) – After how many MCMC steps to turn on variable selection.

  • num_chains (DTypeLike[int, None], default: None) – The number of independent MCMC chains to represent in the state. Single chain with scalar values if not specified.

  • mesh (DTypeLike[Mesh, dict[str, int], None], default: None) –

    A jax mesh used to shard data and computation across multiple devices. If it has a ‘chains’ axis, that axis is used to shard the chains. If it has a ‘data’ axis, that axis is used to shard the datapoints.

    As a shorthand, if a dictionary mapping axis names to axis size is passed, the corresponding mesh is created, e.g., dict(chains=4, data=2) will let jax pick 8 devices to split chains (which must be a multiple of 4) across 4 pairs of devices, where in each pair the data is split in two.

    Note: if a mesh is passed, the arrays are always sharded according to it. In particular even if the mesh has no ‘chains’ or ‘data’ axis, the arrays will be replicated on all devices in the mesh.

  • target_platform (DTypeLike[Literal['cpu', 'gpu'], None], default: None) –

    Platform (‘cpu’ or ‘gpu’) used to determine the number of batches automatically. If mesh is specified, the platform is inferred from the devices in the mesh. Otherwise, if y is a concrete array (i.e., init is not invoked in a jax.jit context), the platform is set to the platform of y. Otherwise, use target_platform.

    To avoid confusion, in all cases where the target_platform argument would be ignored, init raises an exception if target_platform is set.

Returns:

StateAn initialized BART MCMC state.

Raises:

ValueError – If y is boolean and arguments unused in binary regression are set.

Notes

In decision nodes, the values in X[i, :] are compared to a cutpoint out of the range [1, 2, ..., max_split[i]]. A point belongs to the left child iff X[i, j] < cutpoint. Thus it makes sense for X[i, :] to be integers in the range [0, 1, ..., max_split[i]].

bartz.mcmcstep.make_p_nonterminal(d, alpha=0.95, beta=2.0)[source]

Prepare the p_nonterminal argument to init.

It is calculated according to the formula:

P_nt(depth) = alpha / (1 + depth)^beta, with depth 0-based

Parameters:
  • d (int) – The maximum depth of the trees (d=1 means tree with only root node)

  • alpha (DTypeLike[float, Float32[Array, '']], default: 0.95) – The a priori probability of the root node having children, conditional on it being possible

  • beta (DTypeLike[float, Float32[Array, '']], default: 2.0) – The exponent of the power decay of the probability of having children with depth.

Returns:

Float32[Array, '{d}-1']An array of probabilities, one per tree level but the last.

bartz.mcmcstep.step(key, bart)[source]

Do one MCMC step.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • bart (State) – A BART mcmc state, as created by init.

Returns:

StateThe new BART mcmc state.

Notes

The memory of the input state is re-used for the output state, so the input state can not be used any more after calling step. All this applies outside of jax.jit.