Coverage for src/bartz/stochtree/_stochtree.py: 87%

292 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/stochtree/_stochtree.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement class `BARTModel` that mimics the Python package stochtree.""" 

26 

27from collections.abc import Mapping, Sequence 

28from dataclasses import dataclass, field, fields 

29from functools import partial 

30 

31# WORKAROUND(python<3.15): use frozendict instead of MappingProxyType 

32from types import MappingProxyType 

33from typing import Any, Literal, TypeVar, overload 

34 

35from jax import numpy as jnp 

36from jax.scipy.special import ndtr, ndtri 

37from jaxtyping import Array, Float, Float32, Key, Real, Shaped 

38 

39from bartz._interface import Bart, DataFrame, PredictKind, Series 

40from bartz.mcmcstep._state import ArrayLike, FloatLike 

41from bartz.prepcovars import RangeEvenBinner 

42from bartz.stochtree._preprocess import _PreprocessorBase, make_preprocessor 

43 

44T = TypeVar('T') 

45 

46_MAX_DEPTH_LIMIT = 16 

47 

48 

49@dataclass(frozen=True) 

50class OutcomeModel: 

51 """Outcome model specification, matching `stochtree.OutcomeModel`. 

52 

53 Only ``('continuous', 'identity')`` and ``('binary', 'probit')`` are 

54 supported. 

55 """ 

56 

57 outcome: Literal['continuous', 'binary'] = 'continuous' 

58 """Outcome family.""" 

59 

60 link: Literal['identity', 'probit'] | None = None 

61 """Link function. If `None`, defaults to ``'identity'`` for ``'continuous'`` and ``'probit'`` for ``'binary'``.""" 

62 

63 def __post_init__(self) -> None: 

64 if self.link is None: 

65 default_link = {'continuous': 'identity', 'binary': 'probit'}.get( 

66 self.outcome 

67 ) 

68 object.__setattr__(self, 'link', default_link) 

69 if (self.outcome, self.link) not in ( 

70 ('continuous', 'identity'), 

71 ('binary', 'probit'), 

72 ): 

73 msg = ( 

74 f'unsupported outcome_model (outcome={self.outcome!r}, ' 

75 f"link={self.link!r}); only ('continuous', 'identity') " 

76 "and ('binary', 'probit') are supported." 

77 ) 

78 raise NotImplementedError(msg) 

79 

80 

81class NotSampledError(ValueError, AttributeError): 

82 """Raised when calling a method that requires `BARTModel.sample` to have been called.""" 

83 

84 

85@dataclass(frozen=True, kw_only=True) 

86class GeneralParams: 

87 """Mirror of stochtree's ``general_params`` dict, with the keys bartz handles.""" 

88 

89 standardize: bool = True 

90 """Whether to standardize the outcome before fitting. Ignored for probit binary.""" 

91 

92 sigma2_init: FloatLike | None = None 

93 """Starting value of the global error variance. If `None` (default), uses ``var(resid_train)`` for continuous and ``1.0`` for probit.""" 

94 

95 sigma2_global_shape: FloatLike = 0.0 

96 """Shape parameter of the inverse-gamma prior on the global error variance. The default ``0`` is mapped to a near-improper prior, since bartz's scaled-inv-chi² cannot represent ``IG(0, 0)`` exactly.""" 

97 

98 sigma2_global_scale: FloatLike = 0.0 

99 """Scale parameter of the inverse-gamma prior on the global error variance. The default ``0`` is mapped to a near-improper prior, since bartz's scaled-inv-chi² cannot represent ``IG(0, 0)`` exactly.""" 

100 

101 variable_weights: Float[ArrayLike, ' p'] | None = None 

102 """Per-predictor sampling weights. Must be strictly positive; pass a small positive value to suppress a variable.""" 

103 

104 random_seed: int | Key[Array, ''] | None = None 

105 """Seed for the random number generator. Unlike stochtree, the default 

106 `None` is deterministic (equivalent to seed ``0``) rather than drawing a 

107 random seed, so repeated fits reproduce by default.""" 

108 

109 keep_every: int = 1 

110 """Thinning factor for retained MCMC samples.""" 

111 

112 num_chains: int = 1 

113 """Number of independent MCMC chains.""" 

114 

115 outcome_model: OutcomeModel = field(default_factory=OutcomeModel) 

116 """Outcome family and link specification. Defaults to continuous with 

117 identity link.""" 

118 

119 

120@dataclass(frozen=True, kw_only=True) 

121class MeanForestParams: 

122 """Mirror of stochtree's ``mean_forest_params`` dict, restricted to the keys bartz handles.""" 

123 

124 num_trees: int = 200 

125 """Number of trees in the conditional mean ensemble.""" 

126 

127 alpha: FloatLike = 0.95 

128 """Tree split prior base.""" 

129 

130 beta: FloatLike = 2.0 

