Coverage for src/bartz/BART/_gbart.py: 99%

174 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/BART/_gbart.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package.""" 

26 

27from collections.abc import Mapping 

28from functools import cached_property, partial 

29from types import MappingProxyType 

30from typing import Any, Literal 

31 

32import jax.numpy as jnp 

33from equinox import Module, field 

34from jax.scipy.special import ndtr 

35from jaxtyping import Array, Float, Float32, Int32, Key, Real, Shaped 

36 

37from bartz._interface import ( 

38 ArrayLike, 

39 Bart, 

40 DataFrame, 

41 FloatLike, 

42 PredictKind, 

43 Series, 

44 SparseConfig, 

45 _process_predictor_input, 

46 _process_response_input, 

47) 

48from bartz._jaxext.scipy.stats import invgamma 

49from bartz.mcmcloop import BurninTrace, MainTrace 

50from bartz.mcmcstep._axes import chain_to_axis, chain_vmap_axes 

51from bartz.mcmcstep._state import State 

52from bartz.prepcovars import GivenSplitsBinner, RangeEvenBinner, UniqueQuantileBinner 

53from bartz.prepcovars._prepcovars import _sigma2_from_ols 

54 

55 

56class mc_gbart(Module): 

57 R""" 

58 Nonparametric regression with Bayesian Additive Regression Trees (BART). 

59 

60 Regress `y_train` on `x_train` with a latent mean function represented as 

61 a sum of decision trees [2]_. The inference is carried out by sampling the 

62 posterior distribution of the tree ensemble with an MCMC. 

63 

64 Parameters 

65 ---------- 

66 x_train 

67 The training predictors. 

68 y_train 

69 The training responses. 

70 x_test 

71 The test predictors. 

72 type 

73 The type of regression. 'wbart' for continuous regression, 'pbart' for 

74 binary regression with probit link. 

75 sparse 

76 Whether to activate variable selection on the predictors as done in 

77 [1]_. 

78 theta 

79 a 

80 b 

81 rho 

82 Hyperparameters of the sparsity prior used for variable selection. 

83 

84 The prior distribution on the choice of predictor for each decision rule 

85 is 

86 

87 .. math:: 

88 (s_1, \ldots, s_p) \sim 

89 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p). 

90 

91 If `theta` is not specified, it's a priori distributed according to 

92 

93 .. math:: 

94 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim 

95 \operatorname{Beta}(\mathtt{a}, \mathtt{b}). 

96 

97 If not specified, `rho` is set to the number of predictors p. To tune 

98 the prior, consider setting a lower `rho` to prefer more sparsity. 

99 If setting `theta` directly, it should be in the ballpark of p or lower 

100 as well. 

101 augment 

102 Whether to account exactly for the decision rules forbidden by the 

103 ancestors of each node when updating the variable selection 

104 probabilities, using data augmentation. Only relevant if ``sparse=True``. 

105 Like the ``augment`` option of R BART3, but sampling the exact full 

106 conditional rather than substituting expected counts. 

107 varprob 

108 The probability distribution over the `p` predictors for choosing a 

109 predictor to split on in a decision node a priori. Must be > 0. It does 

110 not need to be normalized to sum to 1. If not specified, use a uniform 

111 distribution. If ``sparse=True``, this is used as initial value for the 

112 MCMC. 

113 xinfo 

114 A matrix with the cutpoins to use to bin each predictor. If not 

115 specified, it is generated automatically according to `usequants` and 

116 `numcut`. 

117 

118 Each row shall contain a sorted list of cutpoints for a predictor. If 

119 there are less cutpoints than the number of columns in the matrix, 

120 fill the remaining cells with NaN. 

121 

122 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

123 usequants 

124 Whether to use predictors quantiles instead of a uniform grid to bin 

125 predictors. Ignored if `xinfo` is specified. 

126 rm_const 

127 How to treat predictors with no associated decision rules (i.e., there 

128 are no available cutpoints for that predictor). If `True` (default), 

129 they are ignored. If `False`, an error is raised if there are any. 

130 sigest 

131 An estimate of the residual standard deviation on `y_train`, used to set 

132 `lambda_`. If not specified, it is estimated by linear regression (with 

133 intercept, and without taking into account `w`). Ignored if `lambda_` is 

134 specified. 

135 sigdf 

136 The degrees of freedom of the scaled inverse-chisquared prior on the 

137 noise variance. 

138 sigquant 

