New interface

class bartz.Bart(x_train, y_train, *, outcome_type='continuous', sparse=False, theta=None, a=0.5, b=1.0, rho=None, varprob=None, xinfo=None, usequants=False, rm_const=True, sigest=None, sigdf=3.0, sigquant=0.9, k=2.0, power=2.0, base=0.95, lamda=None, tau_num=None, offset=None, w=None, num_trees=200, numcut=255, ndpost=1000, nskip=1000, keepevery=1, printevery=100, num_chains=4, num_chain_devices=None, num_data_devices=None, devices=None, seed=0, maxdepth=6, init_kw=mappingproxy({}), run_mcmc_kw=mappingproxy({}))[source]

Nonparametric regression with Bayesian Additive Regression Trees (BART) [2].

Regress y_train on x_train with a latent mean function represented as a sum of decision trees. The inference is carried out by sampling the posterior distribution of the tree ensemble with an MCMC.

Parameters:
  • x_train (Real[Array, 'p n'] | DataFrame) – The training predictors.

  • y_train (Float32[Array, 'n'] | Float32[Array, 'k n'] | Series) – The training responses. For univariate regression, a 1D array of shape (n,). For multivariate regression, a 2D array of shape (k, n) where k is the number of response components, as introduced in [3]. For binary regression, the convention is that non-zero values mean 1, zero mean 0, like booleans.

  • outcome_type (OutcomeType | str | Sequence[OutcomeType | str], default: 'continuous') – The type of regression. 'continuous' for continuous regression, 'binary' for binary regression with probit link. For multivariate regression, a scalar value applies to all components; alternatively, a sequence of per-component types (e.g., ['binary', 'continuous']) specifies mixed outcome types.

  • sparse (bool, default: False) – Whether to activate variable selection on the predictors as done in [1].

  • theta (float | Float[Any, ''] | None, default: None)

  • a (float | Float[Any, ''], default: 0.5)

  • b (float | Float[Any, ''], default: 1.0)

  • rho (float | Float[Any, ''] | None, default: None) –

    Hyperparameters of the sparsity prior used for variable selection.

    The prior distribution on the choice of predictor for each decision rule is

    \[(s_1, \ldots, s_p) \sim \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).\]

    If theta is not specified, it’s a priori distributed according to

    \[\frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim \operatorname{Beta}(\mathtt{a}, \mathtt{b}).\]

    If not specified, rho is set to the number of predictors p. To tune the prior, consider setting a lower rho to prefer more sparsity. If setting theta directly, it should be in the ballpark of p or lower as well.

  • varprob (Float[Array, 'p'] | None, default: None) – The probability distribution over the p predictors for choosing a predictor to split on in a decision node a priori. Must be > 0. It does not need to be normalized to sum to 1. If not specified, use a uniform distribution. If sparse=True, this is used as initial value for the MCMC.

  • xinfo (Float[Array, 'p n'] | None, default: None) –

    A matrix with the cutpoins to use to bin each predictor. If not specified, it is generated automatically according to usequants and numcut.

    Each row shall contain a sorted list of cutpoints for a predictor. If there are less cutpoints than the number of columns in the matrix, fill the remaining cells with NaN.

    xinfo shall be a matrix even if x_train is a dataframe.

  • usequants (bool, default: False) – Whether to use predictors quantiles instead of a uniform grid to bin predictors. Ignored if xinfo is specified.

  • rm_const (bool, default: True) – How to treat predictors with no associated decision rules (i.e., there are no available cutpoints for that predictor). If True (default), they are ignored. If False, an error is raised if there are any.

  • sigest (float | Float[Any, ''] | Float[Array, 'k'] | None, default: None) – An estimate of the residual standard deviation on y_train, used to set lamda. If not specified, it is estimated by linear regression (with intercept, and without taking into account w). If y_train has less than two elements, it is set to 1. If n <= p, it is set to the standard deviation of y_train. Ignored if lamda is specified. For multivariate regression, can be a scalar (broadcast to all components) or a (k,) vector of per-component estimates. For mixed outcome types, binary component values are ignored.

  • sigdf (float | Float[Any, ''], default: 3.0) – The degrees of freedom of the scaled inverse-chisquared prior on the noise variance. For multivariate regression, the Inverse-Wishart degrees of freedom are set to sigdf + k - 1.

  • sigquant (float | Float[Any, ''], default: 0.9) – The quantile of the prior on the noise variance that shall match sigest to set the scale of the prior. Ignored if lamda is specified.

  • k (float | Float[Any, ''], default: 2.0) – The inverse scale of the prior standard deviation on the latent mean function, relative to half the observed range of y_train. If y_train has less than two elements, k is ignored and the scale is set to 1.

  • power (float | Float[Any, ''], default: 2.0)

  • base (float | Float[Any, ''], default: 0.95) – Parameters of the prior on tree node generation. The probability that a node at depth d (0-based) is non-terminal is base / (1 + d) ** power.

  • lamda (float | Float[Any, ''] | Float[Array, 'k'] | None, default: None) – The prior harmonic mean of the error variance. (The harmonic mean of x is 1/mean(1/x).) If not specified, it is set based on sigest and sigquant. For multivariate regression, can be a scalar (broadcast to all components) or a (k,) vector. For mixed outcome types, binary component values are ignored.

  • tau_num (float | Float[Any, ''] | None, default: None) – The numerator in the expression that determines the prior standard deviation of leaves. If not specified, default to (max(y_train) - min(y_train)) / 2 (or 1 if y_train has less than two elements) for continuous regression, and 3 for binary regression. For multivariate regression, the range is computed per component. For mixed outcome types, each component uses the default for its type.

  • offset (float | Float[Any, ''] | Float[Array, 'k'] | None, default: None) – The prior mean of the latent mean function. If not specified, it is set to the mean of y_train for continuous regression, and to Phi^-1(mean(y_train != 0)) for binary regression. If y_train is empty, offset is set to 0. With binary regression, if y_train is all zero or all non-zero, it is set to Phi^-1(1/(n+1)) or Phi^-1(n/(n+1)), respectively. For multivariate regression, can be a scalar (broadcast to all components) or a (k,) vector. If not specified, it is set to the per-component mean of y_train. For mixed outcome types, each component uses the default for its type.

  • w (Float[Array, 'n'] | Series | None, default: None) – Coefficients that rescale the error standard deviation on each datapoint. Not specifying w is equivalent to setting it to 1 for all datapoints. Note: w is ignored in the automatic determination of sigest, so either the weights should be O(1), or sigest should be specified by the user. Not supported for multivariate regression.

  • num_trees (int, default: 200) – The number of trees used to represent the latent mean function.

  • numcut (int, default: 255) –

    If usequants is False: the exact number of cutpoints used to bin the predictors, ranging between the minimum and maximum observed values (excluded).

    If usequants is True: the maximum number of cutpoints to use for binning the predictors. Each predictor is binned such that its distribution in x_train is approximately uniform across bins. The number of bins is at most the number of unique values appearing in x_train, or numcut + 1.

    Before running the algorithm, the predictors are compressed to the smallest integer type that fits the bin indices, so numcut is best set to the maximum value of an unsigned integer type, like 255.

    Ignored if xinfo is specified.

  • ndpost (int, default: 1000) – The number of MCMC samples to save, after burn-in. ndpost is the total number of samples across all chains. ndpost is rounded up to the first multiple of num_chains.

  • nskip (int, default: 1000) – The number of initial MCMC samples to discard as burn-in. This number of samples is discarded from each chain.

  • keepevery (int, default: 1) – The thinning factor for the MCMC samples, after burn-in.

  • printevery (int | None, default: 100) – The number of iterations (including thinned-away ones) between each log line. Set to None to disable logging. ^C interrupts the MCMC only every printevery iterations, so with logging disabled it’s impossible to kill the MCMC conveniently.

  • num_chains (int | None, default: 4) –

    The number of independent Markov chains to run.

    The difference between num_chains=None and num_chains=1 is that in the latter case in the object attributes and some methods there will be an explicit chain axis of size 1.

  • num_chain_devices (int | None, default: None) – The number of devices to spread the chains across. Must be a divisor of num_chains. Each device will run a fraction of the chains.

  • num_data_devices (int | None, default: None) –

    The number of devices to split datapoints across. Must be a divisor of n. This is useful only with very high n, about > 1000_000.

    If both num_chain_devices and num_data_devices are specified, the total number of devices used is the product of the two.

  • devices (Device | Sequence[Device] | None, default: None) – One or more devices used to run the MCMC on. If not specified, the computation will follow the placement of the input arrays. If a list of devices, this argument can be longer than the number of devices needed.

  • seed (int | Key[Array, ''], default: 0) – The seed for the random number generator.

  • maxdepth (int, default: 6) – The maximum depth of the trees. This is 1-based, so with the default maxdepth=6, the depths of the levels range from 0 to 5.

  • init_kw (Mapping, default: mappingproxy({})) – Additional arguments passed to bartz.mcmcstep.init.

  • run_mcmc_kw (Mapping, default: mappingproxy({})) – Additional arguments passed to bartz.mcmcloop.run_mcmc.