131 """Tree split prior decay.""" 

132 

133 min_samples_leaf: int = 5 

134 """Minimum number of training samples at a leaf.""" 

135 

136 max_depth: int = 10 

137 """Maximum tree depth. Must be a non-negative integer at most ``16``.""" 

138 

139 sample_sigma2_leaf: bool = True 

140 """Whether to sample the leaf-variance prior. Must be set to ``False``.""" 

141 

142 sigma2_leaf_init: FloatLike | None = None 

143 """Initial leaf-variance prior (held fixed since ``sample_sigma2_leaf=False``). If `None`, matches stochtree's defaults of ``var(resid_train) / num_trees`` for continuous and ``2 / num_trees`` for probit.""" 

144 

145 def __post_init__(self) -> None: 

146 if self.sample_sigma2_leaf: 

147 msg = ( 

148 'sample_sigma2_leaf=True is not supported (bartz uses a fixed' 

149 " leaf-variance prior); pass mean_forest_params={'sample_sigma2_leaf':" 

150 ' False} to acknowledge this.' 

151 ) 

152 raise NotImplementedError(msg) 

153 if self.max_depth < 0: 153 ↛ 154line 153 didn't jump to line 154 because the condition on line 153 was never true

154 msg = ( 

155 f'max_depth={self.max_depth} is not supported; bartz stores trees' 

156 ' as heap arrays of size 2**max_depth, so the stochtree' 

157 ' convention max_depth=-1 (unbounded) is rejected. Pass a' 

158 f' non-negative integer at most {_MAX_DEPTH_LIMIT}.' 

159 ) 

160 raise NotImplementedError(msg) 

161 if self.max_depth > _MAX_DEPTH_LIMIT: 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true

162 msg = ( 

163 f'max_depth={self.max_depth} exceeds {_MAX_DEPTH_LIMIT}; bartz' 

164 ' stores trees as heap arrays of size 2**max_depth, so memory' 

165 ' grows exponentially with depth.' 

166 ) 

167 raise ValueError(msg) 

168 

169 

170def build_dataclass(cls: type[T], params: Mapping[str, Any] | None, name: str) -> T: 

171 """Convert a user-supplied dict to a dataclass, with friendly errors.""" 

172 if params is None: 

173 params = {} 

174 allowed = {f.name for f in fields(cls)} 

175 extra = set(params) - allowed 

176 if extra: 

177 msg = ( 

178 f'{name} contains unsupported key(s) {sorted(extra)}; valid keys' 

179 f' are {sorted(allowed)}' 

180 ) 

181 raise ValueError(msg) 

182 return cls(**params) 

183 

184 

185class BARTModel: 

186 R""" 

187 BART model with a `stochtree`-compatible interface, powered by bartz. 

188 

189 This class mimics `stochtree.BARTModel` so that bartz can be used as a 

190 drop-in reference implementation for testing. The intersection of features 

191 is targeted: continuous regression (Gaussian outcome, identity link) and 

192 binary classification (probit link) on tabular covariates. 

193 

194 Use the same idiomatic pattern as `stochtree.BARTModel`:: 

195 

196 m = BARTModel() 

197 m.sample( 

198 X_train=X, y_train=y, X_test=X_test, 

199 num_gfr=0, num_mcmc=200, 

200 mean_forest_params={'sample_sigma2_leaf': False}, 

201 ) 

202 yhat = m.predict(X_new, terms='y_hat', type='mean') 

203 

204 See `GeneralParams` and `MeanForestParams` for the supported keys in the 

205 ``general_params`` and ``mean_forest_params`` dicts. 

206 

207 Notes 

208 ----- 

209 Differences from `stochtree`, by design: 

210 

211 - ``num_gfr`` has no default and must be set explicitly to ``0``. 

212 - ``mean_forest_params['sample_sigma2_leaf']`` must be ``False``. 

213 - ``mean_forest_params['max_depth']`` must be a non-negative integer at 

214 most ``16``; stochtree's ``-1`` (unbounded depth) sentinel is not 

215 accepted. 

216 - The deprecated ``general_params['probit_outcome_model']`` flag is not 

217 accepted; pass ``outcome_model=OutcomeModel('binary', 'probit')`` 

218 instead. 

219 - ``general_params['cutpoint_grid_size']`` is not accepted; bartz uses a 

220 fixed grid of 256 evenly-spaced bins per predictor. stochtree only 

221 uses this parameter for the GFR sampler, which bartz does not support. 

222 - Leaf-basis regression, random effects, heteroskedastic variance 

223 forests, and warm-starting from a previous model are not supported. 

224 - bartz uses single-precision floats, so outputs differ from stochtree 

225 at the float32 precision level. 

226 - ``general_params['random_seed']`` defaults to deterministic behavior 

227 (seed ``0``) when unset, whereas stochtree draws a random seed. This is 

228 intentional, to make repeated fits reproducible by default. 

229 

230 References 

231 ---------- 

232 Herren, A., Hahn, P. R., Murray, J., Carvalho, C. (2026). "StochTree: 

233 BART-based modeling in R and Python". arXiv:2512.12051. 

234 """ 