139 The quantile of the prior on the noise variance that shall match 

140 `sigest` to set the scale of the prior. Ignored if `lambda_` is specified. 

141 k 

142 The inverse scale of the prior standard deviation on the latent mean 

143 function, relative to half the observed range of `y_train`. If `y_train` 

144 has less than two elements, `k` is ignored and the scale is set to 1. 

145 power 

146 base 

147 Parameters of the prior on tree node generation. The probability that a 

148 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) ** 

149 power``. 

150 lambda_ 

151 The prior harmonic mean of the error variance. (The harmonic mean of x 

152 is 1/mean(1/x).) If not specified, it is set based on `sigest` and 

153 `sigquant`. 

154 tau_num 

155 The numerator in the expression that determines the prior standard 

156 deviation of leaves. If not specified, default to ``(max(y_train) - 

157 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for 

158 continuous regression, and 3 for binary regression. 

159 offset 

160 The prior mean of the latent mean function. If not specified, it is set 

161 to the mean of `y_train` for continuous regression, and to 

162 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty, 

163 `offset` is set to 0. With binary regression, if `y_train` is all 

164 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or 

165 ``Phi^-1(n/(n+1))``, respectively. 

166 w 

167 Coefficients that rescale the error standard deviation on each 

168 datapoint. Not specifying `w` is equivalent to setting it to 1 for all 

169 datapoints. Note: `w` is ignored in the automatic determination of 

170 `sigest`, so either the weights should be O(1), or `sigest` should be 

171 specified by the user. 

172 ntree 

173 The number of trees used to represent the latent mean function. By 

174 default 200 for continuous regression and 50 for binary regression. 

175 numcut 

176 If `usequants` is `False`: the exact number of cutpoints used to bin the 

177 predictors, ranging between the minimum and maximum observed values 

178 (excluded). 

179 

180 If `usequants` is `True`: the maximum number of cutpoints to use for 

181 binning the predictors. Each predictor is binned such that its 

182 distribution in `x_train` is approximately uniform across bins. The 

183 number of bins is at most the number of unique values appearing in 

184 `x_train`, or ``numcut + 1``. 

185 

186 Before running the algorithm, the predictors are compressed to the 

187 smallest integer type that fits the bin indices, so `numcut` is best set 

188 to the maximum value of an unsigned integer type, like 255. 

189 

190 Ignored if `xinfo` is specified. 

191 ndpost 

192 The number of MCMC samples to save, after burn-in. `ndpost` is the 

193 total number of samples across all chains. `ndpost` is rounded up to the 

194 first multiple of `mc_cores`. 

195 nskip 

196 The number of initial MCMC samples to discard as burn-in. This number 

197 of samples is discarded from each chain. 

198 keepevery 

199 The thinning factor for the MCMC samples, after burn-in. By default, 1 

200 for continuous regression and 10 for binary regression. 

201 printevery 

202 The number of iterations (including thinned-away ones) between each log 

203 line. Set to `None` to disable logging. ^C interrupts the MCMC only 

204 every `printevery` iterations, so with logging disabled it's impossible 

205 to kill the MCMC conveniently. 

206 mc_cores 

207 The number of independent MCMC chains. 

208 seed 

209 The seed for the random number generator. 

210 bart_kwargs 

211 Additional arguments passed to `bartz.Bart`. 

212 

213 Notes 

214 ----- 

215 This interface imitates the function ``mc_gbart`` from the R package `BART3 

216 <https://github.com/rsparapa/bnptools>`_, but with these differences: 

217 

218 - If ``usequants=False``, R BART3 switches to quantiles anyway if there are 

219 less predictor values than the required number of bins, while bartz 

220 always follows the specification. 

221 - Some functionality is missing. 

222 - The error variance parameter is called `lambda_` instead of `lambda`, 

223 since the latter is a reserved word in Python. 

224 - There are some additional attributes, and some missing. 

225 - The trees have a maximum depth of 6. 

226 - `rm_const` refers to predictors without decision rules instead of 

227 predictors that are constant in `x_train`. 

228 - If `rm_const=True` and some variables are dropped, the predictors 

229 matrix/dataframe passed to `predict` should still include them. 

230 

231 References 

232 ---------- 

233 .. [1] Linero, Antonio R. (2018). "Bayesian Regression Trees for 

234 High-Dimensional Prediction and Variable Selection". In: Journal of the 

235 American Statistical Association 113.522, pp. 626-636. 

236 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART: 

237 Bayesian additive regression trees," The Annals of Applied Statistics, 

238 Ann. Appl. Stat. 4(1), 266-298, (March 2010). 

239 """ 

240 

241 _bart: Bart 

242 _x_train_fmt: Any = field(static=True, default=None) 

243 _yhat_test: Float32[Array, 'ndpost m'] | None = None 

244 

245 sigest: Float32[Array, ''] | None = None 

246 """The estimated standard deviation of the error used to set `lambda_`.""" 

247 

248 def __init__( 

249 self, 

250 x_train: Real[ArrayLike, 'n p'] | DataFrame, 

251 y_train: Float32[ArrayLike, ' n'] | Series, 

252 *, 

253 x_test: Real[ArrayLike, 'm p'] | DataFrame | None = None, 

254 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002 

255 sparse: bool = False, 

256 theta: FloatLike | None = None, 

257 a: FloatLike = 0.5, 

258 b: FloatLike = 1.0, 

259 rho: FloatLike | None = None, 

260 augment: bool = False, 

261 varprob: Float[ArrayLike, ' p'] | None = None, 

262 xinfo: Float[ArrayLike, 'p ncut'] | None = None, 

263 usequants: bool = False, 

264 rm_const: bool = True, 

265 sigest: FloatLike | None = None, 

266 sigdf: FloatLike = 3.0, 

267 sigquant: FloatLike = 0.9, 

268 k: FloatLike = 2.0, 

269 power: FloatLike = 2.0, 

270 base: FloatLike = 0.95, 

271 lambda_: FloatLike | None = None, 

272 tau_num: FloatLike | None = None, 

273 offset: FloatLike | None = None, 

274 w: Float[ArrayLike, ' n'] | Series | None = None, 

275 ntree: int | None = None, 

276 numcut: int = 100, 

277 ndpost: int = 1000, 

278 nskip: int = 100, 

279 keepevery: int | None = None, 

280 printevery: int | None = 100, 

281 mc_cores: int = 2, 

282 seed: int | Key[Array, ''] = 0, 

283 bart_kwargs: Mapping = MappingProxyType({}), 

284 ) -> None: 

285 # set defaults that depend on type of regression 

286 if keepevery is None: 

287 keepevery = 10 if type == 'pbart' else 1 

288 if ntree is None: 

289 ntree = 50 if type == 'pbart' else 200 

290 

291 # pre-process the data to numeric arrays once, so the OLS estimate of 

292 # `sigest` and `Bart` share a single copy of the (memory-heavy) X matrix. 

293 # `Bart` records the format as plain arrays, so `predict` re-implements 

294 # the input-format consistency check against the original format here. 

295 x_train, self._x_train_fmt = _process_bart3_predictor_input(x_train) 

296 y_train = _process_response_input(y_train) 

297 

298 # map the BART3 error-variance settings to Bart's sigma prior, estimating 

299 # `sigest` by linear regression on x_train when needed 

300 sigma_kw, self.sigest = _resolve_sigma_prior( 

301 x_train, 

302 y_train, 

303 type=type, 

304 sigest=sigest, 

305 sigdf=sigdf, 

306 sigquant=sigquant, 

307 lambda_=lambda_, 

308 ) 

309 

310 # convert to per-chain n_save for Bart 

311 num_chains = None if mc_cores == 1 else mc_cores 

312 actual_num_chains = num_chains or 1 

313 n_save = ndpost // actual_num_chains + bool(ndpost % actual_num_chains) 

314 

315 # translate xinfo/usequants/numcut to a binner factory 

316 if xinfo is not None: 

317 binner = partial(GivenSplitsBinner, xinfo=jnp.asarray(xinfo)) 

318 elif usequants: 

319 binner = partial( 

320 UniqueQuantileBinner, max_bins=numcut + 1, max_subsample=None 

321 ) 

322 else: 

323 binner = partial(RangeEvenBinner, max_bins=numcut + 1) 

324 

325 # set most calling arguments for Bart 

326 kwargs: dict = dict( 

327 x_train=x_train, 

328 y_train=y_train, 

329 outcome_type=dict(wbart='continuous', pbart='binary')[type], 

330 sparse=SparseConfig( 

331 enabled=sparse, theta=theta, a=a, b=b, rho=rho, augment=augment 

332 ), 

333 varprob=varprob, 

334 binner=binner, 

335 rm_const=rm_const, 

336 **sigma_kw, 

337 k=k, 

338 power=power, 

339 base=base, 

340 tau_num=tau_num, 

341 offset=offset, 

342 error_scale=w, 

343 num_trees=ntree, 

344 n_save=n_save, 

345 n_burn=nskip, 

346 n_skip=keepevery, 

347 printevery=printevery, 

348 seed=seed, 

349 maxdepth=6, 

350 num_chains=num_chains, 

351 ) 

352 

353 # default min_points_per_leaf to 5 (unless set by the user) to match 

354 # BART3's hard-coded nl>=5 && nr>=5 birth check. 

355 # min_points_per_decision_node keeps the Bart default of 10 

356 # (= 2 * min_points_per_leaf): it makes the proposal efficient by not 

357 # trying to grow leaves too small to split, without changing the target 

358 # posterior, which thus matches BART3. 

359 if 'min_points_per_leaf' not in bart_kwargs.get('init_kw', {}): 

360 bart_kwargs = dict( 

361 bart_kwargs, 

362 init_kw=dict(bart_kwargs.get('init_kw', {}), min_points_per_leaf=5), 

363 ) 

364 

365 # add user arguments 

366 kwargs.update(bart_kwargs) 

367 

368 # invoke Bart 

369 self._bart = Bart(**kwargs) 

370 

371 # predict at test points 

372 if x_test is not None: 

373 self._yhat_test = self.predict(x_test) 

374 

375 # Public attributes from Bart 

376 

377 @property 

378 def ndpost(self) -> int: 

379 """The number of MCMC samples saved, after burn-in.""" 

380 return self._bart.ndpost 

381 

382 @property 

383 def offset(self) -> Float32[Array, '']: 

384 """The prior mean of the latent mean function.""" 

385 return self._bart.offset 

386 

387 # Private attributes from Bart 

388 

389 @property 

390 def _main_trace(self) -> MainTrace: 

391 return self._bart._main_trace # noqa: SLF001 

392 

393 @property 

394 def _burnin_trace(self) -> BurninTrace: 

395 return self._bart._burnin_trace # noqa: SLF001 

396 

397 @property 

398 def _mcmc_state(self) -> State: 

399 return self._bart._mcmc_state # noqa: SLF001 

400 

401 @property 

402 def _splits(self) -> Real[Array, 'p max_num_splits']: 

403 return self._bart._binner._splits # noqa: SLF001 

404 

405 # Properties 

406 

407 @property 

408 def yhat_test(self) -> Float32[Array, 'ndpost m'] | None: 

409 """The conditional posterior mean at `x_test` for each MCMC iteration.""" 

410 return self._yhat_test 

411 

412 @cached_property 

413 def prob_test(self) -> Float32[Array, 'ndpost m'] | None: 

414 """The posterior probability of y being True at `x_test` for each MCMC iteration.""" 

415 if self._yhat_test is None or self._mcmc_state.binary_y is None: 

416 return None 

417 return ndtr(self._yhat_test) 

418 

419 @cached_property 

420 def prob_test_mean(self) -> Float32[Array, ' m'] | None: 

421 """The marginal posterior probability of y being True at `x_test`.""" 

422 if self.prob_test is None: 

423 return None 

424 return self.prob_test.mean(axis=0) 

425 

426 @cached_property 

427 def prob_train(self) -> Float32[Array, 'ndpost n'] | None: 

428 """The posterior probability of y being True at `x_train` for each MCMC iteration.""" 

429 if self._mcmc_state.binary_y is not None: 

430 return ndtr(self.yhat_train) 

431 else: 

432 return None 

433 

434 @cached_property 

435 def prob_train_mean(self) -> Float32[Array, ' n'] | None: 

436 """The marginal posterior probability of y being True at `x_train`.""" 

437 if self.prob_train is None: 

438 return None 

439 else: 

440 return self.prob_train.mean(axis=0) 

441 

442 @cached_property 

443 def sigma( 

444 self, 

445 ) -> ( 

446 Float32[Array, ' nskip_plus_ndpost'] 

447 | Float32[Array, 'nskip_plus_ndpost_per_core mc_cores'] 

448 | None 

449 ): 

450 """The standard deviation of the error, including burn-in samples.""" 

451 if self._mcmc_state.binary_y is not None: 

452 return None 

453 assert self._burnin_trace.error_cov_inv.ndim <= 2 # chains and samples 

454 tc = chain_vmap_axes(self._main_trace).error_cov_inv 

455 

456 def arrange(arr: Shaped[Array, '...']) -> Shaped[Array, '...']: 

457 # Public output is (nskip+ndpost, mc_cores) = (samples, chains). 

458 return chain_to_axis(arr, tc, target=-1) 

459 

460 return jnp.sqrt( 

461 jnp.reciprocal( 

462 jnp.concatenate( 

463 [ 

464 arrange(self._burnin_trace.error_cov_inv), 

465 arrange(self._main_trace.error_cov_inv), 

466 ], 

467 axis=0, 

468 ) 

469 ) 

470 ) 

471 

472 @cached_property 

473 def sigma_(self) -> Float32[Array, 'ndpost'] | None: 

474 """The standard deviation of the error, only over the post-burnin samples and flattened.""" 

475 if self._mcmc_state.binary_y is not None: 

476 return None 

477 assert self._main_trace.error_cov_inv.ndim <= 2 # chains and samples 

478 arr = chain_to_axis( 

479 self._main_trace.error_cov_inv, 

480 chain_vmap_axes(self._main_trace).error_cov_inv, 

481 ) 

482 return jnp.sqrt(jnp.reciprocal(arr)).reshape(-1) 

483 

484 @cached_property 

485 def sigma_mean(self) -> Float32[Array, ''] | None: 

486 """The mean of `sigma`, only over the post-burnin samples.""" 

487 if self.sigma_ is None: 

488 return None 

489 return self.sigma_.mean() 

490 

491 @cached_property 

492 def varcount(self) -> Int32[Array, 'ndpost p']: 

493 """Histogram of predictor usage for decision rules in the trees.""" 

494 return self._bart.varcount 

495 

496 @cached_property 

497 def varcount_mean(self) -> Float32[Array, ' p']: 

498 """Average of `varcount` across MCMC iterations.""" 

499 return self._bart.varcount_mean 

500 

501 @cached_property 

502 def varprob(self) -> Float32[Array, 'ndpost p']: 

503 """Posterior samples of the probability of choosing each predictor for a decision rule.""" 

504 return self._bart.varprob 

505 

506 @cached_property 

507 def varprob_mean(self) -> Float32[Array, ' p']: 

508 """The marginal posterior probability of each predictor being chosen for a decision rule.""" 

509 return self._bart.varprob_mean 

510 

511 @cached_property 

512 def yhat_test_mean(self) -> Float32[Array, ' m'] | None: 

513 """The marginal posterior mean at `x_test`. 

514 

515 Not defined with binary regression because it's error-prone, typically 

516 the right thing to consider would be `prob_test_mean`. 

517 """ 

518 if self._yhat_test is None or self._mcmc_state.binary_y is not None: 

519 return None 

520 return self._yhat_test.mean(axis=0) 

521 

522 @cached_property 

523 def yhat_train(self) -> Float32[Array, 'ndpost n']: 

524 """The conditional posterior mean at `x_train` for each MCMC iteration.""" 

525 return self._bart.predict('train', kind=PredictKind.latent_samples) 

526 

527 @cached_property 

528 def yhat_train_mean(self) -> Float32[Array, ' n'] | None: 

529 """The marginal posterior mean at `x_train`. 

530 

531 Not defined with binary regression because it's error-prone, typically 

532 the right thing to consider would be `prob_train_mean`. 

533 """ 

534 if self._mcmc_state.binary_y is not None: 

535 return None 

536 else: 

537 return self.yhat_train.mean(axis=0) 

538 

539 # Public methods from Bart 

540 

541 def predict( 

542 self, x_test: Real[ArrayLike, 'm p'] | DataFrame 

543 ) -> Float32[Array, 'ndpost m']: 

544 """ 

545 Evaluate the sum-of-trees at `x_test` for each MCMC iteration. 

546 

547 Parameters 

548 ---------- 

549 x_test 

550 The test predictors. 

551 

552 Returns 

553 ------- 

554 Posterior samples of the latent function value at `x_test`. In the continuous case, this is the conditional mean. 

555 

556 Raises 

557 ------ 

558 ValueError 

559 If `x_test` has a different format than `x_train`. 

560 """ 

561 # pre-process and check the format matches x_train; Bart only sees plain 

562 # arrays, so this consistency check is re-implemented here 

563 x_test, x_test_fmt = _process_bart3_predictor_input(x_test) 

564 if x_test_fmt != self._x_train_fmt: 

565 msg = ( 

566 f'Input format mismatch: {x_test_fmt=} ' 

567 f'!= x_train_fmt={self._x_train_fmt!r}' 

568 ) 

569 raise ValueError(msg) 

570 return self._bart.predict(x_test, kind=PredictKind.latent_samples) 

571 

572 

573class gbart(mc_gbart): 

574 """Subclass of `mc_gbart` that forces `mc_cores=1`.""" 

575 

576 def __init__(self, *args: Any, **kwargs: Any) -> None: 

577 if 'mc_cores' in kwargs: 577 ↛ 580line 577 didn't jump to line 580 because the condition on line 577 was always true

578 msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'" 

579 raise TypeError(msg) 

580 kwargs.update(mc_cores=1) 

581 super().__init__(*args, **kwargs) 

582 

583 

584def _process_bart3_predictor_input( 

585 x: Real[ArrayLike, 'n p'] | DataFrame, 

586) -> tuple[Shaped[Array, 'p n'], Any]: 

587 """Process BART3-style predictors (one predictor per column) to bartz layout. 

588 

589 Unlike `bartz.Bart`, BART3 lays out predictor matrices with one predictor 

590 per column, so plain arrays are transposed to bartz's (p, n) layout. 

591 Dataframes already use one column per predictor, so they are left untouched. 

592 """ 

593 if not isinstance(x, DataFrame): 

594 x = jnp.asarray(x).T 

595 return _process_predictor_input(x) 

596 

597 

598def _resolve_sigma_prior( 

599 x_train: Shaped[Array, 'p n'], 

600 y_train: Float32[Array, ' n'], 

601 *, 

602 type: Literal['wbart', 'pbart'], # noqa: A002 

603 sigest: FloatLike | None, 

604 sigdf: FloatLike, 

605 sigquant: FloatLike, 

606 lambda_: FloatLike | None, 

607) -> tuple[dict, Float32[Array, ''] | None]: 

608 """Map the BART3 error-variance settings to Bart's sigma prior. 

609 

610 Returns (sigma_kwargs, sigest) where sigest is the error standard deviation 

611 estimate, or None for binary regression or when `lambda_` is given. 

612 """ 

613 if type == 'pbart': 

614 if sigest is not None or lambda_ is not None: 

615 msg = 'Do not set `sigest` or `lambda_` for binary regression, they are ignored' 

616 raise ValueError(msg) 

617 return {}, None 

618 

619 if lambda_ is None: 

620 if sigest is None: 

621 sigest2 = _sigest2_ols(x_train, y_train) 

622 else: 

623 sigest2 = jnp.square(jnp.asarray(sigest, jnp.float32)) 

624 sigest_out = jnp.sqrt(sigest2) 

625 # lambda_ such that the sigquant quantile of the prior matches sigest² 

626 invchi2 = invgamma.ppf(sigquant, sigdf / 2) / 2 

627 lambda_ = sigest2 / (invchi2 * sigdf) 

628 else: 

629 if sigest is not None: 

630 msg = "Do not set `sigest` if `lambda_` is specified, it's ignored" 

631 raise ValueError(msg) 

632 lambda_ = jnp.asarray(lambda_, jnp.float32) 

633 sigest_out = None 

634 

635 # Bart's prior reduces to scaled-inv-χ²(sigma_df, sigma_scale²) on the error 

636 # variance, matching BART3's scaled-inv-χ²(sigdf, lambda_); sigma_init keeps 

637 # the initial precision at the prior mean nu/rate = 1 / lambda_ 

638 sigma_scale = jnp.sqrt(lambda_) 

639 sigma_kw = dict(sigma_df=sigdf, sigma_scale=sigma_scale, sigma_init=sigma_scale) 

640 return sigma_kw, sigest_out 

641 

642 

643def _sigest2_ols( 

644 x_train: Shaped[Array, 'p n'], y_train: Float32[Array, ' n'] 

645) -> Float32[Array, '']: 

646 """Estimate the error variance by OLS with intercept.""" 

647 p, n = x_train.shape 

648 if n <= p: 

649 msg = ( 

650 f'cannot estimate `sigest` by OLS with {n} datapoints and {p} ' 

651 'predictors (it requires more datapoints than predictors); ' 

652 'specify `sigest` or `lambda_` explicitly' 

653 ) 

654 raise ValueError(msg) 

655 return _sigma2_from_ols(x_train, y_train)