References

offset: Float32[Array, ''] | Float32[Array, 'k']

The prior mean of the latent mean function.

sigest: Float32[Array, ''] | Float32[Array, 'k'] | None = None

The estimated standard deviation of the error used to set lamda.

predict(x_test, *, kind='mean', key=None, w=None)[source]

Compute predictions at x_test.

Parameters:
  • x_test (Real[Array, 'p m'] | DataFrame | str) – The test predictors, or the string 'train' to compute predictions on the training data.

  • kind (PredictKind | str, default: 'mean') – The kind of output. See PredictKind for details.

  • key (Key[Array, ''] | None, default: None) – Jax random key, required when kind='outcome_samples'.

  • w (Float[Array, 'm'] | Series | None, default: None) – Per-observation error scale for kind='outcome_samples'. Required when the model was fit with weights and x_test is new data.

Returns:

Float32[Array, 'm'] | Float32[Array, 'k m'] | Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'] – Predictions at x_test in the requested format.

Raises:

ValueError – If x_test has a different format than x_train, or if w is specified when it should be None, or if w is not specified when it is required.

property ndpost: int[source]

The total number of posterior samples after burn-in across all chains.

May be larger than the initialization argument ndpost if it was not divisible by the number of chains.

property num_trees: int[source]

Return the number of trees used in the model.