235 

236 # public, set by sample() 

237 sampled: bool 

238 """Whether `sample` has been called.""" 

239 

240 standardize: bool 

241 """Whether the outcome was standardized before fitting.""" 

242 

243 sample_sigma2_global: bool 

244 """Whether the global error variance is sampled (always ``True``).""" 

245 

246 probit_outcome_model: bool 

247 """Whether the model uses a binary outcome with probit link.""" 

248 

249 outcome_model: OutcomeModel 

250 """Outcome family and link specification used during fitting.""" 

251 

252 num_gfr: int 

253 """Number of grow-from-root iterations (always ``0``).""" 

254 

255 num_burnin: int 

256 """Number of MCMC burn-in iterations.""" 

257 

258 num_mcmc: int 

259 """Number of retained MCMC iterations per chain.""" 

260 

261 num_chains: int 

262 """Number of independent MCMC chains.""" 

263 

264 num_samples: int 

265 """Total number of retained posterior samples (``num_mcmc * num_chains``).""" 

266 

267 sigma2_init: FloatLike 

268 """Starting value of the global error variance actually used to seed the chain.""" 

269 

270 y_bar: Float32[Array, ''] 

271 """Mean used to standardize the outcome (``0`` if not standardized).""" 

272 

273 y_std: Float32[Array, ''] 

274 """Standard deviation used to standardize the outcome (``1`` if not standardized).""" 

275 

276 has_rfx: bool 

277 """Whether the model includes random effects (always ``False``).""" 

278 

279 include_mean_forest: bool 

280 """Whether the model includes a conditional mean forest (always ``True``).""" 

281 

282 include_variance_forest: bool 

283 """Whether the model includes a variance forest (always ``False``).""" 

284 

285 y_hat_train: Float32[Array, 'n num_samples'] 

286 """Posterior predictions at the training covariates, in the original outcome scale.""" 

287 

288 global_var_samples: Float32[Array, ' num_samples'] 

289 """Posterior samples of the global error variance. For probit binary regression, an array of ones.""" 

290 

291 y_hat_test: Float32[Array, 'm num_samples'] | None 

292 """Posterior predictions at `X_test` if it was supplied to `sample`, else `None`.""" 

293 

294 _bart: Bart 

295 _preprocessor: _PreprocessorBase | None 

296 

297 def __init__(self) -> None: 

298 self.sampled = False 

299 self._preprocessor = None 

300 

301 def is_sampled(self) -> bool: 

302 """Return whether `sample` has been called.""" 

303 return self.sampled 

304 

305 def _prepare_training_inputs( 

306 self, 

307 X_train: Real[ArrayLike, 'n p'] | DataFrame, 

308 y_train: Real[ArrayLike, ' n'] | Series, 

309 gp: GeneralParams, 

310 ) -> tuple[Real[Array, 'n p'], Real[Array, ' n'], Float32[Array, ' p'] | None]: 

311 """Coerce inputs and build variable weights, fitting the DataFrame preprocessor if any.""" 

312 y_train_arr = _coerce_response(y_train, name='y_train') 

313 

314 self._preprocessor = make_preprocessor(X_train) 

315 if self._preprocessor is None: 

316 X_train_arr = check_X(X_train, name='X_train') 

317 _, p = X_train_arr.shape 

318 varprob = check_variable_weights(gp.variable_weights, p) 

319 else: 

320 # The preprocessor decides the default weights: uniform over the 

321 # *original* columns split across each one-hot expansion (so every 

322 # original variable keeps an equal splitting budget), or `None` when 

323 # nothing expands (deferring to bartz's native uniform fast-path). 

324 weights_np = self._preprocessor.fit( 

325 X_train, variable_weights=gp.variable_weights 

326 ) 

327 X_train_np = self._preprocessor.transform(X_train) 

328 if X_train_np.shape[1] == 0: 328 ↛ 329line 328 didn't jump to line 329 because the condition on line 328 was never true

329 msg = 'X_train has no usable columns after preprocessing' 

330 raise ValueError(msg) 

331 X_train_arr = jnp.asarray(X_train_np) 

332 varprob = None if weights_np is None else jnp.asarray(weights_np) 

333 

334 n, _ = X_train_arr.shape 

335 if y_train_arr.shape[0] != n: 335 ↛ 336line 335 didn't jump to line 336 because the condition on line 335 was never true

336 msg = ( 

337 f'X_train and y_train length mismatch: X_train has {n} rows,' 

338 f' y_train has {y_train_arr.shape[0]} entries' 

339 ) 

340 raise ValueError(msg) 

