bartz.Bart

class bartz.Bart(x_train, y_train, *, outcome_type='continuous', sparse=False, theta=None, a=0.5, b=1.0, rho=None, varprob=None, binner=UniqueQuantileBinner, rm_const=True, sigma_df=3.0, sigma_scale='auto', sigma_init='auto', k=2.0, power=2.0, base=0.95, tau_num=None, offset=None, error_scale=None, missing=None, num_trees=200, n_save=1000, n_burn=1000, n_skip=1, printevery=100, pbar=True, num_chains=4, num_chain_devices='auto', num_data_devices=None, devices=None, seed=0, maxdepth=6, init_kw=MappingProxyType({}), run_mcmc_kw=MappingProxyType({}))[source]

Nonparametric regression with Bayesian Additive Regression Trees (BART).

Regress y_train on x_train with a latent mean function represented as a sum of decision trees [2]. The inference is carried out by sampling the posterior distribution of the tree ensemble with an MCMC.

Parameters:
  • x_train (Real[Array, 'p n'] | Real[ndarray, 'p n'] | DataFrame) – The training predictors.

  • y_train (Float32[Array, 'n'] | Float32[ndarray, 'n'] | Float32[Array, 'k n'] | Float32[ndarray, 'k n'] | Series | DataFrame) – The training responses. For univariate regression, a 1D array of shape (n,). For multivariate regression, a 2D array of shape (k, n) where k is the number of response components, as introduced in [3]. For binary regression, the convention is that non-zero values mean 1, zero mean 0, like booleans.

  • outcome_type (OutcomeType | str | Sequence[OutcomeType | str], default: 'continuous') – The type of regression. 'continuous' for continuous regression, 'binary' for binary regression with probit link. For multivariate regression, a scalar value applies to all components; alternatively, a sequence of per-component types (e.g., ['binary', 'continuous']) specifies mixed outcome types. Binary components in multivariate outcomes follow the multivariate probit BART formulation of [4].

  • sparse (bool, default: False) – Whether to activate variable selection on the predictors as done in [1].

  • theta (float | Float[Array, ''] | Float[ndarray, ''] | None, default: None)

  • a (float | Float[Array, ''] | Float[ndarray, ''], default: 0.5)

  • b (float | Float[Array, ''] | Float[ndarray, ''], default: 1.0)

  • rho (float | Float[Array, ''] | Float[ndarray, ''] | None, default: None) –

    Hyperparameters of the sparsity prior used for variable selection.

    The prior distribution on the choice of predictor for each decision rule is

    \[(s_1, \ldots, s_p) \sim \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).\]

    If theta is not specified, it’s a priori distributed according to

    \[\frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim \operatorname{Beta}(\mathtt{a}, \mathtt{b}).\]

    If not specified, rho is set to the number of predictors p. To tune the prior, consider setting a lower rho to prefer more sparsity. If setting theta directly, it should be in the ballpark of p or lower as well.

  • varprob (Float[Array, 'p'] | Float[ndarray, 'p'] | None, default: None) – The probability distribution over the p predictors for choosing a predictor to split on in a decision node a priori. Must be > 0. It does not need to be normalized to sum to 1. If not specified, use a uniform distribution. If sparse=True, this is used as initial value for the MCMC.

  • binner (BinnerFactory, default: UniqueQuantileBinner) – A callable that, given the training predictors and a random key, returns a Binner instance. The default is UniqueQuantileBinner, which places cutpoints at the quantiles of each predictor. Other built-in options are RangeEvenBinner (evenly-spaced cutpoints over the observed range) and GivenSplitsBinner (R BART xinfo format). To pass options, use functools.partial, e.g. binner=partial(UniqueQuantileBinner, max_bins=128).

  • rm_const (bool, default: True) – How to treat predictors with no associated decision rules (i.e., there are no available cutpoints for that predictor). If True (default), they are ignored. If False, an error is raised if there are any.

  • sigma_df (float | Float[Array, ''] | Float[ndarray, ''], default: 3.0) – The degrees of freedom of the prior on the error precision. For multivariate regression with k components, the Wishart degrees of freedom are set to sigma_df + k - 1.

  • sigma_scale (float | Float[Array, ''] | Float[ndarray, ''] | Float[Array, 'k'] | Float[ndarray, 'k'] | Literal['auto'], default: 'auto') – Sets the scale of the prior on the error precision. If ‘auto’ (default), the prior is scaled so that the error precision equals diag(1 / var(y_train)) in expectation, where with weights error_scale the variance is a precision-weighted one that estimates the unit-weight error variance. Otherwise, square(sigma_scale) is the prior harmonic mean of the error variance; for multivariate regression a scalar is broadcast to all components. For mixed outcome types, binary components are ignored.

  • sigma_init (float | Float[Array, ''] | Float[ndarray, ''] | Float[Array, 'k'] | Float[ndarray, 'k'] | Literal['auto'], default: 'auto') – The initial value of the error standard deviation in the MCMC. If ‘auto’ (default), the initial error precision is set to diag(1 / var(y_train)), with the same precision-weighted variance as sigma_scale when weights are given. Otherwise, the initial precision is diag(1 / square(sigma_init)); for multivariate regression a scalar is broadcast to all components. For mixed outcome types, binary components are ignored.

  • k (float | Float[Array, ''] | Float[ndarray, ''], default: 2.0) – The inverse scale of the prior standard deviation on the latent mean function, relative to half the observed range of y_train. If y_train has less than two elements, k is ignored and the scale is set to 1.

  • power (float | Float[Array, ''] | Float[ndarray, ''], default: 2.0)

  • base (float | Float[Array, ''] | Float[ndarray, ''], default: 0.95) – Parameters of the prior on tree node generation. The probability that a node at depth d (0-based) is non-terminal is base / (1 + d) ** power.

  • tau_num (float | Float[Array, ''] | Float[ndarray, ''] | None, default: None) – The numerator in the expression that determines the prior standard deviation of leaves. If not specified, default to (max(y_train) - min(y_train)) / 2 (or 1 if y_train has less than two elements) for continuous regression, and 3 for binary regression. For multivariate regression, the range is computed per component. For mixed outcome types, each component uses the default for its type.

  • offset (float | Float[Array, ''] | Float[ndarray, ''] | Float[Array, 'k'] | Float[ndarray, 'k'] | None, default: None) – The prior mean of the latent mean function. If not specified, it is set to the mean of y_train for continuous regression, and to Phi^-1(mean(y_train != 0)) for binary regression. If y_train is empty, offset is set to 0. With binary regression, if y_train is all zero or all non-zero, it is set to Phi^-1(1/(n+1)) or Phi^-1(n/(n+1)), respectively. For multivariate regression, can be a scalar (broadcast to all components) or a (k,) vector. If not specified, it is set to the per-component mean of y_train. For mixed outcome types, each component uses the default for its type.

  • error_scale (Float[Array, 'n'] | Float[ndarray, 'n'] | Float[Array, 'k n'] | Float[ndarray, 'k n'] | Series | DataFrame | None, default: None) – Coefficients that rescale the error standard deviation on each datapoint. Not specifying error_scale is equivalent to setting it to 1 for all datapoints. Shape (n,) applies the same scalar weight to every outcome component; for multivariate continuous regression, (k, n) instead supplies a per-component weight per datapoint.

  • missing (Bool[Array, 'n'] | Bool[ndarray, 'n'] | Bool[Array, 'k n'] | Bool[ndarray, 'k n'] | Series | DataFrame | None, default: None) – Boolean mask with the same shape as y_train; True marks entries to be ignored by the MCMC. Values of y_train must be finite everywhere, including at masked positions. If 2-D, the error covariance must be diagonal.

  • num_trees (int, default: 200) – The number of trees used to represent the latent mean function.

  • n_save (int, default: 1000) – The number of MCMC samples to save, after burn-in, per chain. The total trace length across all chains is num_chains * n_save.

  • n_burn (int, default: 1000) – The number of initial MCMC samples to discard as burn-in. This number of samples is discarded from each chain.

  • n_skip (int, default: 1) – The thinning factor for the MCMC samples, after burn-in.

  • printevery (int | None, default: 100) – The number of iterations (including thinned-away ones) between each log line. Set to None to disable progress reporting entirely (this ignores pbar). ^C interrupts the MCMC only every printevery iterations, so with reporting disabled it’s impossible to kill the MCMC conveniently.

  • pbar (bool, default: True) – If True, show a tqdm progress bar instead of printing log lines. The bar advances every iteration and refreshes the acceptance statistics every printevery iterations. Ignored if printevery is None.

  • num_chains (int | None, default: 4) –

    The number of independent Markov chains to run.

    The difference between num_chains=None and num_chains=1 is that in the latter case in the object attributes and some methods there will be an explicit chain axis of size 1.

  • num_chain_devices (int | None | Literal['auto'], default: 'auto') – The number of devices to spread the chains across. Must be a divisor of num_chains. Each device will run a fraction of the chains. If ‘auto’ (default) and running on cpu, the number of devices is picked automatically based on the number of cores and the number of available devices (all the virtual jax cpu devices, or the devices list if set).

  • num_data_devices (int | None, default: None) –

    The number of devices to split datapoints across. Must be a divisor of n. This is useful only with very high n, about > 1000_000. predict parallelizes across the same devices, splitting the test points; the number of test points must be a multiple of num_data_devices as well.

    If both num_chain_devices and num_data_devices are specified, the total number of devices used is the product of the two.

  • devices (Literal['cpu', 'gpu'] | Device | Sequence[Device] | None, default: None) – One or more devices used to run the MCMC on. If not specified, the computation will follow the placement of the input arrays. If a list of devices, this argument can be longer than the number of devices needed.

  • seed (int | Key[Array, ''], default: 0) – The seed for the random number generator.

  • maxdepth (int, default: 6) – The maximum depth of the trees. This is 1-based, so with the default maxdepth=6, the depths of the levels range from 0 to 5.

  • init_kw (Mapping, default: MappingProxyType({})) – Additional arguments passed to bartz.mcmcstep.init.

  • run_mcmc_kw (Mapping, default: MappingProxyType({})) – Additional arguments passed to bartz.mcmcloop.run_mcmc.

