Coverage for src/bartz/stochtree/_stochtree.py: 87%
292 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/stochtree/_stochtree.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement class `BARTModel` that mimics the Python package stochtree."""
27from collections.abc import Mapping, Sequence
28from dataclasses import dataclass, field, fields
29from functools import partial
31# WORKAROUND(python<3.15): use frozendict instead of MappingProxyType
32from types import MappingProxyType
33from typing import Any, Literal, TypeVar, overload
35from jax import numpy as jnp
36from jax.scipy.special import ndtr, ndtri
37from jaxtyping import Array, Float, Float32, Key, Real, Shaped
39from bartz._interface import Bart, DataFrame, PredictKind, Series
40from bartz.mcmcstep._state import ArrayLike, FloatLike
41from bartz.prepcovars import RangeEvenBinner
42from bartz.stochtree._preprocess import _PreprocessorBase, make_preprocessor
44T = TypeVar('T')
46_MAX_DEPTH_LIMIT = 16
49@dataclass(frozen=True)
50class OutcomeModel:
51 """Outcome model specification, matching `stochtree.OutcomeModel`.
53 Only ``('continuous', 'identity')`` and ``('binary', 'probit')`` are
54 supported.
55 """
57 outcome: Literal['continuous', 'binary'] = 'continuous'
58 """Outcome family."""
60 link: Literal['identity', 'probit'] | None = None
61 """Link function. If `None`, defaults to ``'identity'`` for ``'continuous'`` and ``'probit'`` for ``'binary'``."""
63 def __post_init__(self) -> None:
64 if self.link is None:
65 default_link = {'continuous': 'identity', 'binary': 'probit'}.get(
66 self.outcome
67 )
68 object.__setattr__(self, 'link', default_link)
69 if (self.outcome, self.link) not in (
70 ('continuous', 'identity'),
71 ('binary', 'probit'),
72 ):
73 msg = (
74 f'unsupported outcome_model (outcome={self.outcome!r}, '
75 f"link={self.link!r}); only ('continuous', 'identity') "
76 "and ('binary', 'probit') are supported."
77 )
78 raise NotImplementedError(msg)
81class NotSampledError(ValueError, AttributeError):
82 """Raised when calling a method that requires `BARTModel.sample` to have been called."""
85@dataclass(frozen=True, kw_only=True)
86class GeneralParams:
87 """Mirror of stochtree's ``general_params`` dict, with the keys bartz handles."""
89 standardize: bool = True
90 """Whether to standardize the outcome before fitting. Ignored for probit binary."""
92 sigma2_init: FloatLike | None = None
93 """Starting value of the global error variance. If `None` (default), uses ``var(resid_train)`` for continuous and ``1.0`` for probit."""
95 sigma2_global_shape: FloatLike = 0.0
96 """Shape parameter of the inverse-gamma prior on the global error variance. The default ``0`` is mapped to a near-improper prior, since bartz's scaled-inv-chi² cannot represent ``IG(0, 0)`` exactly."""
98 sigma2_global_scale: FloatLike = 0.0
99 """Scale parameter of the inverse-gamma prior on the global error variance. The default ``0`` is mapped to a near-improper prior, since bartz's scaled-inv-chi² cannot represent ``IG(0, 0)`` exactly."""
101 variable_weights: Float[ArrayLike, ' p'] | None = None
102 """Per-predictor sampling weights. Must be strictly positive; pass a small positive value to suppress a variable."""
104 random_seed: int | Key[Array, ''] | None = None
105 """Seed for the random number generator. Unlike stochtree, the default
106 `None` is deterministic (equivalent to seed ``0``) rather than drawing a
107 random seed, so repeated fits reproduce by default."""
109 keep_every: int = 1
110 """Thinning factor for retained MCMC samples."""
112 num_chains: int = 1
113 """Number of independent MCMC chains."""
115 outcome_model: OutcomeModel = field(default_factory=OutcomeModel)
116 """Outcome family and link specification. Defaults to continuous with
117 identity link."""
120@dataclass(frozen=True, kw_only=True)
121class MeanForestParams:
122 """Mirror of stochtree's ``mean_forest_params`` dict, restricted to the keys bartz handles."""
124 num_trees: int = 200
125 """Number of trees in the conditional mean ensemble."""
127 alpha: FloatLike = 0.95
128 """Tree split prior base."""
130 beta: FloatLike = 2.0
131 """Tree split prior decay."""
133 min_samples_leaf: int = 5
134 """Minimum number of training samples at a leaf."""
136 max_depth: int = 10
137 """Maximum tree depth. Must be a non-negative integer at most ``16``."""
139 sample_sigma2_leaf: bool = True
140 """Whether to sample the leaf-variance prior. Must be set to ``False``."""
142 sigma2_leaf_init: FloatLike | None = None
143 """Initial leaf-variance prior (held fixed since ``sample_sigma2_leaf=False``). If `None`, matches stochtree's defaults of ``var(resid_train) / num_trees`` for continuous and ``2 / num_trees`` for probit."""
145 def __post_init__(self) -> None:
146 if self.sample_sigma2_leaf:
147 msg = (
148 'sample_sigma2_leaf=True is not supported (bartz uses a fixed'
149 " leaf-variance prior); pass mean_forest_params={'sample_sigma2_leaf':"
150 ' False} to acknowledge this.'
151 )
152 raise NotImplementedError(msg)
153 if self.max_depth < 0: 153 ↛ 154line 153 didn't jump to line 154 because the condition on line 153 was never true
154 msg = (
155 f'max_depth={self.max_depth} is not supported; bartz stores trees'
156 ' as heap arrays of size 2**max_depth, so the stochtree'
157 ' convention max_depth=-1 (unbounded) is rejected. Pass a'
158 f' non-negative integer at most {_MAX_DEPTH_LIMIT}.'
159 )
160 raise NotImplementedError(msg)
161 if self.max_depth > _MAX_DEPTH_LIMIT: 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true
162 msg = (
163 f'max_depth={self.max_depth} exceeds {_MAX_DEPTH_LIMIT}; bartz'
164 ' stores trees as heap arrays of size 2**max_depth, so memory'
165 ' grows exponentially with depth.'
166 )
167 raise ValueError(msg)
170def build_dataclass(cls: type[T], params: Mapping[str, Any] | None, name: str) -> T:
171 """Convert a user-supplied dict to a dataclass, with friendly errors."""
172 if params is None:
173 params = {}
174 allowed = {f.name for f in fields(cls)}
175 extra = set(params) - allowed
176 if extra:
177 msg = (
178 f'{name} contains unsupported key(s) {sorted(extra)}; valid keys'
179 f' are {sorted(allowed)}'
180 )
181 raise ValueError(msg)
182 return cls(**params)
185class BARTModel:
186 R"""
187 BART model with a `stochtree`-compatible interface, powered by bartz.
189 This class mimics `stochtree.BARTModel` so that bartz can be used as a
190 drop-in reference implementation for testing. The intersection of features
191 is targeted: continuous regression (Gaussian outcome, identity link) and
192 binary classification (probit link) on tabular covariates.
194 Use the same idiomatic pattern as `stochtree.BARTModel`::
196 m = BARTModel()
197 m.sample(
198 X_train=X, y_train=y, X_test=X_test,
199 num_gfr=0, num_mcmc=200,
200 mean_forest_params={'sample_sigma2_leaf': False},
201 )
202 yhat = m.predict(X_new, terms='y_hat', type='mean')
204 See `GeneralParams` and `MeanForestParams` for the supported keys in the
205 ``general_params`` and ``mean_forest_params`` dicts.
207 Notes
208 -----
209 Differences from `stochtree`, by design:
211 - ``num_gfr`` has no default and must be set explicitly to ``0``.
212 - ``mean_forest_params['sample_sigma2_leaf']`` must be ``False``.
213 - ``mean_forest_params['max_depth']`` must be a non-negative integer at
214 most ``16``; stochtree's ``-1`` (unbounded depth) sentinel is not
215 accepted.
216 - The deprecated ``general_params['probit_outcome_model']`` flag is not
217 accepted; pass ``outcome_model=OutcomeModel('binary', 'probit')``
218 instead.
219 - ``general_params['cutpoint_grid_size']`` is not accepted; bartz uses a
220 fixed grid of 256 evenly-spaced bins per predictor. stochtree only
221 uses this parameter for the GFR sampler, which bartz does not support.
222 - Leaf-basis regression, random effects, heteroskedastic variance
223 forests, and warm-starting from a previous model are not supported.
224 - bartz uses single-precision floats, so outputs differ from stochtree
225 at the float32 precision level.
226 - ``general_params['random_seed']`` defaults to deterministic behavior
227 (seed ``0``) when unset, whereas stochtree draws a random seed. This is
228 intentional, to make repeated fits reproducible by default.
230 References
231 ----------
232 Herren, A., Hahn, P. R., Murray, J., Carvalho, C. (2026). "StochTree:
233 BART-based modeling in R and Python". arXiv:2512.12051.
234 """
236 # public, set by sample()
237 sampled: bool
238 """Whether `sample` has been called."""
240 standardize: bool
241 """Whether the outcome was standardized before fitting."""
243 sample_sigma2_global: bool
244 """Whether the global error variance is sampled (always ``True``)."""
246 probit_outcome_model: bool
247 """Whether the model uses a binary outcome with probit link."""
249 outcome_model: OutcomeModel
250 """Outcome family and link specification used during fitting."""
252 num_gfr: int
253 """Number of grow-from-root iterations (always ``0``)."""
255 num_burnin: int
256 """Number of MCMC burn-in iterations."""
258 num_mcmc: int
259 """Number of retained MCMC iterations per chain."""
261 num_chains: int
262 """Number of independent MCMC chains."""
264 num_samples: int
265 """Total number of retained posterior samples (``num_mcmc * num_chains``)."""
267 sigma2_init: FloatLike
268 """Starting value of the global error variance actually used to seed the chain."""
270 y_bar: Float32[Array, '']
271 """Mean used to standardize the outcome (``0`` if not standardized)."""
273 y_std: Float32[Array, '']
274 """Standard deviation used to standardize the outcome (``1`` if not standardized)."""
276 has_rfx: bool
277 """Whether the model includes random effects (always ``False``)."""
279 include_mean_forest: bool
280 """Whether the model includes a conditional mean forest (always ``True``)."""
282 include_variance_forest: bool
283 """Whether the model includes a variance forest (always ``False``)."""
285 y_hat_train: Float32[Array, 'n num_samples']
286 """Posterior predictions at the training covariates, in the original outcome scale."""
288 global_var_samples: Float32[Array, ' num_samples']
289 """Posterior samples of the global error variance. For probit binary regression, an array of ones."""
291 y_hat_test: Float32[Array, 'm num_samples'] | None
292 """Posterior predictions at `X_test` if it was supplied to `sample`, else `None`."""
294 _bart: Bart
295 _preprocessor: _PreprocessorBase | None
297 def __init__(self) -> None:
298 self.sampled = False
299 self._preprocessor = None
301 def is_sampled(self) -> bool:
302 """Return whether `sample` has been called."""
303 return self.sampled
305 def _prepare_training_inputs(
306 self,
307 X_train: Real[ArrayLike, 'n p'] | DataFrame,
308 y_train: Real[ArrayLike, ' n'] | Series,
309 gp: GeneralParams,
310 ) -> tuple[Real[Array, 'n p'], Real[Array, ' n'], Float32[Array, ' p'] | None]:
311 """Coerce inputs and build variable weights, fitting the DataFrame preprocessor if any."""
312 y_train_arr = _coerce_response(y_train, name='y_train')
314 self._preprocessor = make_preprocessor(X_train)
315 if self._preprocessor is None:
316 X_train_arr = check_X(X_train, name='X_train')
317 _, p = X_train_arr.shape
318 varprob = check_variable_weights(gp.variable_weights, p)
319 else:
320 # The preprocessor decides the default weights: uniform over the
321 # *original* columns split across each one-hot expansion (so every
322 # original variable keeps an equal splitting budget), or `None` when
323 # nothing expands (deferring to bartz's native uniform fast-path).
324 weights_np = self._preprocessor.fit(
325 X_train, variable_weights=gp.variable_weights
326 )
327 X_train_np = self._preprocessor.transform(X_train)
328 if X_train_np.shape[1] == 0: 328 ↛ 329line 328 didn't jump to line 329 because the condition on line 328 was never true
329 msg = 'X_train has no usable columns after preprocessing'
330 raise ValueError(msg)
331 X_train_arr = jnp.asarray(X_train_np)
332 varprob = None if weights_np is None else jnp.asarray(weights_np)
334 n, _ = X_train_arr.shape
335 if y_train_arr.shape[0] != n: 335 ↛ 336line 335 didn't jump to line 336 because the condition on line 335 was never true
336 msg = (
337 f'X_train and y_train length mismatch: X_train has {n} rows,'
338 f' y_train has {y_train_arr.shape[0]} entries'
339 )
340 raise ValueError(msg)
341 return X_train_arr, y_train_arr, varprob
343 def sample(
344 self,
345 X_train: Real[ArrayLike, 'n p'] | DataFrame,
346 y_train: Real[ArrayLike, ' n'] | Series,
347 X_test: Real[ArrayLike, 'm p'] | DataFrame | None = None,
348 observation_weights: Float[ArrayLike, ' n'] | Series | None = None,
349 *,
350 num_gfr: int,
351 num_burnin: int = 0,
352 num_mcmc: int = 100,
353 general_params: Mapping[str, Any] | None = None,
354 mean_forest_params: Mapping[str, Any] | None = None,
355 bart_kwargs: Mapping[str, Any] = MappingProxyType({}),
356 ) -> None:
357 """Fit the model.
359 The signature mirrors `stochtree.BARTModel.sample`, restricted to the
360 keyword arguments bartz supports.
362 Parameters
363 ----------
364 X_train
365 Training covariates with shape ``(n, p)``.
366 y_train
367 Training outcomes of length ``n``.
368 X_test
369 Optional test covariates; if given, predictions are cached on
370 them in `y_hat_test`.
371 observation_weights
372 Optional positive per-observation weights scaling the residual
373 variance (``y_i | - ~ N(mu(X_i), sigma^2 / w_i)``).
374 num_gfr
375 Number of grow-from-root iterations. Must be ``0``.
376 num_burnin
377 Number of MCMC burn-in iterations.
378 num_mcmc
379 Number of retained MCMC iterations per chain.
380 general_params
381 Optional override for the keys of `GeneralParams`.
382 mean_forest_params
383 Override for the keys of `MeanForestParams`. Must explicitly
384 disable ``sample_sigma2_leaf``.
385 bart_kwargs
386 Additional arguments forwarded to `bartz.Bart`. Use this to set
387 ``devices`` and ``rm_const=False`` when wrapping `sample` in
388 `jax.jit`.
390 Raises
391 ------
392 NotImplementedError
393 If ``num_gfr`` is non-zero.
394 """
395 if num_gfr != 0:
396 msg = (
397 'num_gfr must be 0; the grow-from-root sampler is not available'
398 ' in bartz.'
399 )
400 raise NotImplementedError(msg)
402 gp = build_dataclass(GeneralParams, general_params, 'general_params')
403 mfp = build_dataclass(
404 MeanForestParams, mean_forest_params, 'mean_forest_params'
405 )
407 is_probit = gp.outcome_model.outcome == 'binary'
409 X_train_arr, y_train_arr, variable_weights = self._prepare_training_inputs(
410 X_train, y_train, gp
411 )
413 y_bar, y_std, y_for_bartz = standardize_y(
414 y_train_arr, is_probit, gp.standardize
415 )
417 bart_num_chains = None if gp.num_chains == 1 else gp.num_chains
419 # variance of the standardized residual, matching stochtree
420 # (np.var(resid_train) with ddof=0). For standardize=True it is exactly
421 # 1.0; we hardcode that so the value stays trace-time concrete.
422 if is_probit:
423 var_resid_train: FloatLike = 1.0 # bartz ignores σ² for binary
424 elif gp.standardize: 424 ↛ 427line 424 didn't jump to line 427 because the condition on line 424 was always true
425 var_resid_train = 1.0
426 else:
427 var_resid_train = jnp.var(y_for_bartz)
429 # leaf-prior: bartz uses sigma_mu = tau_num / (k * sqrt(num_trees));
430 # stochtree's sigma2_leaf is the leaf-variance prior. Hold k=2 and solve
431 # for tau_num so that the two parameterizations agree.
432 bartz_k = 2.0
433 sigma2_leaf_init = resolve_sigma2_leaf_init(
434 mfp.sigma2_leaf_init, mfp.num_trees, is_probit, var_resid_train
435 )
436 tau_num_arg = bartz_k * jnp.sqrt(mfp.num_trees * sigma2_leaf_init)
438 if is_probit:
439 # stochtree pins σ²=1 for probit; bartz binary branch ignores the
440 # variance prior, so we leave the scale/init at their 'auto'
441 # defaults (bartz rejects explicit values for binary outcomes).
442 sigma_df_arg: FloatLike = 3.0
443 sigma_scale_arg: FloatLike | Literal['auto'] = 'auto'
444 sigma_init_arg: FloatLike | Literal['auto'] = 'auto'
445 sigma2_init_stored: FloatLike = 1.0
446 else:
447 sigma_df_arg, sigma_scale_arg, sigma_init_arg, sigma2_init_stored = (
448 resolve_variance_prior(
449 gp.sigma2_global_shape,
450 gp.sigma2_global_scale,
451 gp.sigma2_init,
452 var_resid_train,
453 )
454 )
456 binner = partial(RangeEvenBinner, max_bins=256)
458 seed = 0 if gp.random_seed is None else gp.random_seed
460 kwargs: dict = dict(
461 x_train=X_train_arr.T,
462 y_train=y_for_bartz,
463 outcome_type='binary' if is_probit else 'continuous',
464 binner=binner,
465 varprob=variable_weights,
466 sigma_df=sigma_df_arg,
467 sigma_scale=sigma_scale_arg,
468 sigma_init=sigma_init_arg,
469 k=bartz_k,
470 power=mfp.beta,
471 base=mfp.alpha,
472 tau_num=tau_num_arg,
473 error_scale=observation_weights,
474 num_trees=mfp.num_trees,
475 n_save=num_mcmc,
476 n_burn=num_burnin,
477 n_skip=gp.keep_every,
478 printevery=None,
479 num_chains=bart_num_chains,
480 seed=seed,
481 maxdepth=mfp.max_depth + 1,
482 )
483 kwargs.update(bart_kwargs)
484 # match stochtree's gating: only acceptance-time veto on
485 # min_samples_leaf, no per-leaf affluence filter (stochtree picks
486 # leaves uniformly over all of them). User-supplied init_kw values
487 # win on conflicts.
488 kwargs = dict(
489 kwargs,
490 init_kw=dict(
491 {
492 'min_points_per_leaf': mfp.min_samples_leaf,
493 'min_points_per_decision_node': None,
494 },
495 **kwargs.get('init_kw', {}),
496 ),
497 )
498 self._bart = Bart(**kwargs)
499 self._finalize_sample(
500 outcome_model=gp.outcome_model,
501 num_burnin=num_burnin,
502 num_mcmc=num_mcmc,
503 num_chains=gp.num_chains,
504 sigma2_init=sigma2_init_stored,
505 y_bar=y_bar,
506 y_std=y_std,
507 standardize=gp.standardize,
508 X_test=X_test,
509 )
511 def _finalize_sample(
512 self,
513 *,
514 outcome_model: OutcomeModel,
515 num_burnin: int,
516 num_mcmc: int,
517 num_chains: int,
518 sigma2_init: FloatLike,
519 y_bar: Float32[Array, ''],
520 y_std: Float32[Array, ''],
521 standardize: bool,
522 X_test: Real[ArrayLike, 'm p'] | DataFrame | None,
523 ) -> None:
524 """Populate the public attributes after `_bart` has been constructed."""
525 is_probit = outcome_model.outcome == 'binary'
526 self.sampled = True
527 self.standardize = standardize
528 self.sample_sigma2_global = True
529 self.probit_outcome_model = is_probit
530 self.outcome_model = outcome_model
531 self.num_gfr = 0
532 self.num_burnin = num_burnin
533 self.num_mcmc = num_mcmc
534 self.num_chains = num_chains
535 self.num_samples = num_mcmc * num_chains
536 self.sigma2_init = sigma2_init
537 self.y_bar = y_bar
538 self.y_std = y_std
539 self.has_rfx = False
540 self.include_mean_forest = True
541 self.include_variance_forest = False
543 # cached outputs in stochtree's (n, num_samples) layout, original scale
544 self.y_hat_train = self._predict_y_hat_internal('train')
545 if X_test is not None:
546 self.y_hat_test = self._predict_y_hat_internal(self._prepare_x(X_test).T)
547 else:
548 self.y_hat_test = None
550 if is_probit:
551 self.global_var_samples = jnp.ones((self.num_samples,))
552 else:
553 sigma = self._bart.get_error_sdev()
554 self.global_var_samples = (sigma * y_std) ** 2
556 @overload
557 def predict(
558 self,
559 X: Real[ArrayLike, 'm p'] | DataFrame,
560 *,
561 type: Literal['posterior', 'mean'] = 'posterior',
562 terms: Literal['y_hat', 'mean_forest'],
563 scale: Literal['linear', 'probability', 'class'] = 'linear',
564 ) -> Shaped[Array, 'm num_samples'] | Shaped[Array, ' m']: ...
566 @overload
567 def predict(
568 self,
569 X: Real[ArrayLike, 'm p'] | DataFrame,
570 *,
571 type: Literal['posterior', 'mean'] = 'posterior',
572 terms: Literal['all'] = 'all',
573 scale: Literal['linear', 'probability', 'class'] = 'linear',
574 ) -> dict[str, Shaped[Array, 'm num_samples']] | dict[str, Shaped[Array, ' m']]: ...
576 @overload
577 def predict(
578 self,
579 X: Real[ArrayLike, 'm p'] | DataFrame,
580 *,
581 type: Literal['posterior', 'mean'] = 'posterior',
582 terms: Sequence[Literal['y_hat', 'mean_forest', 'all']],
583 scale: Literal['linear', 'probability', 'class'] = 'linear',
584 ) -> (
585 Shaped[Array, 'm num_samples']
586 | Shaped[Array, ' m']
587 | dict[str, Shaped[Array, 'm num_samples']]
588 | dict[str, Shaped[Array, ' m']]
589 ): ...
591 def predict(
592 self,
593 X: Real[ArrayLike, 'm p'] | DataFrame,
594 *,
595 type: Literal['posterior', 'mean'] = 'posterior', # noqa: A002
596 terms: Literal['y_hat', 'mean_forest', 'all']
597 | Sequence[Literal['y_hat', 'mean_forest', 'all']] = 'all',
598 scale: Literal['linear', 'probability', 'class'] = 'linear',
599 ) -> (
600 Shaped[Array, 'm num_samples']
601 | Shaped[Array, ' m']
602 | dict[str, Shaped[Array, 'm num_samples']]
603 | dict[str, Shaped[Array, ' m']]
604 ):
605 """Predict at new covariates.
607 Parameters
608 ----------
609 X
610 New covariates with shape ``(m, p)``.
611 type
612 ``'posterior'`` returns one prediction per posterior sample, with
613 shape ``(m, num_samples)``. ``'mean'`` averages the posterior
614 samples, returning a vector of shape ``(m,)``.
615 terms
616 One of ``'y_hat'``, ``'mean_forest'``, ``'all'``, or a list. Since
617 random effects and a variance forest are not supported, ``'y_hat'``
618 and ``'mean_forest'`` produce the same result.
619 scale
620 For probit binary regression: ``'linear'`` returns the eta values,
621 ``'probability'`` returns ``Phi(eta)``, ``'class'`` returns 0 / 1.
622 Only ``'linear'`` is valid for continuous outcomes.
624 Returns
625 -------
626 Either a single jax array (for a single requested term) or a dict keyed by term name.
628 Raises
629 ------
630 NotSampledError
631 If `sample` has not been called yet.
632 """
633 if not self.sampled:
634 msg = (
635 "This BARTModel instance is not fitted yet. Call 'sample' before"
636 ' using this model.'
637 )
638 raise NotSampledError(msg)
639 terms_tuple = check_predict_args(type, scale, terms, self.probit_outcome_model)
641 pred = self._predict_y_hat_internal(self._prepare_x(X).T)
643 if self.probit_outcome_model and scale in ('probability', 'class'):
644 prob = ndtr(pred)
645 pred_out = jnp.where(prob < 0.5, 0, 1) if scale == 'class' else prob
646 else:
647 pred_out = pred
649 if type == 'mean':
650 pred_out = jnp.mean(pred_out, axis=1)
652 wants_y_hat = ('y_hat' in terms_tuple) or ('all' in terms_tuple)
653 wants_mean_forest = ('mean_forest' in terms_tuple) or ('all' in terms_tuple)
654 single = sum([wants_y_hat, wants_mean_forest]) == 1
655 if single:
656 return pred_out
657 result: dict[str, Shaped[Array, '...']] = {}
658 if wants_y_hat: 658 ↛ 660line 658 didn't jump to line 660 because the condition on line 658 was always true
659 result['y_hat'] = pred_out
660 if wants_mean_forest: 660 ↛ 662line 660 didn't jump to line 662 because the condition on line 660 was always true
661 result['mean_forest_predictions'] = pred_out
662 return result
664 def _prepare_x(self, X: Real[ArrayLike, 'm p'] | DataFrame) -> Real[Array, 'm p']:
665 """Convert covariates to a 2-D jax array, replaying the fitted preprocessor if any."""
666 if self._preprocessor is None:
667 return check_X(X)
668 if make_preprocessor(X) is None:
669 msg = (
670 'this model was fit on a DataFrame, so prediction covariates must'
671 ' also be a pandas/polars DataFrame with the same columns; got a'
672 ' non-DataFrame. Passing a raw array would bypass the fitted'
673 ' preprocessing (e.g. one-hot encoding) and silently misalign the'
674 ' features.'
675 )
676 raise TypeError(msg)
677 return jnp.asarray(self._preprocessor.transform(X))
679 def _predict_y_hat_internal(
680 self, x: Real[ArrayLike, 'p m'] | Literal['train']
681 ) -> Float32[Array, 'm num_samples']:
682 """Return predictions on the original outcome scale, layout ``(m, num_samples)``."""
683 latent = self._bart.predict(x, kind=PredictKind.latent_samples)
684 if self.probit_outcome_model:
685 # bartz integrates the binary offset into latent; result already on probit scale.
686 return latent.T
687 if self.standardize: 687 ↛ 689line 687 didn't jump to line 689 because the condition on line 687 was always true
688 return (latent * self.y_std + self.y_bar).T
689 return latent.T
692def standardize_y(
693 y_train: Real[ArrayLike, ' n'], is_probit: bool, standardize: bool
694) -> tuple[Float32[Array, ''], Float32[Array, ''], Float32[Array, ' n']]:
695 """Return ``(y_bar, y_std, y_for_bartz)`` matching stochtree's standardization."""
696 y = jnp.asarray(y_train, jnp.float32)
697 if is_probit:
698 return ndtri(y.mean()), jnp.float32(1.0), (y != 0).astype(jnp.float32)
699 if standardize: 699 ↛ 704line 699 didn't jump to line 704 because the condition on line 699 was always true
700 y_bar = y.mean()
701 y_std_val = y.std()
702 y_std = jnp.where(y_std_val > 0, y_std_val, 1.0)
703 return y_bar, y_std, (y - y_bar) / y_std
704 return jnp.float32(0.0), jnp.float32(1.0), y
707def resolve_sigma2_leaf_init(
708 sigma2_leaf_init: FloatLike | None,
709 num_trees: int,
710 is_probit: bool,
711 var_resid_train: FloatLike,
712) -> FloatLike:
713 """Default `sigma2_leaf_init` per stochtree: probit→2/num_trees, continuous→var(resid)/num_trees."""
714 if sigma2_leaf_init is not None:
715 return sigma2_leaf_init
716 if is_probit:
717 return 2.0 / num_trees
718 return var_resid_train / num_trees
721def resolve_variance_prior(
722 shape: FloatLike,
723 scale: FloatLike,
724 sigma2_init: FloatLike | None,
725 var_resid_train: FloatLike,
726) -> tuple[Float32[Array, ''], Float32[Array, ''], Float32[Array, ''], FloatLike]:
727 """Translate stochtree's IG(shape, scale) prior to bartz's error variance prior.
729 The IG(shape, scale) prior on σ² is the scaled-inverse-χ² with
730 ``sigma_df = 2*shape`` and prior harmonic mean ``square(sigma_scale) =
731 scale/shape``; the chain starts at `sigma2_init` (default
732 ``var(resid_train)``), decoupled from the prior. The mapping is branchless so
733 `shape` / `scale` may be traced; the unrepresentable IG(0, scale>0) (positive
734 rate, zero df) yields a NaN that surfaces downstream rather than an error.
736 Parameters
737 ----------
738 shape
739 Stochtree's ``sigma2_global_shape``.
740 scale
741 Stochtree's ``sigma2_global_scale``.
742 sigma2_init
743 Stochtree's ``sigma2_init``. If `None`, defaults to `var_resid_train`.
744 var_resid_train
745 Variance of the residual, the default chain start for σ².
747 Returns
748 -------
749 sigma_df : Float32[Array, '']
750 Degrees of freedom of bartz's error variance prior.
751 sigma_scale : Float32[Array, '']
752 Scale of bartz's prior (sqrt of the prior harmonic mean of the variance).
753 sigma_init : Float32[Array, '']
754 Initial error standard deviation seeding the chain.
755 sigma2_init_stored : FloatLike
756 The chain starting value of σ², suitable for ``BARTModel.sigma2_init``.
757 """
758 shape = jnp.asarray(shape, jnp.float32)
759 scale = jnp.asarray(scale, jnp.float32)
760 sigma2_start = sigma2_init if sigma2_init is not None else var_resid_train
761 sigma_init = jnp.sqrt(jnp.asarray(sigma2_start, jnp.float32))
762 # IG(shape, scale) <=> scaled-inv-chi2(df=2*shape, harmonic mean=scale/shape).
763 # The `scale > 0` guard keeps IG(0, 0) at harmonic mean 0 (avoiding 0/0) while
764 # letting IG(0, scale>0) overflow to inf -> NaN rate, flagging it as invalid.
765 harmonic_mean = jnp.where(scale > 0, scale / shape, 0.0)
766 return 2.0 * shape, jnp.sqrt(harmonic_mean), sigma_init, sigma2_start
769def check_variable_weights(
770 variable_weights: Float[ArrayLike, ' p'] | None, p: int
771) -> Float32[Array, ' p'] | None:
772 """Validate `variable_weights`, returning the jax array (or None)."""
773 if variable_weights is None:
774 return None
775 arr = jnp.asarray(variable_weights, jnp.float32)
776 if arr.shape != (p,): 776 ↛ 777line 776 didn't jump to line 777 because the condition on line 776 was never true
777 msg = f'variable_weights must have shape (p,)=({p},), got {arr.shape}'
778 raise ValueError(msg)
779 return arr
782def check_predict_args(
783 type_: Literal['posterior', 'mean'],
784 scale: Literal['linear', 'probability', 'class'],
785 terms: Literal['y_hat', 'mean_forest', 'all']
786 | Sequence[Literal['y_hat', 'mean_forest', 'all']],
787 probit_outcome_model: bool,
788) -> tuple[str, ...]:
789 """Validate `BARTModel.predict` arguments, returning the normalized terms tuple."""
790 if scale not in ('linear', 'probability', 'class'): 790 ↛ 791line 790 didn't jump to line 791 because the condition on line 790 was never true
791 msg = f"scale must be 'linear', 'probability', or 'class'; got {scale!r}"
792 raise ValueError(msg)
793 if type_ not in ('posterior', 'mean'): 793 ↛ 794line 793 didn't jump to line 794 because the condition on line 793 was never true
794 msg = f"type must be 'posterior' or 'mean'; got {type_!r}"
795 raise ValueError(msg)
796 if not probit_outcome_model and scale != 'linear': 796 ↛ 797line 796 didn't jump to line 797 because the condition on line 796 was never true
797 msg = (
798 "scale must be 'linear' for non-probit (continuous) regression;"
799 f' got {scale!r}'
800 )
801 raise ValueError(msg)
802 if type_ == 'mean' and scale == 'class': 802 ↛ 803line 802 didn't jump to line 803 because the condition on line 802 was never true
803 msg = "scale='class' is incompatible with type='mean'"
804 raise ValueError(msg)
805 terms_tuple = (terms,) if isinstance(terms, str) else tuple(terms)
806 for t in terms_tuple:
807 if t not in ('y_hat', 'mean_forest', 'all'): 807 ↛ 808line 807 didn't jump to line 808 because the condition on line 807 was never true
808 msg = f'unknown term {t!r}; valid terms are y_hat, mean_forest, all'
809 raise ValueError(msg)
810 if scale == 'class' and set(terms_tuple) != {'y_hat'}:
811 # match stochtree: 'class' converts only the single 'y_hat' term, so it
812 # rejects 'mean_forest' and 'all' (the latter also pulls in mean_forest)
813 msg = "scale='class' is only supported when requesting a single 'y_hat' term"
814 raise ValueError(msg)
815 return terms_tuple
818def check_X(
819 X: Real[ArrayLike, 'n p'] | DataFrame, *, name: str = 'X'
820) -> Real[Array, 'n p']:
821 """Convert a DataFrame/array-like to a 2-D jax array in ``(n, p)`` layout."""
822 if isinstance(X, DataFrame): 822 ↛ 823line 822 didn't jump to line 823 because the condition on line 822 was never true
823 X = X.to_numpy()
824 arr = jnp.asarray(X)
825 if arr.ndim == 1: 825 ↛ 826line 825 didn't jump to line 826 because the condition on line 825 was never true
826 arr = arr[:, None]
827 if arr.ndim != 2: 827 ↛ 828line 827 didn't jump to line 828 because the condition on line 827 was never true
828 msg = f'{name} must be 2D (n, p); got shape {arr.shape}'
829 raise ValueError(msg)
830 return arr
833def _coerce_response(
834 y: Real[ArrayLike, ' n'] | Series, *, name: str
835) -> Real[Array, ' n']:
836 """Convert a Series/array-like response to a 1-D jax array."""
837 if isinstance(y, Series): 837 ↛ 838line 837 didn't jump to line 838 because the condition on line 837 was never true
838 y = y.to_numpy()
839 arr = jnp.asarray(y)
840 if arr.ndim != 1: 840 ↛ 841line 840 didn't jump to line 841 because the condition on line 840 was never true
841 msg = f'{name} must be 1D (n,); got shape {arr.shape}'
842 raise ValueError(msg)
843 return arr