341 return X_train_arr, y_train_arr, varprob 

342 

343 def sample( 

344 self, 

345 X_train: Real[ArrayLike, 'n p'] | DataFrame, 

346 y_train: Real[ArrayLike, ' n'] | Series, 

347 X_test: Real[ArrayLike, 'm p'] | DataFrame | None = None, 

348 observation_weights: Float[ArrayLike, ' n'] | Series | None = None, 

349 *, 

350 num_gfr: int, 

351 num_burnin: int = 0, 

352 num_mcmc: int = 100, 

353 general_params: Mapping[str, Any] | None = None, 

354 mean_forest_params: Mapping[str, Any] | None = None, 

355 bart_kwargs: Mapping[str, Any] = MappingProxyType({}), 

356 ) -> None: 

357 """Fit the model. 

358 

359 The signature mirrors `stochtree.BARTModel.sample`, restricted to the 

360 keyword arguments bartz supports. 

361 

362 Parameters 

363 ---------- 

364 X_train 

365 Training covariates with shape ``(n, p)``. 

366 y_train 

367 Training outcomes of length ``n``. 

368 X_test 

369 Optional test covariates; if given, predictions are cached on 

370 them in `y_hat_test`. 

371 observation_weights 

372 Optional positive per-observation weights scaling the residual 

373 variance (``y_i | - ~ N(mu(X_i), sigma^2 / w_i)``). 

374 num_gfr 

375 Number of grow-from-root iterations. Must be ``0``. 

376 num_burnin 

377 Number of MCMC burn-in iterations. 

378 num_mcmc 

379 Number of retained MCMC iterations per chain. 

380 general_params 

381 Optional override for the keys of `GeneralParams`. 

382 mean_forest_params 

383 Override for the keys of `MeanForestParams`. Must explicitly 

384 disable ``sample_sigma2_leaf``. 

385 bart_kwargs 

386 Additional arguments forwarded to `bartz.Bart`. Use this to set 

387 ``devices`` and ``rm_const=False`` when wrapping `sample` in 

388 `jax.jit`. 

389 

390 Raises 

391 ------ 

392 NotImplementedError 

393 If ``num_gfr`` is non-zero. 

394 """ 

395 if num_gfr != 0: 

396 msg = ( 

397 'num_gfr must be 0; the grow-from-root sampler is not available' 

398 ' in bartz.' 

399 ) 

400 raise NotImplementedError(msg) 

401 

402 gp = build_dataclass(GeneralParams, general_params, 'general_params') 

403 mfp = build_dataclass( 

404 MeanForestParams, mean_forest_params, 'mean_forest_params' 

405 ) 

406 

407 is_probit = gp.outcome_model.outcome == 'binary' 

408 

409 X_train_arr, y_train_arr, variable_weights = self._prepare_training_inputs( 

410 X_train, y_train, gp 

411 ) 

412 

413 y_bar, y_std, y_for_bartz = standardize_y( 

414 y_train_arr, is_probit, gp.standardize 

415 ) 

416 

417 bart_num_chains = None if gp.num_chains == 1 else gp.num_chains 

418 

419 # variance of the standardized residual, matching stochtree 

420 # (np.var(resid_train) with ddof=0). For standardize=True it is exactly 

421 # 1.0; we hardcode that so the value stays trace-time concrete. 

422 if is_probit: 

423 var_resid_train: FloatLike = 1.0 # bartz ignores σ² for binary 

424 elif gp.standardize: 424 ↛ 427line 424 didn't jump to line 427 because the condition on line 424 was always true

425 var_resid_train = 1.0 

426 else: 

427 var_resid_train = jnp.var(y_for_bartz) 

428 

429 # leaf-prior: bartz uses sigma_mu = tau_num / (k * sqrt(num_trees)); 

430 # stochtree's sigma2_leaf is the leaf-variance prior. Hold k=2 and solve 

431 # for tau_num so that the two parameterizations agree. 

432 bartz_k = 2.0 

433 sigma2_leaf_init = resolve_sigma2_leaf_init( 

434 mfp.sigma2_leaf_init, mfp.num_trees, is_probit, var_resid_train 

435 ) 

436 tau_num_arg = bartz_k * jnp.sqrt(mfp.num_trees * sigma2_leaf_init) 

437 

438 if is_probit: 

439 # stochtree pins σ²=1 for probit; bartz binary branch ignores the 

440 # variance prior, so we leave the scale/init at their 'auto' 

441 # defaults (bartz rejects explicit values for binary outcomes). 

442 sigma_df_arg: FloatLike = 3.0 

443 sigma_scale_arg: FloatLike | Literal['auto'] = 'auto' 

444 sigma_init_arg: FloatLike | Literal['auto'] = 'auto' 

445 sigma2_init_stored: FloatLike = 1.0 

446 else: 