get_latent_prec(only_continuous=False)[source]

Return the posterior samples of the latent error precision matrix.

Parameters:

only_continuous (bool, default: False) – If True and the model has mixed binary-continuous outcomes, return only the submatrix for the continuous components.

Returns:

Float32[Array, 'nskip+ndpost'] | Float32[Array, 'nskip+ndpost k k'] | Float32[Array, 'num_chains nskip+ndpost/num_chains'] | Float32[Array, 'num_chains nskip+ndpost/num_chains k k']MCMC samples of the error precision matrix.

Notes

This method is meant to check for convergence, so it returns the full MCMC trace and does not concatenate chains together. For probit regression, this returns the precision of the latent error term, not the Bernoulli precision for the binary outcome. For heteroskedastic regression, the returned precision is the global precision parameter, that would have to be divided by a squared weight to get the precision on a given datapoint.

Raises:

ValueError – If only_continuous is True but the model has only binary outcomes, so there is no continuous submatrix to return.

get_error_sdev(mean=False)[source]

Return the error standard deviation, post-burnin, chains concatenated.

Parameters:

mean (bool, default: False) – If True, average the precision matrix across samples first (harmonic mean at the covariance matrix level), returning a single scalar or vector instead of posterior samples.

Returns:

Float32[Array, 'ndpost'] | Float32[Array, 'ndpost k'] | Float32[Array, ''] | Float32[Array, 'k']Posterior samples (or single estimate) of the error standard deviation; NaN for binary outcomes.

Notes

Binary outcomes do have a standard deviation of course, but it’s not returned by this method because that would require to evaluate predictions on a given X, since the Bernoulli variance is p(1-p).

property varcount: Int32[Array, 'ndpost p'][source]

Histogram of predictor usage for decision rules in the trees.

property varcount_mean: Float32[Array, 'p'][source]

Average of varcount across MCMC iterations.

property varprob: Float32[Array, 'ndpost p'][source]

Posterior samples of the probability of choosing each predictor for a decision rule.

property varprob_mean: Float32[Array, 'p'][source]

The marginal posterior probability of each predictor being chosen for a decision rule.

check_trees(error=False)[source]

Apply bartz.grove.check_trace to all the tree draws.

Parameters:

error (bool, default: False) – If True, throw an error if any invalid trees are found.

Returns:

UInt[Array, 'num_chains ndpost/num_chains num_trees']An array where non-zero entries indicate invalid trees.

Raises:

RuntimeError – If error is True and any invalid trees are found.

check_replicated_trees()[source]

Check that the trees are equal across data-sharded devices.

If the data is sharded across devices, verify that the trees (which should be replicated) are identical on all shards.

Raises:

RuntimeError – If the trees differ across devices.

Return type:

None

compare_resid(y=None)[source]

Re-compute residuals to compare them with the updated ones.

Parameters:

y (Float32[Array, 'n'] | Float32[Array, 'k n'] | None, default: None) – The response variable. Required for continuous regression (since State does not store y in continuous mode). Ignored for binary regression (where State.z is used instead).

Returns:

  • resid1 – The final state of the residuals updated during the MCMC.

  • resid2 – The residuals computed from the final state of the trees.

depth_distr()[source]

Histogram of tree depths for each state of the trees.

Returns:

Int32[Array, '*num_chains ndpost/num_chains d']A matrix where each row contains a histogram of tree depths.

points_per_decision_node_distr()[source]

Histogram of number of points belonging to parent-of-leaf nodes.

Returns:

Int32[Array, '*num_chains ndpost/num_chains n+1']For each chain, a matrix where each row contains a histogram of number of points.

points_per_leaf_distr()[source]

Histogram of number of points belonging to leaves.

Returns:

Int32[Array, '*num_chains ndpost/num_chains n+1']A matrix where each row contains a histogram of number of points.

class bartz.PredictKind(*values)[source]

Kind of output of Bart.predict.

mean = 'mean'[source]

The posterior mean of the conditional mean, shape (m,) (or (k, m) for multivariate regression).

mean_samples = 'mean_samples'[source]

Per-sample conditional mean, shape (ndpost, m) (or (ndpost, k, m)). For binary regression, this is the probit-transformed sum-of-trees.

outcome_samples = 'outcome_samples'[source]

Samples of the outcome variable, shape (ndpost, m) (or (ndpost, k, m)). For binary regression, these are Bernoulli draws. For continuous regression, these are Gaussian draws with the posterior noise variance.

latent_samples = 'latent_samples'[source]

Raw sum-of-trees values, shape (ndpost, m) (or (ndpost, k, m)).