bartz.stochtree.BARTModel¶
- class bartz.stochtree.BARTModel[source]¶
BART model with a
stochtree-compatible interface, powered by bartz.This class mimics
stochtree.BARTModelso that bartz can be used as a drop-in reference implementation for testing. The intersection of features is targeted: continuous regression (Gaussian outcome, identity link) and binary classification (probit link) on tabular covariates.Use the same idiomatic pattern as
stochtree.BARTModel:m = BARTModel() m.sample( X_train=X, y_train=y, X_test=X_test, num_gfr=0, num_mcmc=200, mean_forest_params={'sample_sigma2_leaf': False}, ) yhat = m.predict(X_new, terms='y_hat', type='mean')
See
GeneralParamsandMeanForestParamsfor the supported keys in thegeneral_paramsandmean_forest_paramsdicts.Notes
Differences from
stochtree, by design:num_gfrhas no default and must be set explicitly to0.mean_forest_params['sample_sigma2_leaf']must beFalse.mean_forest_params['max_depth']must be a non-negative integer at most16; stochtree’s-1(unbounded depth) sentinel is not accepted.The deprecated
general_params['probit_outcome_model']flag is not accepted; passoutcome_model=OutcomeModel('binary', 'probit')instead.general_params['cutpoint_grid_size']is not accepted; bartz uses a fixed grid of 256 evenly-spaced bins per predictor. stochtree only uses this parameter for the GFR sampler, which bartz does not support.Leaf-basis regression, random effects, heteroskedastic variance forests, and warm-starting from a previous model are not supported.
bartz uses single-precision floats, so outputs differ from stochtree at the float32 precision level.
general_params['random_seed']defaults to deterministic behavior (seed0) when unset, whereas stochtree draws a random seed. This is intentional, to make repeated fits reproducible by default.
References
Herren, A., Hahn, P. R., Murray, J., Carvalho, C. (2026). “StochTree: BART-based modeling in R and Python”. arXiv:2512.12051.
- outcome_model: OutcomeModel¶
Outcome family and link specification used during fitting.
- sigma2_init: float | Float[Array, ''] | Float[ndarray, '']¶
Starting value of the global error variance actually used to seed the chain.
- y_bar: Float32[Array, '']¶
Mean used to standardize the outcome (
0if not standardized).
- y_std: Float32[Array, '']¶
Standard deviation used to standardize the outcome (
1if not standardized).
- y_hat_train: Float32[Array, 'n num_samples']¶
Posterior predictions at the training covariates, in the original outcome scale.
- global_var_samples: Float32[Array, 'num_samples']¶
Posterior samples of the global error variance. For probit binary regression, an array of ones.
- y_hat_test: Float32[Array, 'm num_samples'] | None¶
Posterior predictions at
X_testif it was supplied tosample, elseNone.
- sample(X_train, y_train, X_test=None, observation_weights=None, *, num_gfr, num_burnin=0, num_mcmc=100, general_params=None, mean_forest_params=None, bart_kwargs=MappingProxyType({}))[source]¶
Fit the model.
The signature mirrors
stochtree.BARTModel.sample, restricted to the keyword arguments bartz supports.- Parameters:
X_train (
Real[Array, 'n p']|Real[ndarray, 'n p']|DataFrame) – Training covariates with shape(n, p).y_train (
Real[Array, 'n']|Real[ndarray, 'n']|Series) – Training outcomes of lengthn.X_test (
Real[Array, 'm p']|Real[ndarray, 'm p']|DataFrame|None, default:None) – Optional test covariates; if given, predictions are cached on them iny_hat_test.observation_weights (
Float[Array, 'n']|Float[ndarray, 'n']|Series|None, default:None) – Optional positive per-observation weights scaling the residual variance (y_i | - ~ N(mu(X_i), sigma^2 / w_i)).num_gfr (
int) – Number of grow-from-root iterations. Must be0.num_burnin (
int, default:0) – Number of MCMC burn-in iterations.num_mcmc (
int, default:100) – Number of retained MCMC iterations per chain.general_params (
Mapping[str,Any] |None, default:None) – Optional override for the keys ofGeneralParams.mean_forest_params (
Mapping[str,Any] |None, default:None) – Override for the keys ofMeanForestParams. Must explicitly disablesample_sigma2_leaf.bart_kwargs (
Mapping[str,Any], default:MappingProxyType({})) – Additional arguments forwarded tobartz.Bart. Use this to setdevicesandrm_const=Falsewhen wrappingsampleinjax.jit.
- Raises:
NotImplementedError – If
num_gfris non-zero.- Return type:
- predict(X, *, type='posterior', terms='all', scale='linear')[source]¶
Predict at new covariates.
- Parameters:
X (
Real[Array, 'm p']|Real[ndarray, 'm p']|DataFrame) – New covariates with shape(m, p).type (
Literal['posterior','mean'], default:'posterior') –'posterior'returns one prediction per posterior sample, with shape(m, num_samples).'mean'averages the posterior samples, returning a vector of shape(m,).terms (
Literal['y_hat','mean_forest','all'] |Sequence[Literal['y_hat','mean_forest','all']], default:'all') – One of'y_hat','mean_forest','all', or a list. Since random effects and a variance forest are not supported,'y_hat'and'mean_forest'produce the same result.scale (
Literal['linear','probability','class'], default:'linear') – For probit binary regression:'linear'returns the eta values,'probability'returnsPhi(eta),'class'returns 0 / 1. Only'linear'is valid for continuous outcomes.
- Returns:
Shaped[Array, 'm num_samples']|Shaped[Array, 'm']|dict[str,Shaped[Array, 'm num_samples']] |dict[str,Shaped[Array, 'm']] – Either a single jax array (for a single requested term) or a dict keyed by term name.- Raises:
NotSampledError – If
samplehas not been called yet.