447 sigma_df_arg, sigma_scale_arg, sigma_init_arg, sigma2_init_stored = ( 

448 resolve_variance_prior( 

449 gp.sigma2_global_shape, 

450 gp.sigma2_global_scale, 

451 gp.sigma2_init, 

452 var_resid_train, 

453 ) 

454 ) 

455 

456 binner = partial(RangeEvenBinner, max_bins=256) 

457 

458 seed = 0 if gp.random_seed is None else gp.random_seed 

459 

460 kwargs: dict = dict( 

461 x_train=X_train_arr.T, 

462 y_train=y_for_bartz, 

463 outcome_type='binary' if is_probit else 'continuous', 

464 binner=binner, 

465 varprob=variable_weights, 

466 sigma_df=sigma_df_arg, 

467 sigma_scale=sigma_scale_arg, 

468 sigma_init=sigma_init_arg, 

469 k=bartz_k, 

470 power=mfp.beta, 

471 base=mfp.alpha, 

472 tau_num=tau_num_arg, 

473 error_scale=observation_weights, 

474 num_trees=mfp.num_trees, 

475 n_save=num_mcmc, 

476 n_burn=num_burnin, 

477 n_skip=gp.keep_every, 

478 printevery=None, 

479 num_chains=bart_num_chains, 

480 seed=seed, 

481 maxdepth=mfp.max_depth + 1, 

482 ) 

483 kwargs.update(bart_kwargs) 

484 # match stochtree's gating: only acceptance-time veto on 

485 # min_samples_leaf, no per-leaf affluence filter (stochtree picks 

486 # leaves uniformly over all of them). User-supplied init_kw values 

487 # win on conflicts. 

488 kwargs = dict( 

489 kwargs, 

490 init_kw=dict( 

491 { 

492 'min_points_per_leaf': mfp.min_samples_leaf, 

493 'min_points_per_decision_node': None, 

494 }, 

495 **kwargs.get('init_kw', {}), 

496 ), 

497 ) 

498 self._bart = Bart(**kwargs) 

499 self._finalize_sample( 

500 outcome_model=gp.outcome_model, 

501 num_burnin=num_burnin, 

502 num_mcmc=num_mcmc, 

503 num_chains=gp.num_chains, 

504 sigma2_init=sigma2_init_stored, 

505 y_bar=y_bar, 

506 y_std=y_std, 

507 standardize=gp.standardize, 

508 X_test=X_test, 

509 ) 

510 

511 def _finalize_sample( 

512 self, 

513 *, 

514 outcome_model: OutcomeModel, 

515 num_burnin: int, 

516 num_mcmc: int, 

517 num_chains: int, 

518 sigma2_init: FloatLike, 

519 y_bar: Float32[Array, ''], 

520 y_std: Float32[Array, ''], 

521 standardize: bool, 

522 X_test: Real[ArrayLike, 'm p'] | DataFrame | None, 

523 ) -> None: 

524 """Populate the public attributes after `_bart` has been constructed.""" 

525 is_probit = outcome_model.outcome == 'binary' 

526 self.sampled = True 

527 self.standardize = standardize 

528 self.sample_sigma2_global = True 

529 self.probit_outcome_model = is_probit 

530 self.outcome_model = outcome_model 

531 self.num_gfr = 0 

532 self.num_burnin = num_burnin 

533 self.num_mcmc = num_mcmc 

534 self.num_chains = num_chains 

535 self.num_samples = num_mcmc * num_chains 

536 self.sigma2_init = sigma2_init 

537 self.y_bar = y_bar 

538 self.y_std = y_std 

539 self.has_rfx = False 

540 self.include_mean_forest = True 

541 self.include_variance_forest = False 

542 

543 # cached outputs in stochtree's (n, num_samples) layout, original scale 

544 self.y_hat_train = self._predict_y_hat_internal('train') 

545 if X_test is not None: 

546 self.y_hat_test = self._predict_y_hat_internal(self._prepare_x(X_test).T) 

547 else: 

548 self.y_hat_test = None 

549 

550 if is_probit: 

551 self.global_var_samples = jnp.ones((self.num_samples,)) 

552 else: 

553 sigma = self._bart.get_error_sdev() 

554 self.global_var_samples = (sigma * y_std) ** 2 

555 

556 @overload 

557 def predict( 

558 self, 

559 X: Real[ArrayLike, 'm p'] | DataFrame, 

560 *, 

561 type: Literal['posterior', 'mean'] = 'posterior', 

562 terms: Literal['y_hat', 'mean_forest'], 

563 scale: Literal['linear', 'probability', 'class'] = 'linear', 

564 ) -> Shaped[Array, 'm num_samples'] | Shaped[Array, ' m']: ... 

565 

566 @overload 

