bartz.mcmcstep.init

bartz.mcmcstep.init(*, X, y, outcome_type='continuous', offset, max_split, num_trees, p_nonterminal, leaf_prior_cov_inv, error_cov_inv=None, error_scale=None, missing=None, min_points_per_decision_node=None, resid_reduction_config=AutoBatchedReduction(), count_reduction_config=AutoOneHotReduction(), prec_reduction_config=AutoOneHotReduction(), prec_count_num_trees='auto', sequential_unroll=2, save_ratios=False, filter_splitless_vars=0, min_points_per_leaf=None, log_s=None, theta=None, a=None, b=None, rho=None, sparse_on_at=None, augment=True, num_chains=None, mesh=None)[source]

Make a BART posterior sampling MCMC initial state.

Parameters:
  • X (UInt[Array, 'p n'] | UInt[ndarray, 'p n']) – The predictors. Note this is trasposed compared to the usual convention.

  • y (Float32[Array, 'n'] | Float32[ndarray, 'n'] | Float32[Array, 'k n'] | Float32[ndarray, 'k n']) – The response. If two-dimensional, the outcome is multivariate with the first axis indicating the component. For binary data, non-zero means 1, zero means 0.

  • outcome_type (OutcomeType | str | Sequence[OutcomeType | str], default: 'continuous') – Whether the regression is continuous or binary (probit). Can also be a sequence of OutcomeType values, one per outcome component, for mixed binary-continuous multivariate regression.

  • offset (float | Float[Array, ''] | Float[ndarray, ''] | Float[Array, 'k'] | Float[ndarray, 'k']) – Constant shift added to the sum of trees. 0 if not specified.

  • max_split (UInt[Array, 'p'] | UInt[ndarray, 'p']) – The maximum split index for each variable. All split ranges start at 1.

  • num_trees (int) – The number of trees in the forest.

  • p_nonterminal (Float32[Array, 'd_minus_1'] | Float32[ndarray, 'd_minus_1']) – The probability of a nonterminal node at each depth. The maximum depth of trees is fixed by the length of this array. Use make_p_nonterminal to set it with the conventional formula.

  • leaf_prior_cov_inv (float | Float[Array, ''] | Float[ndarray, ''] | Float[Array, 'k k'] | Float[ndarray, 'k k']) – The prior precision matrix of a leaf, conditional on the tree structure. For the univariate case (k=1), this is a scalar (the inverse variance). The prior covariance of the sum of trees is num_trees * leaf_prior_cov_inv^-1. The prior mean of leaves is always zero.

  • error_cov_inv (Wishart | None, default: None) – The Wishart prior on the inverse error covariance, together with its initial value (see Wishart). Leave it unspecified for binary regression. The mixed binary-continuous and partial-missing diagonal modes require a DiagWishart; in the mixed case the binary components must have an initial precision of 1 (see DiagWishart).

  • error_scale (Float32[Array, 'n'] | Float32[ndarray, 'n'] | Float32[Array, 'k n'] | Float32[ndarray, 'k n'] | None, default: None) – Each error is scaled by the corresponding factor in error_scale. If error_scale[..., i] is a scalar, each error variance or covariance matrix is multiplied by error_scale[..., i] ** 2. If error_scale[:, i] is a vector, then the covariance matrix is rescaled by its outer product. Not supported for binary or mixed binary-continuous regression. If not specified, defaults to 1 for all points, but potentially skipping calculations.

  • missing (Bool[Array, 'n'] | Bool[ndarray, 'n'] | Bool[Array, 'k n'] | Bool[ndarray, 'k n'] | None, default: None) – Boolean mask, same shape as y; True marks entries to be ignored by the MCMC. Values of y must be finite everywhere, including at masked positions. If 2-D, error_cov_inv.rate must be diagonal.

  • min_points_per_decision_node (int | Integer[Array, ''] | Integer[ndarray, ''] | None, default: None) – The minimum number of data points in a decision node. 0 if not specified.

  • resid_reduction_config (ReductionConfig, default: AutoBatchedReduction())

  • count_reduction_config (ReductionConfig, default: AutoOneHotReduction())

  • prec_reduction_config (ReductionConfig, default: AutoOneHotReduction()) – How to sum the residuals, count the datapoints, and sum the likelihood precisions in each leaf, respectively. See ReductionConfig and its subclasses.

  • prec_count_num_trees (int | None | Literal['auto'], default: 'auto') – The number of trees to process at a time when counting datapoints or computing the likelihood precision. If None, do all trees at once, which may use too much memory. If ‘auto’ (default), it’s chosen automatically.

  • sequential_unroll (int | bool, default: 2) – How much to unroll the sequential accept/reject loop over trees in step. See the unroll argument of jax.lax.scan. Unrolling may speed up the MCMC at the cost of longer compilation. 1 means no unrolling; the default is 2.

  • save_ratios (bool, default: False) – Whether to save the Metropolis-Hastings ratios.

  • filter_splitless_vars (int, default: 0) – The maximum number of variables without splits that can be ignored. If there are more, init raises an exception.

  • min_points_per_leaf (int | Integer[Array, ''] | Integer[ndarray, ''] | None, default: None) – The minimum number of datapoints in a leaf node. 0 if not specified. Unlike min_points_per_decision_node, this constraint is not taken into account in the Metropolis-Hastings ratio because it would be expensive to compute. Grow moves that would violate this constraint are vetoed. This parameter is independent of min_points_per_decision_node and there is no check that they are coherent. It makes sense to set min_points_per_decision_node >= 2 * min_points_per_leaf.

  • log_s (Float32[Array, 'p'] | Float32[ndarray, 'p'] | None, default: None) – The logarithm of the prior probability for choosing a variable to split along in a decision rule, conditional on the ancestors. Not normalized. If not specified, use a uniform distribution. If not specified and theta or rho, a, b are, it’s initialized automatically.

  • theta (float | Float[Array, ''] | Float[ndarray, ''] | None, default: None) – The concentration parameter for the Dirichlet prior on s. Required only to update log_s. If not specified, and rho, a, b are specified, it’s initialized automatically.

  • a (float | Float[Array, ''] | Float[ndarray, ''] | None, default: None)

  • b (float | Float[Array, ''] | Float[ndarray, ''] | None, default: None)

  • rho (float | Float[Array, ''] | Float[ndarray, ''] | None, default: None) – Parameters of the prior on theta. Required only to sample theta.

  • sparse_on_at (int | Integer[Array, ''] | Integer[ndarray, ''] | None, default: None) – After how many MCMC steps to turn on variable selection.

  • augment (bool, default: True) – Whether to account exactly, via data augmentation, for the decision rules forbidden by the ancestors of each node when updating log_s. If not set, those rules are ignored, which is faster but only approximate.

  • num_chains (int | None, default: None) – The number of independent MCMC chains to represent in the state. Single chain with scalar values if not specified.

  • mesh (Mesh | dict[str, int] | None, default: None) –

    A jax mesh used to shard data and computation across multiple devices. If it has a ‘chains’ axis, that axis is used to shard the chains. If it has a ‘data’ axis, that axis is used to shard the datapoints.

    As a shorthand, if a dictionary mapping axis names to axis size is passed, the corresponding mesh is created, e.g., dict(chains=4, data=2) will let jax pick 8 devices to split chains (which must be a multiple of 4) across 4 pairs of devices, where in each pair the data is split in two.

    Note: if a mesh is passed, the arrays are always sharded according to it. In particular even if the mesh has no ‘chains’ or ‘data’ axis, the arrays will be replicated on all devices in the mesh.

Returns:

StateAn initialized BART MCMC state.

Raises:

ValueError – If arguments unused in binary regression are set.

Notes

In decision nodes, the values in X[i, :] are compared to a cutpoint out of the range [1, 2, ..., max_split[i]]. A point belongs to the left child iff X[i, j] < cutpoint. Thus it makes sense for X[i, :] to be integers in the range [0, 1, ..., max_split[i]].

In general the arrays passed to this function as arguments may be donated, invalidating them. Create copies before passing them to init if this happens and you need them again.