References

predict(x_test, *, kind='mean', key=None, error_scale=None)[source]

Compute predictions at x_test.

Parameters:
  • x_test (Real[Array, 'p m'] | Real[ndarray, 'p m'] | DataFrame | str) – The test predictors, or the string 'train' to compute predictions on the training data.

  • kind (PredictKind | str, default: 'mean') – The kind of output. See PredictKind for details.

  • key (Key[Array, ''] | None, default: None) – Jax random key, required when kind='outcome_samples'.

  • error_scale (Float[Array, 'm'] | Float[ndarray, 'm'] | Float[Array, 'k m'] | Float[ndarray, 'k m'] | Series | DataFrame | None, default: None) – Per-observation error scale for kind='outcome_samples'. Required when the model was fit with weights and x_test is new data. Shape matches the shape used at fitting: (m,) for scalar weights, (k, m) for multivariate vector weights.

Returns:

Float32[Array, 'm'] | Float32[Array, 'k m'] | Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'] – Predictions at x_test in the requested format.

Raises:

ValueError – If x_test has a different format than x_train, or if error_scale is specified when it should be None, or if error_scale is not specified when it is required, or if the model splits datapoints across devices (num_data_devices) and the number of test points is not a multiple of the number of data devices.

Notes

If the model splits datapoints across devices (num_data_devices), the test points and the returned predictions are split the same way.