567 def predict( 

568 self, 

569 X: Real[ArrayLike, 'm p'] | DataFrame, 

570 *, 

571 type: Literal['posterior', 'mean'] = 'posterior', 

572 terms: Literal['all'] = 'all', 

573 scale: Literal['linear', 'probability', 'class'] = 'linear', 

574 ) -> dict[str, Shaped[Array, 'm num_samples']] | dict[str, Shaped[Array, ' m']]: ... 

575 

576 @overload 

577 def predict( 

578 self, 

579 X: Real[ArrayLike, 'm p'] | DataFrame, 

580 *, 

581 type: Literal['posterior', 'mean'] = 'posterior', 

582 terms: Sequence[Literal['y_hat', 'mean_forest', 'all']], 

583 scale: Literal['linear', 'probability', 'class'] = 'linear', 

584 ) -> ( 

585 Shaped[Array, 'm num_samples'] 

586 | Shaped[Array, ' m'] 

587 | dict[str, Shaped[Array, 'm num_samples']] 

588 | dict[str, Shaped[Array, ' m']] 

589 ): ... 

590 

591 def predict( 

592 self, 

593 X: Real[ArrayLike, 'm p'] | DataFrame, 

594 *, 

595 type: Literal['posterior', 'mean'] = 'posterior', # noqa: A002 

596 terms: Literal['y_hat', 'mean_forest', 'all'] 

597 | Sequence[Literal['y_hat', 'mean_forest', 'all']] = 'all', 

598 scale: Literal['linear', 'probability', 'class'] = 'linear', 

599 ) -> ( 

600 Shaped[Array, 'm num_samples'] 

601 | Shaped[Array, ' m'] 

602 | dict[str, Shaped[Array, 'm num_samples']] 

603 | dict[str, Shaped[Array, ' m']] 

604 ): 

605 """Predict at new covariates. 

606 

607 Parameters 

608 ---------- 

609 X 

610 New covariates with shape ``(m, p)``. 

611 type 

612 ``'posterior'`` returns one prediction per posterior sample, with 

613 shape ``(m, num_samples)``. ``'mean'`` averages the posterior 

614 samples, returning a vector of shape ``(m,)``. 

615 terms 

616 One of ``'y_hat'``, ``'mean_forest'``, ``'all'``, or a list. Since 

617 random effects and a variance forest are not supported, ``'y_hat'`` 

618 and ``'mean_forest'`` produce the same result. 

619 scale 

620 For probit binary regression: ``'linear'`` returns the eta values, 

621 ``'probability'`` returns ``Phi(eta)``, ``'class'`` returns 0 / 1. 

622 Only ``'linear'`` is valid for continuous outcomes. 

623 

624 Returns 

625 ------- 

626 Either a single jax array (for a single requested term) or a dict keyed by term name. 

627 

628 Raises 

629 ------ 

630 NotSampledError 

631 If `sample` has not been called yet. 

632 """ 

633 if not self.sampled: 

634 msg = ( 

635 "This BARTModel instance is not fitted yet. Call 'sample' before" 

636 ' using this model.' 

637 ) 

638 raise NotSampledError(msg) 

639 terms_tuple = check_predict_args(type, scale, terms, self.probit_outcome_model) 

640 

641 pred = self._predict_y_hat_internal(self._prepare_x(X).T) 

642 

643 if self.probit_outcome_model and scale in ('probability', 'class'): 

644 prob = ndtr(pred) 

645 pred_out = jnp.where(prob < 0.5, 0, 1) if scale == 'class' else prob 

646 else: 

647 pred_out = pred 

648 

649 if type == 'mean': 

650 pred_out = jnp.mean(pred_out, axis=1) 

651 

652 wants_y_hat = ('y_hat' in terms_tuple) or ('all' in terms_tuple) 

653 wants_mean_forest = ('mean_forest' in terms_tuple) or ('all' in terms_tuple) 

654 single = sum([wants_y_hat, wants_mean_forest]) == 1 

655 if single: 

656 return pred_out 

657 result: dict[str, Shaped[Array, '...']] = {} 

658 if wants_y_hat: 658 ↛ 660line 658 didn't jump to line 660 because the condition on line 658 was always true

659 result['y_hat'] = pred_out 

660 if wants_mean_forest: 660 ↛ 662line 660 didn't jump to line 662 because the condition on line 660 was always true

661 result['mean_forest_predictions'] = pred_out 

662 return result 

663 

664 def _prepare_x(self, X: Real[ArrayLike, 'm p'] | DataFrame) -> Real[Array, 'm p']: 

665 """Convert covariates to a 2-D jax array, replaying the fitted preprocessor if any.""" 

666 if self._preprocessor is None: 

667 return check_X(X) 

668 if make_preprocessor(X) is None: 

669 msg = ( 

670 'this model was fit on a DataFrame, so prediction covariates must' 

671 ' also be a pandas/polars DataFrame with the same columns; got a' 

672 ' non-DataFrame. Passing a raw array would bypass the fitted' 

673 ' preprocessing (e.g. one-hot encoding) and silently misalign the' 

674 ' features.' 

675 ) 

