bartz.Bart¶
- class bartz.Bart(x_train, y_train, *, outcome_type='continuous', sparse=False, theta=None, a=0.5, b=1.0, rho=None, varprob=None, binner=UniqueQuantileBinner, rm_const=True, sigma_df=3.0, sigma_scale='auto', sigma_init='auto', k=2.0, power=2.0, base=0.95, tau_num=None, offset=None, error_scale=None, missing=None, num_trees=200, n_save=1000, n_burn=1000, n_skip=1, printevery=100, pbar=True, num_chains=4, num_chain_devices='auto', num_data_devices=None, devices=None, seed=0, maxdepth=6, init_kw=MappingProxyType({}), run_mcmc_kw=MappingProxyType({}))[source]¶
Nonparametric regression with Bayesian Additive Regression Trees (BART).
Regress
y_trainonx_trainwith a latent mean function represented as a sum of decision trees [2]. The inference is carried out by sampling the posterior distribution of the tree ensemble with an MCMC.- Parameters:
x_train (
Real[Array, 'p n']|Real[ndarray, 'p n']|DataFrame) – The training predictors.y_train (
Float32[Array, 'n']|Float32[ndarray, 'n']|Float32[Array, 'k n']|Float32[ndarray, 'k n']|Series|DataFrame) – The training responses. For univariate regression, a 1D array of shape(n,). For multivariate regression, a 2D array of shape(k, n)wherekis the number of response components, as introduced in [3]. For binary regression, the convention is that non-zero values mean 1, zero mean 0, like booleans.outcome_type (
OutcomeType|str|Sequence[OutcomeType|str], default:'continuous') – The type of regression.'continuous'for continuous regression,'binary'for binary regression with probit link. For multivariate regression, a scalar value applies to all components; alternatively, a sequence of per-component types (e.g.,['binary', 'continuous']) specifies mixed outcome types. Binary components in multivariate outcomes follow the multivariate probit BART formulation of [4].sparse (
bool, default:False) – Whether to activate variable selection on the predictors as done in [1].theta (
float|Float[Array, '']|Float[ndarray, '']|None, default:None)a (
float|Float[Array, '']|Float[ndarray, ''], default:0.5)b (
float|Float[Array, '']|Float[ndarray, ''], default:1.0)rho (
float|Float[Array, '']|Float[ndarray, '']|None, default:None) –Hyperparameters of the sparsity prior used for variable selection.
The prior distribution on the choice of predictor for each decision rule is
\[(s_1, \ldots, s_p) \sim \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).\]If
thetais not specified, it’s a priori distributed according to\[\frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim \operatorname{Beta}(\mathtt{a}, \mathtt{b}).\]If not specified,
rhois set to the number of predictors p. To tune the prior, consider setting a lowerrhoto prefer more sparsity. If settingthetadirectly, it should be in the ballpark of p or lower as well.varprob (
Float[Array, 'p']|Float[ndarray, 'p']|None, default:None) – The probability distribution over theppredictors for choosing a predictor to split on in a decision node a priori. Must be > 0. It does not need to be normalized to sum to 1. If not specified, use a uniform distribution. Ifsparse=True, this is used as initial value for the MCMC.binner (
BinnerFactory, default:UniqueQuantileBinner) – A callable that, given the training predictors and a random key, returns aBinnerinstance. The default isUniqueQuantileBinner, which places cutpoints at the quantiles of each predictor. Other built-in options areRangeEvenBinner(evenly-spaced cutpoints over the observed range) andGivenSplitsBinner(R BARTxinfoformat). To pass options, usefunctools.partial, e.g.binner=partial(UniqueQuantileBinner, max_bins=128).rm_const (
bool, default:True) – How to treat predictors with no associated decision rules (i.e., there are no available cutpoints for that predictor). IfTrue(default), they are ignored. IfFalse, an error is raised if there are any.sigma_df (
float|Float[Array, '']|Float[ndarray, ''], default:3.0) – The degrees of freedom of the prior on the error precision. For multivariate regression withkcomponents, the Wishart degrees of freedom are set tosigma_df + k - 1.sigma_scale (
float|Float[Array, '']|Float[ndarray, '']|Float[Array, 'k']|Float[ndarray, 'k']|Literal['auto'], default:'auto') – Sets the scale of the prior on the error precision. If ‘auto’ (default), the prior is scaled so that the error precision equalsdiag(1 / var(y_train))in expectation, where with weightserror_scalethe variance is a precision-weighted one that estimates the unit-weight error variance. Otherwise,square(sigma_scale)is the prior harmonic mean of the error variance; for multivariate regression a scalar is broadcast to all components. For mixed outcome types, binary components are ignored.sigma_init (
float|Float[Array, '']|Float[ndarray, '']|Float[Array, 'k']|Float[ndarray, 'k']|Literal['auto'], default:'auto') – The initial value of the error standard deviation in the MCMC. If ‘auto’ (default), the initial error precision is set todiag(1 / var(y_train)), with the same precision-weighted variance assigma_scalewhen weights are given. Otherwise, the initial precision isdiag(1 / square(sigma_init)); for multivariate regression a scalar is broadcast to all components. For mixed outcome types, binary components are ignored.k (
float|Float[Array, '']|Float[ndarray, ''], default:2.0) – The inverse scale of the prior standard deviation on the latent mean function, relative to half the observed range ofy_train. Ify_trainhas less than two elements,kis ignored and the scale is set to 1.power (
float|Float[Array, '']|Float[ndarray, ''], default:2.0)base (
float|Float[Array, '']|Float[ndarray, ''], default:0.95) – Parameters of the prior on tree node generation. The probability that a node at depthd(0-based) is non-terminal isbase / (1 + d) ** power.tau_num (
float|Float[Array, '']|Float[ndarray, '']|None, default:None) – The numerator in the expression that determines the prior standard deviation of leaves. If not specified, default to(max(y_train) - min(y_train)) / 2(or 1 ify_trainhas less than two elements) for continuous regression, and 3 for binary regression. For multivariate regression, the range is computed per component. For mixed outcome types, each component uses the default for its type.offset (
float|Float[Array, '']|Float[ndarray, '']|Float[Array, 'k']|Float[ndarray, 'k']|None, default:None) – The prior mean of the latent mean function. If not specified, it is set to the mean ofy_trainfor continuous regression, and toPhi^-1(mean(y_train != 0))for binary regression. Ify_trainis empty,offsetis set to 0. With binary regression, ify_trainis all zero or all non-zero, it is set toPhi^-1(1/(n+1))orPhi^-1(n/(n+1)), respectively. For multivariate regression, can be a scalar (broadcast to all components) or a(k,)vector. If not specified, it is set to the per-component mean ofy_train. For mixed outcome types, each component uses the default for its type.error_scale (
Float[Array, 'n']|Float[ndarray, 'n']|Float[Array, 'k n']|Float[ndarray, 'k n']|Series|DataFrame|None, default:None) – Coefficients that rescale the error standard deviation on each datapoint. Not specifyingerror_scaleis equivalent to setting it to 1 for all datapoints. Shape(n,)applies the same scalar weight to every outcome component; for multivariate continuous regression,(k, n)instead supplies a per-component weight per datapoint.missing (
Bool[Array, 'n']|Bool[ndarray, 'n']|Bool[Array, 'k n']|Bool[ndarray, 'k n']|Series|DataFrame|None, default:None) – Boolean mask with the same shape asy_train;Truemarks entries to be ignored by the MCMC. Values ofy_trainmust be finite everywhere, including at masked positions. If 2-D, the error covariance must be diagonal.num_trees (
int, default:200) – The number of trees used to represent the latent mean function.n_save (
int, default:1000) – The number of MCMC samples to save, after burn-in, per chain. The total trace length across all chains isnum_chains * n_save.n_burn (
int, default:1000) – The number of initial MCMC samples to discard as burn-in. This number of samples is discarded from each chain.n_skip (
int, default:1) – The thinning factor for the MCMC samples, after burn-in.printevery (
int|None, default:100) – The number of iterations (including thinned-away ones) between each log line. Set toNoneto disable progress reporting entirely (this ignorespbar). ^C interrupts the MCMC only everyprinteveryiterations, so with reporting disabled it’s impossible to kill the MCMC conveniently.pbar (
bool, default:True) – IfTrue, show atqdmprogress bar instead of printing log lines. The bar advances every iteration and refreshes the acceptance statistics everyprinteveryiterations. Ignored ifprinteveryisNone.num_chains (
int|None, default:4) –The number of independent Markov chains to run.
The difference between
num_chains=Noneandnum_chains=1is that in the latter case in the object attributes and some methods there will be an explicit chain axis of size 1.num_chain_devices (
int|None|Literal['auto'], default:'auto') – The number of devices to spread the chains across. Must be a divisor ofnum_chains. Each device will run a fraction of the chains. If ‘auto’ (default) and running on cpu, the number of devices is picked automatically based on the number of cores and the number of available devices (all the virtual jax cpu devices, or thedeviceslist if set).num_data_devices (
int|None, default:None) –The number of devices to split datapoints across. Must be a divisor of
n. This is useful only with very highn, about > 1000_000.predictparallelizes across the same devices, splitting the test points; the number of test points must be a multiple ofnum_data_devicesas well.If both num_chain_devices and num_data_devices are specified, the total number of devices used is the product of the two.
devices (
Literal['cpu','gpu'] |Device|Sequence[Device] |None, default:None) – One or more devices used to run the MCMC on. If not specified, the computation will follow the placement of the input arrays. If a list of devices, this argument can be longer than the number of devices needed.seed (
int|Key[Array, ''], default:0) – The seed for the random number generator.maxdepth (
int, default:6) – The maximum depth of the trees. This is 1-based, so with the defaultmaxdepth=6, the depths of the levels range from 0 to 5.init_kw (
Mapping, default:MappingProxyType({})) – Additional arguments passed tobartz.mcmcstep.init.run_mcmc_kw (
Mapping, default:MappingProxyType({})) – Additional arguments passed tobartz.mcmcloop.run_mcmc.
References
- predict(x_test, *, kind='mean', key=None, error_scale=None)[source]¶
Compute predictions at
x_test.- Parameters:
x_test (
Real[Array, 'p m']|Real[ndarray, 'p m']|DataFrame|str) – The test predictors, or the string'train'to compute predictions on the training data.kind (
PredictKind|str, default:'mean') – The kind of output. SeePredictKindfor details.key (
Key[Array, '']|None, default:None) – Jax random key, required whenkind='outcome_samples'.error_scale (
Float[Array, 'm']|Float[ndarray, 'm']|Float[Array, 'k m']|Float[ndarray, 'k m']|Series|DataFrame|None, default:None) – Per-observation error scale forkind='outcome_samples'. Required when the model was fit with weights andx_testis new data. Shape matches the shape used at fitting:(m,)for scalar weights,(k, m)for multivariate vector weights.
- Returns:
Float32[Array, 'm']|Float32[Array, 'k m']|Float32[Array, 'ndpost m']|Float32[Array, 'ndpost k m']– Predictions atx_testin the requested format.- Raises:
ValueError – If
x_testhas a different format thanx_train, or iferror_scaleis specified when it should beNone, or iferror_scaleis not specified when it is required, or if the model splits datapoints across devices (num_data_devices) and the number of test points is not a multiple of the number of data devices.
Notes
If the model splits datapoints across devices (
num_data_devices), the test points and the returned predictions are split the same way.
- dump(path)[source]¶
Serialize the fitted model to a file with
pickle.Notes
Intended for short-term storage (e.g. caching across processes), not long-term archival: the format depends on the versions of bartz, jax and equinox. The arrays are copied to host memory and all device/sharding placement is dropped;
loadreconstructs a single-device model.
- property offset: Float32[Array, ''] | Float32[Array, 'k'][source]¶
The prior mean of the latent mean function.
- property ndpost: int[source]¶
The total number of posterior samples after burn-in across all chains.
- get_latent_prec(only_continuous=False)[source]¶
Return the posterior samples of the latent error precision matrix.
- Parameters:
only_continuous (
bool, default:False) – IfTrueand the model has mixed binary-continuous outcomes, return only the submatrix for the continuous components.- Returns:
Float32[Array, 'n_burn_plus_n_save']|Float32[Array, 'n_burn_plus_n_save k k']|Float32[Array, 'num_chains n_burn_plus_n_save']|Float32[Array, 'num_chains n_burn_plus_n_save k k']– MCMC samples of the error precision matrix.- Raises:
ValueError – If
only_continuousisTruebut the model has only binary outcomes, so there is no continuous submatrix to return.
Notes
This method is meant to check for convergence, so it returns the full MCMC trace and does not concatenate chains together. For probit regression, this returns the precision of the latent error term, not the Bernoulli precision for the binary outcome. For heteroskedastic regression, the returned precision is the global precision parameter, that would have to be divided by a squared weight to get the precision on a given datapoint.
- get_error_sdev(mean=False)[source]¶
Return the error standard deviation, post-burnin, chains concatenated.
- Parameters:
mean (
bool, default:False) – IfTrue, average the error covariance matrix across samples before taking the square root, returning a single scalar or vector instead of posterior samples.- Returns:
Float32[Array, 'ndpost']|Float32[Array, 'ndpost k']|Float32[Array, '']|Float32[Array, 'k']– Posterior samples (or single estimate) of the error standard deviation; NaN for binary outcomes.
Notes
Binary outcomes do have a standard deviation of course, but it’s not returned by this method because that would require to evaluate predictions on a given X, since the Bernoulli variance is p(1-p).
- property varcount: Int32[Array, 'ndpost p'][source]¶
Histogram of predictor usage for decision rules in the trees.