dump(path)[source]

Serialize the fitted model to a file with pickle.

Parameters:

path (str | PathLike) – The file to write to.

Return type:

None

Notes

Intended for short-term storage (e.g. caching across processes), not long-term archival: the format depends on the versions of bartz, jax and equinox. The arrays are copied to host memory and all device/sharding placement is dropped; load reconstructs a single-device model.

classmethod load(path)[source]

Load a model saved with dump.

Parameters:

path (str | PathLike) – The file to read from.

Returns:

BartThe deserialized model, on host memory with no device placement.

Raises:

TypeError – If the file does not contain a Bart instance.

property offset: Float32[Array, ''] | Float32[Array, 'k'][source]

The prior mean of the latent mean function.

property n_save: int[source]

The number of posterior samples after burn-in saved per chain.

property num_chains: int | None[source]

The number of chains, None if scalar.

property ndpost: int[source]

The total number of posterior samples after burn-in across all chains.

property num_trees: int[source]

Return the number of trees used in the model.

get_latent_prec(only_continuous=False)[source]

Return the posterior samples of the latent error precision matrix.

Parameters:

only_continuous (bool, default: False) – If True and the model has mixed binary-continuous outcomes, return only the submatrix for the continuous components.

Returns:

Float32[Array, 'n_burn_plus_n_save'] | Float32[Array, 'n_burn_plus_n_save k k'] | Float32[Array, 'num_chains n_burn_plus_n_save'] | Float32[Array, 'num_chains n_burn_plus_n_save k k']MCMC samples of the error precision matrix.

Raises:

ValueError – If only_continuous is True but the model has only binary outcomes, so there is no continuous submatrix to return.

Notes

This method is meant to check for convergence, so it returns the full MCMC trace and does not concatenate chains together. For probit regression, this returns the precision of the latent error term, not the Bernoulli precision for the binary outcome. For heteroskedastic regression, the returned precision is the global precision parameter, that would have to be divided by a squared weight to get the precision on a given datapoint.

get_error_sdev(mean=False)[source]

Return the error standard deviation, post-burnin, chains concatenated.

Parameters:

mean (bool, default: False) – If True, average the error covariance matrix across samples before taking the square root, returning a single scalar or vector instead of posterior samples.

Returns:

Float32[Array, 'ndpost'] | Float32[Array, 'ndpost k'] | Float32[Array, ''] | Float32[Array, 'k']Posterior samples (or single estimate) of the error standard deviation; NaN for binary outcomes.

Notes

Binary outcomes do have a standard deviation of course, but it’s not returned by this method because that would require to evaluate predictions on a given X, since the Bernoulli variance is p(1-p).

property varcount: Int32[Array, 'ndpost p'][source]

Histogram of predictor usage for decision rules in the trees.

property varcount_mean: Float32[Array, 'p'][source]

Average of varcount across MCMC iterations.

property varprob: Float32[Array, 'ndpost p'][source]

Posterior samples of the probability of choosing each predictor for a decision rule.

property varprob_mean: Float32[Array, 'p'][source]

The marginal posterior probability of each predictor being chosen for a decision rule.