676 raise TypeError(msg) 

677 return jnp.asarray(self._preprocessor.transform(X)) 

678 

679 def _predict_y_hat_internal( 

680 self, x: Real[ArrayLike, 'p m'] | Literal['train'] 

681 ) -> Float32[Array, 'm num_samples']: 

682 """Return predictions on the original outcome scale, layout ``(m, num_samples)``.""" 

683 latent = self._bart.predict(x, kind=PredictKind.latent_samples) 

684 if self.probit_outcome_model: 

685 # bartz integrates the binary offset into latent; result already on probit scale. 

686 return latent.T 

687 if self.standardize: 687 ↛ 689line 687 didn't jump to line 689 because the condition on line 687 was always true

688 return (latent * self.y_std + self.y_bar).T 

689 return latent.T 

690 

691 

692def standardize_y( 

693 y_train: Real[ArrayLike, ' n'], is_probit: bool, standardize: bool 

694) -> tuple[Float32[Array, ''], Float32[Array, ''], Float32[Array, ' n']]: 

695 """Return ``(y_bar, y_std, y_for_bartz)`` matching stochtree's standardization.""" 

696 y = jnp.asarray(y_train, jnp.float32) 

697 if is_probit: 

698 return ndtri(y.mean()), jnp.float32(1.0), (y != 0).astype(jnp.float32) 

699 if standardize: 699 ↛ 704line 699 didn't jump to line 704 because the condition on line 699 was always true

700 y_bar = y.mean() 

701 y_std_val = y.std() 

702 y_std = jnp.where(y_std_val > 0, y_std_val, 1.0) 

703 return y_bar, y_std, (y - y_bar) / y_std 

704 return jnp.float32(0.0), jnp.float32(1.0), y 

705 

706 

707def resolve_sigma2_leaf_init( 

708 sigma2_leaf_init: FloatLike | None, 

709 num_trees: int, 

710 is_probit: bool, 

711 var_resid_train: FloatLike, 

712) -> FloatLike: 

713 """Default `sigma2_leaf_init` per stochtree: probit→2/num_trees, continuous→var(resid)/num_trees.""" 

714 if sigma2_leaf_init is not None: 

715 return sigma2_leaf_init 

716 if is_probit: 

717 return 2.0 / num_trees 

718 return var_resid_train / num_trees 

719 

720 

721def resolve_variance_prior( 

722 shape: FloatLike, 

723 scale: FloatLike, 

724 sigma2_init: FloatLike | None, 

725 var_resid_train: FloatLike, 

726) -> tuple[Float32[Array, ''], Float32[Array, ''], Float32[Array, ''], FloatLike]: 

727 """Translate stochtree's IG(shape, scale) prior to bartz's error variance prior. 

728 

729 The IG(shape, scale) prior on σ² is the scaled-inverse-χ² with 

730 ``sigma_df = 2*shape`` and prior harmonic mean ``square(sigma_scale) = 

731 scale/shape``; the chain starts at `sigma2_init` (default 

732 ``var(resid_train)``), decoupled from the prior. The mapping is branchless so 

733 `shape` / `scale` may be traced; the unrepresentable IG(0, scale>0) (positive 

734 rate, zero df) yields a NaN that surfaces downstream rather than an error. 

735 

736 Parameters 

737 ---------- 

738 shape 

739 Stochtree's ``sigma2_global_shape``. 

740 scale 

741 Stochtree's ``sigma2_global_scale``. 

742 sigma2_init 

743 Stochtree's ``sigma2_init``. If `None`, defaults to `var_resid_train`. 

744 var_resid_train 

745 Variance of the residual, the default chain start for σ². 

746 

747 Returns 

748 ------- 

749 sigma_df : Float32[Array, ''] 

750 Degrees of freedom of bartz's error variance prior. 

751 sigma_scale : Float32[Array, ''] 

752 Scale of bartz's prior (sqrt of the prior harmonic mean of the variance). 

753 sigma_init : Float32[Array, ''] 

754 Initial error standard deviation seeding the chain. 

755 sigma2_init_stored : FloatLike 

756 The chain starting value of σ², suitable for ``BARTModel.sigma2_init``. 

757 """ 

758 shape = jnp.asarray(shape, jnp.float32) 

759 scale = jnp.asarray(scale, jnp.float32) 

760 sigma2_start = sigma2_init if sigma2_init is not None else var_resid_train 

761 sigma_init = jnp.sqrt(jnp.asarray(sigma2_start, jnp.float32)) 

762 # IG(shape, scale) <=> scaled-inv-chi2(df=2*shape, harmonic mean=scale/shape). 

763 # The `scale > 0` guard keeps IG(0, 0) at harmonic mean 0 (avoiding 0/0) while 

