bartz.stochtree.BARTModel

class bartz.stochtree.BARTModel[source]

BART model with a stochtree-compatible interface, powered by bartz.

This class mimics stochtree.BARTModel so that bartz can be used as a drop-in reference implementation for testing. The intersection of features is targeted: continuous regression (Gaussian outcome, identity link) and binary classification (probit link) on tabular covariates.

Use the same idiomatic pattern as stochtree.BARTModel:

m = BARTModel()
m.sample(
    X_train=X, y_train=y, X_test=X_test,
    num_gfr=0, num_mcmc=200,
    mean_forest_params={'sample_sigma2_leaf': False},
)
yhat = m.predict(X_new, terms='y_hat', type='mean')

See GeneralParams and MeanForestParams for the supported keys in the general_params and mean_forest_params dicts.

Notes

Differences from stochtree, by design:

  • num_gfr has no default and must be set explicitly to 0.

  • mean_forest_params['sample_sigma2_leaf'] must be False.

  • mean_forest_params['max_depth'] must be a non-negative integer at most 16; stochtree’s -1 (unbounded depth) sentinel is not accepted.

  • The deprecated general_params['probit_outcome_model'] flag is not accepted; pass outcome_model=OutcomeModel('binary', 'probit') instead.

  • general_params['cutpoint_grid_size'] is not accepted; bartz uses a fixed grid of 256 evenly-spaced bins per predictor. stochtree only uses this parameter for the GFR sampler, which bartz does not support.

  • Leaf-basis regression, random effects, heteroskedastic variance forests, and warm-starting from a previous model are not supported.

  • bartz uses single-precision floats, so outputs differ from stochtree at the float32 precision level.

  • general_params['random_seed'] defaults to deterministic behavior (seed 0) when unset, whereas stochtree draws a random seed. This is intentional, to make repeated fits reproducible by default.

References

Herren, A., Hahn, P. R., Murray, J., Carvalho, C. (2026). “StochTree: BART-based modeling in R and Python”. arXiv:2512.12051.

standardize: bool

Whether the outcome was standardized before fitting.

sample_sigma2_global: bool

Whether the global error variance is sampled (always True).

probit_outcome_model: bool

Whether the model uses a binary outcome with probit link.

outcome_model: OutcomeModel

Outcome family and link specification used during fitting.

num_gfr: int

Number of grow-from-root iterations (always 0).

num_burnin: int

Number of MCMC burn-in iterations.

num_mcmc: int

Number of retained MCMC iterations per chain.

num_chains: int

Number of independent MCMC chains.

num_samples: int

Total number of retained posterior samples (num_mcmc * num_chains).

sigma2_init: float | Float[Array, ''] | Float[ndarray, '']

Starting value of the global error variance actually used to seed the chain.

y_bar: Float32[Array, '']

Mean used to standardize the outcome (0 if not standardized).

y_std: Float32[Array, '']

Standard deviation used to standardize the outcome (1 if not standardized).

has_rfx: bool

Whether the model includes random effects (always False).

include_mean_forest: bool

Whether the model includes a conditional mean forest (always True).

include_variance_forest: bool

Whether the model includes a variance forest (always False).

y_hat_train: Float32[Array, 'n num_samples']

Posterior predictions at the training covariates, in the original outcome scale.

global_var_samples: Float32[Array, 'num_samples']

Posterior samples of the global error variance. For probit binary regression, an array of ones.

y_hat_test: Float32[Array, 'm num_samples'] | None

Posterior predictions at X_test if it was supplied to sample, else None.

sampled: bool

Whether sample has been called.

is_sampled()[source]

Return whether sample has been called.

Return type:

bool

sample(X_train, y_train, X_test=None, observation_weights=None, *, num_gfr, num_burnin=0, num_mcmc=100, general_params=None, mean_forest_params=None, bart_kwargs=MappingProxyType({}))[source]

Fit the model.

The signature mirrors stochtree.BARTModel.sample, restricted to the keyword arguments bartz supports.

Parameters:
  • X_train (Real[Array, 'n p'] | Real[ndarray, 'n p'] | DataFrame) – Training covariates with shape (n, p).

  • y_train (Real[Array, 'n'] | Real[ndarray, 'n'] | Series) – Training outcomes of length n.

  • X_test (Real[Array, 'm p'] | Real[ndarray, 'm p'] | DataFrame | None, default: None) – Optional test covariates; if given, predictions are cached on them in y_hat_test.

  • observation_weights (Float[Array, 'n'] | Float[ndarray, 'n'] | Series | None, default: None) – Optional positive per-observation weights scaling the residual variance (y_i | - ~ N(mu(X_i), sigma^2 / w_i)).

  • num_gfr (int) – Number of grow-from-root iterations. Must be 0.

  • num_burnin (int, default: 0) – Number of MCMC burn-in iterations.

  • num_mcmc (int, default: 100) – Number of retained MCMC iterations per chain.

  • general_params (Mapping[str, Any] | None, default: None) – Optional override for the keys of GeneralParams.

  • mean_forest_params (Mapping[str, Any] | None, default: None) – Override for the keys of MeanForestParams. Must explicitly disable sample_sigma2_leaf.

  • bart_kwargs (Mapping[str, Any], default: MappingProxyType({})) – Additional arguments forwarded to bartz.Bart. Use this to set devices and rm_const=False when wrapping sample in jax.jit.

Raises:

NotImplementedError – If num_gfr is non-zero.

Return type:

None

predict(X, *, type='posterior', terms='all', scale='linear')[source]

Predict at new covariates.

Parameters:
  • X (Real[Array, 'm p'] | Real[ndarray, 'm p'] | DataFrame) – New covariates with shape (m, p).

  • type (Literal['posterior', 'mean'], default: 'posterior') – 'posterior' returns one prediction per posterior sample, with shape (m, num_samples). 'mean' averages the posterior samples, returning a vector of shape (m,).

  • terms (Literal['y_hat', 'mean_forest', 'all'] | Sequence[Literal['y_hat', 'mean_forest', 'all']], default: 'all') – One of 'y_hat', 'mean_forest', 'all', or a list. Since random effects and a variance forest are not supported, 'y_hat' and 'mean_forest' produce the same result.

  • scale (Literal['linear', 'probability', 'class'], default: 'linear') – For probit binary regression: 'linear' returns the eta values, 'probability' returns Phi(eta), 'class' returns 0 / 1. Only 'linear' is valid for continuous outcomes.

Returns:

Shaped[Array, 'm num_samples'] | Shaped[Array, 'm'] | dict[str, Shaped[Array, 'm num_samples']] | dict[str, Shaped[Array, 'm']]Either a single jax array (for a single requested term) or a dict keyed by term name.

Raises:

NotSampledError – If sample has not been called yet.