Coverage for src/bartz/BART/_gbart.py: 99%
174 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/BART/_gbart.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package."""
27from collections.abc import Mapping
28from functools import cached_property, partial
29from types import MappingProxyType
30from typing import Any, Literal
32import jax.numpy as jnp
33from equinox import Module, field
34from jax.scipy.special import ndtr
35from jaxtyping import Array, Float, Float32, Int32, Key, Real, Shaped
37from bartz._interface import (
38 ArrayLike,
39 Bart,
40 DataFrame,
41 FloatLike,
42 PredictKind,
43 Series,
44 SparseConfig,
45 _process_predictor_input,
46 _process_response_input,
47)
48from bartz._jaxext.scipy.stats import invgamma
49from bartz.mcmcloop import BurninTrace, MainTrace
50from bartz.mcmcstep._axes import chain_to_axis, chain_vmap_axes
51from bartz.mcmcstep._state import State
52from bartz.prepcovars import GivenSplitsBinner, RangeEvenBinner, UniqueQuantileBinner
53from bartz.prepcovars._prepcovars import _sigma2_from_ols
56class mc_gbart(Module):
57 R"""
58 Nonparametric regression with Bayesian Additive Regression Trees (BART).
60 Regress `y_train` on `x_train` with a latent mean function represented as
61 a sum of decision trees [2]_. The inference is carried out by sampling the
62 posterior distribution of the tree ensemble with an MCMC.
64 Parameters
65 ----------
66 x_train
67 The training predictors.
68 y_train
69 The training responses.
70 x_test
71 The test predictors.
72 type
73 The type of regression. 'wbart' for continuous regression, 'pbart' for
74 binary regression with probit link.
75 sparse
76 Whether to activate variable selection on the predictors as done in
77 [1]_.
78 theta
79 a
80 b
81 rho
82 Hyperparameters of the sparsity prior used for variable selection.
84 The prior distribution on the choice of predictor for each decision rule
85 is
87 .. math::
88 (s_1, \ldots, s_p) \sim
89 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
91 If `theta` is not specified, it's a priori distributed according to
93 .. math::
94 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
95 \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
97 If not specified, `rho` is set to the number of predictors p. To tune
98 the prior, consider setting a lower `rho` to prefer more sparsity.
99 If setting `theta` directly, it should be in the ballpark of p or lower
100 as well.
101 augment
102 Whether to account exactly for the decision rules forbidden by the
103 ancestors of each node when updating the variable selection
104 probabilities, using data augmentation. Only relevant if ``sparse=True``.
105 Like the ``augment`` option of R BART3, but sampling the exact full
106 conditional rather than substituting expected counts.
107 varprob
108 The probability distribution over the `p` predictors for choosing a
109 predictor to split on in a decision node a priori. Must be > 0. It does
110 not need to be normalized to sum to 1. If not specified, use a uniform
111 distribution. If ``sparse=True``, this is used as initial value for the
112 MCMC.
113 xinfo
114 A matrix with the cutpoins to use to bin each predictor. If not
115 specified, it is generated automatically according to `usequants` and
116 `numcut`.
118 Each row shall contain a sorted list of cutpoints for a predictor. If
119 there are less cutpoints than the number of columns in the matrix,
120 fill the remaining cells with NaN.
122 `xinfo` shall be a matrix even if `x_train` is a dataframe.
123 usequants
124 Whether to use predictors quantiles instead of a uniform grid to bin
125 predictors. Ignored if `xinfo` is specified.
126 rm_const
127 How to treat predictors with no associated decision rules (i.e., there
128 are no available cutpoints for that predictor). If `True` (default),
129 they are ignored. If `False`, an error is raised if there are any.
130 sigest
131 An estimate of the residual standard deviation on `y_train`, used to set
132 `lambda_`. If not specified, it is estimated by linear regression (with
133 intercept, and without taking into account `w`). Ignored if `lambda_` is
134 specified.
135 sigdf
136 The degrees of freedom of the scaled inverse-chisquared prior on the
137 noise variance.
138 sigquant
139 The quantile of the prior on the noise variance that shall match
140 `sigest` to set the scale of the prior. Ignored if `lambda_` is specified.
141 k
142 The inverse scale of the prior standard deviation on the latent mean
143 function, relative to half the observed range of `y_train`. If `y_train`
144 has less than two elements, `k` is ignored and the scale is set to 1.
145 power
146 base
147 Parameters of the prior on tree node generation. The probability that a
148 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
149 power``.
150 lambda_
151 The prior harmonic mean of the error variance. (The harmonic mean of x
152 is 1/mean(1/x).) If not specified, it is set based on `sigest` and
153 `sigquant`.
154 tau_num
155 The numerator in the expression that determines the prior standard
156 deviation of leaves. If not specified, default to ``(max(y_train) -
157 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
158 continuous regression, and 3 for binary regression.
159 offset
160 The prior mean of the latent mean function. If not specified, it is set
161 to the mean of `y_train` for continuous regression, and to
162 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
163 `offset` is set to 0. With binary regression, if `y_train` is all
164 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
165 ``Phi^-1(n/(n+1))``, respectively.
166 w
167 Coefficients that rescale the error standard deviation on each
168 datapoint. Not specifying `w` is equivalent to setting it to 1 for all
169 datapoints. Note: `w` is ignored in the automatic determination of
170 `sigest`, so either the weights should be O(1), or `sigest` should be
171 specified by the user.
172 ntree
173 The number of trees used to represent the latent mean function. By
174 default 200 for continuous regression and 50 for binary regression.
175 numcut
176 If `usequants` is `False`: the exact number of cutpoints used to bin the
177 predictors, ranging between the minimum and maximum observed values
178 (excluded).
180 If `usequants` is `True`: the maximum number of cutpoints to use for
181 binning the predictors. Each predictor is binned such that its
182 distribution in `x_train` is approximately uniform across bins. The
183 number of bins is at most the number of unique values appearing in
184 `x_train`, or ``numcut + 1``.
186 Before running the algorithm, the predictors are compressed to the
187 smallest integer type that fits the bin indices, so `numcut` is best set
188 to the maximum value of an unsigned integer type, like 255.
190 Ignored if `xinfo` is specified.
191 ndpost
192 The number of MCMC samples to save, after burn-in. `ndpost` is the
193 total number of samples across all chains. `ndpost` is rounded up to the
194 first multiple of `mc_cores`.
195 nskip
196 The number of initial MCMC samples to discard as burn-in. This number
197 of samples is discarded from each chain.
198 keepevery
199 The thinning factor for the MCMC samples, after burn-in. By default, 1
200 for continuous regression and 10 for binary regression.
201 printevery
202 The number of iterations (including thinned-away ones) between each log
203 line. Set to `None` to disable logging. ^C interrupts the MCMC only
204 every `printevery` iterations, so with logging disabled it's impossible
205 to kill the MCMC conveniently.
206 mc_cores
207 The number of independent MCMC chains.
208 seed
209 The seed for the random number generator.
210 bart_kwargs
211 Additional arguments passed to `bartz.Bart`.
213 Notes
214 -----
215 This interface imitates the function ``mc_gbart`` from the R package `BART3
216 <https://github.com/rsparapa/bnptools>`_, but with these differences:
218 - If ``usequants=False``, R BART3 switches to quantiles anyway if there are
219 less predictor values than the required number of bins, while bartz
220 always follows the specification.
221 - Some functionality is missing.
222 - The error variance parameter is called `lambda_` instead of `lambda`,
223 since the latter is a reserved word in Python.
224 - There are some additional attributes, and some missing.
225 - The trees have a maximum depth of 6.
226 - `rm_const` refers to predictors without decision rules instead of
227 predictors that are constant in `x_train`.
228 - If `rm_const=True` and some variables are dropped, the predictors
229 matrix/dataframe passed to `predict` should still include them.
231 References
232 ----------
233 .. [1] Linero, Antonio R. (2018). "Bayesian Regression Trees for
234 High-Dimensional Prediction and Variable Selection". In: Journal of the
235 American Statistical Association 113.522, pp. 626-636.
236 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
237 Bayesian additive regression trees," The Annals of Applied Statistics,
238 Ann. Appl. Stat. 4(1), 266-298, (March 2010).
239 """
241 _bart: Bart
242 _x_train_fmt: Any = field(static=True, default=None)
243 _yhat_test: Float32[Array, 'ndpost m'] | None = None
245 sigest: Float32[Array, ''] | None = None
246 """The estimated standard deviation of the error used to set `lambda_`."""
248 def __init__(
249 self,
250 x_train: Real[ArrayLike, 'n p'] | DataFrame,
251 y_train: Float32[ArrayLike, ' n'] | Series,
252 *,
253 x_test: Real[ArrayLike, 'm p'] | DataFrame | None = None,
254 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
255 sparse: bool = False,
256 theta: FloatLike | None = None,
257 a: FloatLike = 0.5,
258 b: FloatLike = 1.0,
259 rho: FloatLike | None = None,
260 augment: bool = False,
261 varprob: Float[ArrayLike, ' p'] | None = None,
262 xinfo: Float[ArrayLike, 'p ncut'] | None = None,
263 usequants: bool = False,
264 rm_const: bool = True,
265 sigest: FloatLike | None = None,
266 sigdf: FloatLike = 3.0,
267 sigquant: FloatLike = 0.9,
268 k: FloatLike = 2.0,
269 power: FloatLike = 2.0,
270 base: FloatLike = 0.95,
271 lambda_: FloatLike | None = None,
272 tau_num: FloatLike | None = None,
273 offset: FloatLike | None = None,
274 w: Float[ArrayLike, ' n'] | Series | None = None,
275 ntree: int | None = None,
276 numcut: int = 100,
277 ndpost: int = 1000,
278 nskip: int = 100,
279 keepevery: int | None = None,
280 printevery: int | None = 100,
281 mc_cores: int = 2,
282 seed: int | Key[Array, ''] = 0,
283 bart_kwargs: Mapping = MappingProxyType({}),
284 ) -> None:
285 # set defaults that depend on type of regression
286 if keepevery is None:
287 keepevery = 10 if type == 'pbart' else 1
288 if ntree is None:
289 ntree = 50 if type == 'pbart' else 200
291 # pre-process the data to numeric arrays once, so the OLS estimate of
292 # `sigest` and `Bart` share a single copy of the (memory-heavy) X matrix.
293 # `Bart` records the format as plain arrays, so `predict` re-implements
294 # the input-format consistency check against the original format here.
295 x_train, self._x_train_fmt = _process_bart3_predictor_input(x_train)
296 y_train = _process_response_input(y_train)
298 # map the BART3 error-variance settings to Bart's sigma prior, estimating
299 # `sigest` by linear regression on x_train when needed
300 sigma_kw, self.sigest = _resolve_sigma_prior(
301 x_train,
302 y_train,
303 type=type,
304 sigest=sigest,
305 sigdf=sigdf,
306 sigquant=sigquant,
307 lambda_=lambda_,
308 )
310 # convert to per-chain n_save for Bart
311 num_chains = None if mc_cores == 1 else mc_cores
312 actual_num_chains = num_chains or 1
313 n_save = ndpost // actual_num_chains + bool(ndpost % actual_num_chains)
315 # translate xinfo/usequants/numcut to a binner factory
316 if xinfo is not None:
317 binner = partial(GivenSplitsBinner, xinfo=jnp.asarray(xinfo))
318 elif usequants:
319 binner = partial(
320 UniqueQuantileBinner, max_bins=numcut + 1, max_subsample=None
321 )
322 else:
323 binner = partial(RangeEvenBinner, max_bins=numcut + 1)
325 # set most calling arguments for Bart
326 kwargs: dict = dict(
327 x_train=x_train,
328 y_train=y_train,
329 outcome_type=dict(wbart='continuous', pbart='binary')[type],
330 sparse=SparseConfig(
331 enabled=sparse, theta=theta, a=a, b=b, rho=rho, augment=augment
332 ),
333 varprob=varprob,
334 binner=binner,
335 rm_const=rm_const,
336 **sigma_kw,
337 k=k,
338 power=power,
339 base=base,
340 tau_num=tau_num,
341 offset=offset,
342 error_scale=w,
343 num_trees=ntree,
344 n_save=n_save,
345 n_burn=nskip,
346 n_skip=keepevery,
347 printevery=printevery,
348 seed=seed,
349 maxdepth=6,
350 num_chains=num_chains,
351 )
353 # default min_points_per_leaf to 5 (unless set by the user) to match
354 # BART3's hard-coded nl>=5 && nr>=5 birth check.
355 # min_points_per_decision_node keeps the Bart default of 10
356 # (= 2 * min_points_per_leaf): it makes the proposal efficient by not
357 # trying to grow leaves too small to split, without changing the target
358 # posterior, which thus matches BART3.
359 if 'min_points_per_leaf' not in bart_kwargs.get('init_kw', {}):
360 bart_kwargs = dict(
361 bart_kwargs,
362 init_kw=dict(bart_kwargs.get('init_kw', {}), min_points_per_leaf=5),
363 )
365 # add user arguments
366 kwargs.update(bart_kwargs)
368 # invoke Bart
369 self._bart = Bart(**kwargs)
371 # predict at test points
372 if x_test is not None:
373 self._yhat_test = self.predict(x_test)
375 # Public attributes from Bart
377 @property
378 def ndpost(self) -> int:
379 """The number of MCMC samples saved, after burn-in."""
380 return self._bart.ndpost
382 @property
383 def offset(self) -> Float32[Array, '']:
384 """The prior mean of the latent mean function."""
385 return self._bart.offset
387 # Private attributes from Bart
389 @property
390 def _main_trace(self) -> MainTrace:
391 return self._bart._main_trace # noqa: SLF001
393 @property
394 def _burnin_trace(self) -> BurninTrace:
395 return self._bart._burnin_trace # noqa: SLF001
397 @property
398 def _mcmc_state(self) -> State:
399 return self._bart._mcmc_state # noqa: SLF001
401 @property
402 def _splits(self) -> Real[Array, 'p max_num_splits']:
403 return self._bart._binner._splits # noqa: SLF001
405 # Properties
407 @property
408 def yhat_test(self) -> Float32[Array, 'ndpost m'] | None:
409 """The conditional posterior mean at `x_test` for each MCMC iteration."""
410 return self._yhat_test
412 @cached_property
413 def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
414 """The posterior probability of y being True at `x_test` for each MCMC iteration."""
415 if self._yhat_test is None or self._mcmc_state.binary_y is None:
416 return None
417 return ndtr(self._yhat_test)
419 @cached_property
420 def prob_test_mean(self) -> Float32[Array, ' m'] | None:
421 """The marginal posterior probability of y being True at `x_test`."""
422 if self.prob_test is None:
423 return None
424 return self.prob_test.mean(axis=0)
426 @cached_property
427 def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
428 """The posterior probability of y being True at `x_train` for each MCMC iteration."""
429 if self._mcmc_state.binary_y is not None:
430 return ndtr(self.yhat_train)
431 else:
432 return None
434 @cached_property
435 def prob_train_mean(self) -> Float32[Array, ' n'] | None:
436 """The marginal posterior probability of y being True at `x_train`."""
437 if self.prob_train is None:
438 return None
439 else:
440 return self.prob_train.mean(axis=0)
442 @cached_property
443 def sigma(
444 self,
445 ) -> (
446 Float32[Array, ' nskip_plus_ndpost']
447 | Float32[Array, 'nskip_plus_ndpost_per_core mc_cores']
448 | None
449 ):
450 """The standard deviation of the error, including burn-in samples."""
451 if self._mcmc_state.binary_y is not None:
452 return None
453 assert self._burnin_trace.error_cov_inv.ndim <= 2 # chains and samples
454 tc = chain_vmap_axes(self._main_trace).error_cov_inv
456 def arrange(arr: Shaped[Array, '...']) -> Shaped[Array, '...']:
457 # Public output is (nskip+ndpost, mc_cores) = (samples, chains).
458 return chain_to_axis(arr, tc, target=-1)
460 return jnp.sqrt(
461 jnp.reciprocal(
462 jnp.concatenate(
463 [
464 arrange(self._burnin_trace.error_cov_inv),
465 arrange(self._main_trace.error_cov_inv),
466 ],
467 axis=0,
468 )
469 )
470 )
472 @cached_property
473 def sigma_(self) -> Float32[Array, 'ndpost'] | None:
474 """The standard deviation of the error, only over the post-burnin samples and flattened."""
475 if self._mcmc_state.binary_y is not None:
476 return None
477 assert self._main_trace.error_cov_inv.ndim <= 2 # chains and samples
478 arr = chain_to_axis(
479 self._main_trace.error_cov_inv,
480 chain_vmap_axes(self._main_trace).error_cov_inv,
481 )
482 return jnp.sqrt(jnp.reciprocal(arr)).reshape(-1)
484 @cached_property
485 def sigma_mean(self) -> Float32[Array, ''] | None:
486 """The mean of `sigma`, only over the post-burnin samples."""
487 if self.sigma_ is None:
488 return None
489 return self.sigma_.mean()
491 @cached_property
492 def varcount(self) -> Int32[Array, 'ndpost p']:
493 """Histogram of predictor usage for decision rules in the trees."""
494 return self._bart.varcount
496 @cached_property
497 def varcount_mean(self) -> Float32[Array, ' p']:
498 """Average of `varcount` across MCMC iterations."""
499 return self._bart.varcount_mean
501 @cached_property
502 def varprob(self) -> Float32[Array, 'ndpost p']:
503 """Posterior samples of the probability of choosing each predictor for a decision rule."""
504 return self._bart.varprob
506 @cached_property
507 def varprob_mean(self) -> Float32[Array, ' p']:
508 """The marginal posterior probability of each predictor being chosen for a decision rule."""
509 return self._bart.varprob_mean
511 @cached_property
512 def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
513 """The marginal posterior mean at `x_test`.
515 Not defined with binary regression because it's error-prone, typically
516 the right thing to consider would be `prob_test_mean`.
517 """
518 if self._yhat_test is None or self._mcmc_state.binary_y is not None:
519 return None
520 return self._yhat_test.mean(axis=0)
522 @cached_property
523 def yhat_train(self) -> Float32[Array, 'ndpost n']:
524 """The conditional posterior mean at `x_train` for each MCMC iteration."""
525 return self._bart.predict('train', kind=PredictKind.latent_samples)
527 @cached_property
528 def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
529 """The marginal posterior mean at `x_train`.
531 Not defined with binary regression because it's error-prone, typically
532 the right thing to consider would be `prob_train_mean`.
533 """
534 if self._mcmc_state.binary_y is not None:
535 return None
536 else:
537 return self.yhat_train.mean(axis=0)
539 # Public methods from Bart
541 def predict(
542 self, x_test: Real[ArrayLike, 'm p'] | DataFrame
543 ) -> Float32[Array, 'ndpost m']:
544 """
545 Evaluate the sum-of-trees at `x_test` for each MCMC iteration.
547 Parameters
548 ----------
549 x_test
550 The test predictors.
552 Returns
553 -------
554 Posterior samples of the latent function value at `x_test`. In the continuous case, this is the conditional mean.
556 Raises
557 ------
558 ValueError
559 If `x_test` has a different format than `x_train`.
560 """
561 # pre-process and check the format matches x_train; Bart only sees plain
562 # arrays, so this consistency check is re-implemented here
563 x_test, x_test_fmt = _process_bart3_predictor_input(x_test)
564 if x_test_fmt != self._x_train_fmt:
565 msg = (
566 f'Input format mismatch: {x_test_fmt=} '
567 f'!= x_train_fmt={self._x_train_fmt!r}'
568 )
569 raise ValueError(msg)
570 return self._bart.predict(x_test, kind=PredictKind.latent_samples)
573class gbart(mc_gbart):
574 """Subclass of `mc_gbart` that forces `mc_cores=1`."""
576 def __init__(self, *args: Any, **kwargs: Any) -> None:
577 if 'mc_cores' in kwargs: 577 ↛ 580line 577 didn't jump to line 580 because the condition on line 577 was always true
578 msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'"
579 raise TypeError(msg)
580 kwargs.update(mc_cores=1)
581 super().__init__(*args, **kwargs)
584def _process_bart3_predictor_input(
585 x: Real[ArrayLike, 'n p'] | DataFrame,
586) -> tuple[Shaped[Array, 'p n'], Any]:
587 """Process BART3-style predictors (one predictor per column) to bartz layout.
589 Unlike `bartz.Bart`, BART3 lays out predictor matrices with one predictor
590 per column, so plain arrays are transposed to bartz's (p, n) layout.
591 Dataframes already use one column per predictor, so they are left untouched.
592 """
593 if not isinstance(x, DataFrame):
594 x = jnp.asarray(x).T
595 return _process_predictor_input(x)
598def _resolve_sigma_prior(
599 x_train: Shaped[Array, 'p n'],
600 y_train: Float32[Array, ' n'],
601 *,
602 type: Literal['wbart', 'pbart'], # noqa: A002
603 sigest: FloatLike | None,
604 sigdf: FloatLike,
605 sigquant: FloatLike,
606 lambda_: FloatLike | None,
607) -> tuple[dict, Float32[Array, ''] | None]:
608 """Map the BART3 error-variance settings to Bart's sigma prior.
610 Returns (sigma_kwargs, sigest) where sigest is the error standard deviation
611 estimate, or None for binary regression or when `lambda_` is given.
612 """
613 if type == 'pbart':
614 if sigest is not None or lambda_ is not None:
615 msg = 'Do not set `sigest` or `lambda_` for binary regression, they are ignored'
616 raise ValueError(msg)
617 return {}, None
619 if lambda_ is None:
620 if sigest is None:
621 sigest2 = _sigest2_ols(x_train, y_train)
622 else:
623 sigest2 = jnp.square(jnp.asarray(sigest, jnp.float32))
624 sigest_out = jnp.sqrt(sigest2)
625 # lambda_ such that the sigquant quantile of the prior matches sigest²
626 invchi2 = invgamma.ppf(sigquant, sigdf / 2) / 2
627 lambda_ = sigest2 / (invchi2 * sigdf)
628 else:
629 if sigest is not None:
630 msg = "Do not set `sigest` if `lambda_` is specified, it's ignored"
631 raise ValueError(msg)
632 lambda_ = jnp.asarray(lambda_, jnp.float32)
633 sigest_out = None
635 # Bart's prior reduces to scaled-inv-χ²(sigma_df, sigma_scale²) on the error
636 # variance, matching BART3's scaled-inv-χ²(sigdf, lambda_); sigma_init keeps
637 # the initial precision at the prior mean nu/rate = 1 / lambda_
638 sigma_scale = jnp.sqrt(lambda_)
639 sigma_kw = dict(sigma_df=sigdf, sigma_scale=sigma_scale, sigma_init=sigma_scale)
640 return sigma_kw, sigest_out
643def _sigest2_ols(
644 x_train: Shaped[Array, 'p n'], y_train: Float32[Array, ' n']
645) -> Float32[Array, '']:
646 """Estimate the error variance by OLS with intercept."""
647 p, n = x_train.shape
648 if n <= p:
649 msg = (
650 f'cannot estimate `sigest` by OLS with {n} datapoints and {p} '
651 'predictors (it requires more datapoints than predictors); '
652 'specify `sigest` or `lambda_` explicitly'
653 )
654 raise ValueError(msg)
655 return _sigma2_from_ols(x_train, y_train)