764 # letting IG(0, scale>0) overflow to inf -> NaN rate, flagging it as invalid. 

765 harmonic_mean = jnp.where(scale > 0, scale / shape, 0.0) 

766 return 2.0 * shape, jnp.sqrt(harmonic_mean), sigma_init, sigma2_start 

767 

768 

769def check_variable_weights( 

770 variable_weights: Float[ArrayLike, ' p'] | None, p: int 

771) -> Float32[Array, ' p'] | None: 

772 """Validate `variable_weights`, returning the jax array (or None).""" 

773 if variable_weights is None: 

774 return None 

775 arr = jnp.asarray(variable_weights, jnp.float32) 

776 if arr.shape != (p,): 776 ↛ 777line 776 didn't jump to line 777 because the condition on line 776 was never true

777 msg = f'variable_weights must have shape (p,)=({p},), got {arr.shape}' 

778 raise ValueError(msg) 

779 return arr 

780 

781 

782def check_predict_args( 

783 type_: Literal['posterior', 'mean'], 

784 scale: Literal['linear', 'probability', 'class'], 

785 terms: Literal['y_hat', 'mean_forest', 'all'] 

786 | Sequence[Literal['y_hat', 'mean_forest', 'all']], 

787 probit_outcome_model: bool, 

788) -> tuple[str, ...]: 

789 """Validate `BARTModel.predict` arguments, returning the normalized terms tuple.""" 

790 if scale not in ('linear', 'probability', 'class'): 790 ↛ 791line 790 didn't jump to line 791 because the condition on line 790 was never true

791 msg = f"scale must be 'linear', 'probability', or 'class'; got {scale!r}" 

792 raise ValueError(msg) 

793 if type_ not in ('posterior', 'mean'): 793 ↛ 794line 793 didn't jump to line 794 because the condition on line 793 was never true

794 msg = f"type must be 'posterior' or 'mean'; got {type_!r}" 

795 raise ValueError(msg) 

796 if not probit_outcome_model and scale != 'linear': 796 ↛ 797line 796 didn't jump to line 797 because the condition on line 796 was never true

797 msg = ( 

798 "scale must be 'linear' for non-probit (continuous) regression;" 

799 f' got {scale!r}' 

800 ) 

801 raise ValueError(msg) 

802 if type_ == 'mean' and scale == 'class': 802 ↛ 803line 802 didn't jump to line 803 because the condition on line 802 was never true

803 msg = "scale='class' is incompatible with type='mean'" 

804 raise ValueError(msg) 

805 terms_tuple = (terms,) if isinstance(terms, str) else tuple(terms) 

806 for t in terms_tuple: 

807 if t not in ('y_hat', 'mean_forest', 'all'): 807 ↛ 808line 807 didn't jump to line 808 because the condition on line 807 was never true

808 msg = f'unknown term {t!r}; valid terms are y_hat, mean_forest, all' 

809 raise ValueError(msg) 

810 if scale == 'class' and set(terms_tuple) != {'y_hat'}: 

811 # match stochtree: 'class' converts only the single 'y_hat' term, so it 

812 # rejects 'mean_forest' and 'all' (the latter also pulls in mean_forest) 

813 msg = "scale='class' is only supported when requesting a single 'y_hat' term" 

814 raise ValueError(msg) 

815 return terms_tuple 

816 

817 

818def check_X( 

819 X: Real[ArrayLike, 'n p'] | DataFrame, *, name: str = 'X' 

820) -> Real[Array, 'n p']: 

821 """Convert a DataFrame/array-like to a 2-D jax array in ``(n, p)`` layout.""" 

822 if isinstance(X, DataFrame): 822 ↛ 823line 822 didn't jump to line 823 because the condition on line 822 was never true

823 X = X.to_numpy() 

824 arr = jnp.asarray(X) 

825 if arr.ndim == 1: 825 ↛ 826line 825 didn't jump to line 826 because the condition on line 825 was never true

826 arr = arr[:, None] 

827 if arr.ndim != 2: 827 ↛ 828line 827 didn't jump to line 828 because the condition on line 827 was never true

828 msg = f'{name} must be 2D (n, p); got shape {arr.shape}' 

829 raise ValueError(msg) 

830 return arr 

831 

832 

833def _coerce_response( 

834 y: Real[ArrayLike, ' n'] | Series, *, name: str 

835) -> Real[Array, ' n']: 

836 """Convert a Series/array-like response to a 1-D jax array.""" 

837 if isinstance(y, Series): 837 ↛ 838line 837 didn't jump to line 838 because the condition on line 837 was never true

838 y = y.to_numpy() 

839 arr = jnp.asarray(y) 

840 if arr.ndim != 1: 840 ↛ 841line 840 didn't jump to line 841 because the condition on line 840 was never true

841 msg = f'{name} must be 1D (n,); got shape {arr.shape}' 

842 raise ValueError(msg) 

843 return arr