bartz.mcmcstep.init¶
- bartz.mcmcstep.init(*, X, y, outcome_type='continuous', offset, max_split, num_trees, p_nonterminal, leaf_prior_cov_inv, error_cov_inv=None, error_scale=None, missing=None, min_points_per_decision_node=None, resid_reduction_config=AutoBatchedReduction(), count_reduction_config=AutoOneHotReduction(), prec_reduction_config=AutoOneHotReduction(), prec_count_num_trees='auto', sequential_unroll=2, save_ratios=False, filter_splitless_vars=0, min_points_per_leaf=None, log_s=None, theta=None, a=None, b=None, rho=None, sparse_on_at=None, num_chains=None, mesh=None)[source]¶
Make a BART posterior sampling MCMC initial state.
- Parameters:
X (
UInt[Array, 'p n']|UInt[ndarray, 'p n']) – The predictors. Note this is trasposed compared to the usual convention.y (
Float32[Array, 'n']|Float32[ndarray, 'n']|Float32[Array, 'k n']|Float32[ndarray, 'k n']) – The response. If two-dimensional, the outcome is multivariate with the first axis indicating the component. For binary data, non-zero means 1, zero means 0.outcome_type (
OutcomeType|str|Sequence[OutcomeType|str], default:'continuous') – Whether the regression is continuous or binary (probit). Can also be a sequence ofOutcomeTypevalues, one per outcome component, for mixed binary-continuous multivariate regression.offset (
float|Float[Array, '']|Float[ndarray, '']|Float[Array, 'k']|Float[ndarray, 'k']) – Constant shift added to the sum of trees. 0 if not specified.max_split (
UInt[Array, 'p']|UInt[ndarray, 'p']) – The maximum split index for each variable. All split ranges start at 1.num_trees (
int) – The number of trees in the forest.p_nonterminal (
Float32[Array, 'd_minus_1']|Float32[ndarray, 'd_minus_1']) – The probability of a nonterminal node at each depth. The maximum depth of trees is fixed by the length of this array. Usemake_p_nonterminalto set it with the conventional formula.leaf_prior_cov_inv (
float|Float[Array, '']|Float[ndarray, '']|Float[Array, 'k k']|Float[ndarray, 'k k']) – The prior precision matrix of a leaf, conditional on the tree structure. For the univariate case (k=1), this is a scalar (the inverse variance). The prior covariance of the sum of trees isnum_trees * leaf_prior_cov_inv^-1. The prior mean of leaves is always zero.error_cov_inv (
Wishart|None, default:None) – The Wishart prior on the inverse error covariance, together with its initial value (seeWishart). Leave it unspecified for binary regression. The mixed binary-continuous and partial-missing diagonal modes require aDiagWishart; in the mixed case the binary components must have an initial precision of 1 (seeDiagWishart).error_scale (
Float32[Array, 'n']|Float32[ndarray, 'n']|Float32[Array, 'k n']|Float32[ndarray, 'k n']|None, default:None) – Each error is scaled by the corresponding factor inerror_scale. Iferror_scale[..., i]is a scalar, each error variance or covariance matrix is multiplied byerror_scale[..., i] ** 2. Iferror_scale[:, i]is a vector, then the covariance matrix is rescaled by its outer product. Not supported for binary or mixed binary-continuous regression. If not specified, defaults to 1 for all points, but potentially skipping calculations.missing (
Bool[Array, 'n']|Bool[ndarray, 'n']|Bool[Array, 'k n']|Bool[ndarray, 'k n']|None, default:None) – Boolean mask, same shape asy;Truemarks entries to be ignored by the MCMC. Values ofymust be finite everywhere, including at masked positions. If 2-D,error_cov_inv.ratemust be diagonal.min_points_per_decision_node (
int|Integer[Array, '']|Integer[ndarray, '']|None, default:None) – The minimum number of data points in a decision node. 0 if not specified.resid_reduction_config (
ReductionConfig, default:AutoBatchedReduction())count_reduction_config (
ReductionConfig, default:AutoOneHotReduction())prec_reduction_config (
ReductionConfig, default:AutoOneHotReduction()) – How to sum the residuals, count the datapoints, and sum the likelihood precisions in each leaf, respectively. SeeReductionConfigand its subclasses.prec_count_num_trees (
int|None|Literal['auto'], default:'auto') – The number of trees to process at a time when counting datapoints or computing the likelihood precision. IfNone, do all trees at once, which may use too much memory. If ‘auto’ (default), it’s chosen automatically.sequential_unroll (
int|bool, default:2) – How much to unroll the sequential accept/reject loop over trees instep. See theunrollargument ofjax.lax.scan. Unrolling may speed up the MCMC at the cost of longer compilation. 1 means no unrolling; the default is 2.save_ratios (
bool, default:False) – Whether to save the Metropolis-Hastings ratios.filter_splitless_vars (
int, default:0) – The maximum number of variables without splits that can be ignored. If there are more,initraises an exception.min_points_per_leaf (
int|Integer[Array, '']|Integer[ndarray, '']|None, default:None) – The minimum number of datapoints in a leaf node. 0 if not specified. Unlikemin_points_per_decision_node, this constraint is not taken into account in the Metropolis-Hastings ratio because it would be expensive to compute. Grow moves that would violate this constraint are vetoed. This parameter is independent ofmin_points_per_decision_nodeand there is no check that they are coherent. It makes sense to setmin_points_per_decision_node >= 2 * min_points_per_leaf.log_s (
Float32[Array, 'p']|Float32[ndarray, 'p']|None, default:None) – The logarithm of the prior probability for choosing a variable to split along in a decision rule, conditional on the ancestors. Not normalized. If not specified, use a uniform distribution. If not specified andthetaorrho,a,bare, it’s initialized automatically.theta (
float|Float[Array, '']|Float[ndarray, '']|None, default:None) – The concentration parameter for the Dirichlet prior ons. Required only to updatelog_s. If not specified, andrho,a,bare specified, it’s initialized automatically.a (
float|Float[Array, '']|Float[ndarray, '']|None, default:None)b (
float|Float[Array, '']|Float[ndarray, '']|None, default:None)rho (
float|Float[Array, '']|Float[ndarray, '']|None, default:None) – Parameters of the prior ontheta. Required only to sampletheta.sparse_on_at (
int|Integer[Array, '']|Integer[ndarray, '']|None, default:None) – After how many MCMC steps to turn on variable selection.num_chains (
int|None, default:None) – The number of independent MCMC chains to represent in the state. Single chain with scalar values if not specified.mesh (
Mesh|dict[str,int] |None, default:None) –A jax mesh used to shard data and computation across multiple devices. If it has a ‘chains’ axis, that axis is used to shard the chains. If it has a ‘data’ axis, that axis is used to shard the datapoints.
As a shorthand, if a dictionary mapping axis names to axis size is passed, the corresponding mesh is created, e.g.,
dict(chains=4, data=2)will let jax pick 8 devices to split chains (which must be a multiple of 4) across 4 pairs of devices, where in each pair the data is split in two.Note: if a mesh is passed, the arrays are always sharded according to it. In particular even if the mesh has no ‘chains’ or ‘data’ axis, the arrays will be replicated on all devices in the mesh.
- Returns:
State– An initialized BART MCMC state.- Raises:
ValueError – If arguments unused in binary regression are set.
Notes
In decision nodes, the values in
X[i, :]are compared to a cutpoint out of the range[1, 2, ..., max_split[i]]. A point belongs to the left child iffX[i, j] < cutpoint. Thus it makes sense forX[i, :]to be integers in the range[0, 1, ..., max_split[i]].In general the arrays passed to this function as arguments may be donated, invalidating them. Create copies before passing them to
initif this happens and you need them again.