Coverage for src / bartz / BART / _gbart.py: 94%
160 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
1# bartz/src/bartz/BART/_gbart.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package."""
27from collections.abc import Hashable, Mapping
28from functools import cached_property
29from os import cpu_count
30from types import MappingProxyType
31from typing import Any, Literal
32from warnings import warn
34import jax.numpy as jnp
35from equinox import Module
36from jax import device_count
37from jax.scipy.special import ndtr
38from jaxtyping import Array, Float, Float32, Int32, Key, Real
40from bartz import mcmcloop, mcmcstep
41from bartz._interface import Bart, DataFrame, FloatLike, PredictKind, Series
42from bartz.jaxext import get_default_device, jit_active
45class mc_gbart(Module):
46 R"""
47 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
49 Regress `y_train` on `x_train` with a latent mean function represented as
50 a sum of decision trees. The inference is carried out by sampling the
51 posterior distribution of the tree ensemble with an MCMC.
53 Parameters
54 ----------
55 x_train
56 The training predictors.
57 y_train
58 The training responses.
59 x_test
60 The test predictors.
61 type
62 The type of regression. 'wbart' for continuous regression, 'pbart' for
63 binary regression with probit link.
64 sparse
65 Whether to activate variable selection on the predictors as done in
66 [1]_.
67 theta
68 a
69 b
70 rho
71 Hyperparameters of the sparsity prior used for variable selection.
73 The prior distribution on the choice of predictor for each decision rule
74 is
76 .. math::
77 (s_1, \ldots, s_p) \sim
78 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
80 If `theta` is not specified, it's a priori distributed according to
82 .. math::
83 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
84 \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
86 If not specified, `rho` is set to the number of predictors p. To tune
87 the prior, consider setting a lower `rho` to prefer more sparsity.
88 If setting `theta` directly, it should be in the ballpark of p or lower
89 as well.
90 varprob
91 The probability distribution over the `p` predictors for choosing a
92 predictor to split on in a decision node a priori. Must be > 0. It does
93 not need to be normalized to sum to 1. If not specified, use a uniform
94 distribution. If ``sparse=True``, this is used as initial value for the
95 MCMC.
96 xinfo
97 A matrix with the cutpoins to use to bin each predictor. If not
98 specified, it is generated automatically according to `usequants` and
99 `numcut`.
101 Each row shall contain a sorted list of cutpoints for a predictor. If
102 there are less cutpoints than the number of columns in the matrix,
103 fill the remaining cells with NaN.
105 `xinfo` shall be a matrix even if `x_train` is a dataframe.
106 usequants
107 Whether to use predictors quantiles instead of a uniform grid to bin
108 predictors. Ignored if `xinfo` is specified.
109 rm_const
110 How to treat predictors with no associated decision rules (i.e., there
111 are no available cutpoints for that predictor). If `True` (default),
112 they are ignored. If `False`, an error is raised if there are any.
113 sigest
114 An estimate of the residual standard deviation on `y_train`, used to set
115 `lamda`. If not specified, it is estimated by linear regression (with
116 intercept, and without taking into account `w`). If `y_train` has less
117 than two elements, it is set to 1. If n <= p, it is set to the standard
118 deviation of `y_train`. Ignored if `lamda` is specified.
119 sigdf
120 The degrees of freedom of the scaled inverse-chisquared prior on the
121 noise variance.
122 sigquant
123 The quantile of the prior on the noise variance that shall match
124 `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
125 k
126 The inverse scale of the prior standard deviation on the latent mean
127 function, relative to half the observed range of `y_train`. If `y_train`
128 has less than two elements, `k` is ignored and the scale is set to 1.
129 power
130 base
131 Parameters of the prior on tree node generation. The probability that a
132 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
133 power``.
134 lamda
135 The prior harmonic mean of the error variance. (The harmonic mean of x
136 is 1/mean(1/x).) If not specified, it is set based on `sigest` and
137 `sigquant`.
138 tau_num
139 The numerator in the expression that determines the prior standard
140 deviation of leaves. If not specified, default to ``(max(y_train) -
141 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
142 continuous regression, and 3 for binary regression.
143 offset
144 The prior mean of the latent mean function. If not specified, it is set
145 to the mean of `y_train` for continuous regression, and to
146 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
147 `offset` is set to 0. With binary regression, if `y_train` is all
148 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
149 ``Phi^-1(n/(n+1))``, respectively.
150 w
151 Coefficients that rescale the error standard deviation on each
152 datapoint. Not specifying `w` is equivalent to setting it to 1 for all
153 datapoints. Note: `w` is ignored in the automatic determination of
154 `sigest`, so either the weights should be O(1), or `sigest` should be
155 specified by the user.
156 ntree
157 The number of trees used to represent the latent mean function. By
158 default 200 for continuous regression and 50 for binary regression.
159 numcut
160 If `usequants` is `False`: the exact number of cutpoints used to bin the
161 predictors, ranging between the minimum and maximum observed values
162 (excluded).
164 If `usequants` is `True`: the maximum number of cutpoints to use for
165 binning the predictors. Each predictor is binned such that its
166 distribution in `x_train` is approximately uniform across bins. The
167 number of bins is at most the number of unique values appearing in
168 `x_train`, or ``numcut + 1``.
170 Before running the algorithm, the predictors are compressed to the
171 smallest integer type that fits the bin indices, so `numcut` is best set
172 to the maximum value of an unsigned integer type, like 255.
174 Ignored if `xinfo` is specified.
175 ndpost
176 The number of MCMC samples to save, after burn-in. `ndpost` is the
177 total number of samples across all chains. `ndpost` is rounded up to the
178 first multiple of `mc_cores`.
179 nskip
180 The number of initial MCMC samples to discard as burn-in. This number
181 of samples is discarded from each chain.
182 keepevery
183 The thinning factor for the MCMC samples, after burn-in. By default, 1
184 for continuous regression and 10 for binary regression.
185 printevery
186 The number of iterations (including thinned-away ones) between each log
187 line. Set to `None` to disable logging. ^C interrupts the MCMC only
188 every `printevery` iterations, so with logging disabled it's impossible
189 to kill the MCMC conveniently.
190 mc_cores
191 The number of independent MCMC chains.
192 seed
193 The seed for the random number generator.
194 bart_kwargs
195 Additional arguments passed to `bartz.Bart`.
197 Notes
198 -----
199 This interface imitates the function ``mc_gbart`` from the R package `BART3
200 <https://github.com/rsparapa/bnptools>`_, but with these differences:
202 - If `x_train` and `x_test` are matrices, they have one predictor per row
203 instead of per column.
204 - If ``usequants=False``, R BART3 switches to quantiles anyway if there are
205 less predictor values than the required number of bins, while bartz
206 always follows the specification.
207 - Some functionality is missing.
208 - The error variance parameter is called `lamda` instead of `lambda`.
209 - There are some additional attributes, and some missing.
210 - The trees have a maximum depth of 6.
211 - `rm_const` refers to predictors without decision rules instead of
212 predictors that are constant in `x_train`.
213 - If `rm_const=True` and some variables are dropped, the predictors
214 matrix/dataframe passed to `predict` should still include them.
216 References
217 ----------
218 .. [1] Linero, Antonio R. (2018). "Bayesian Regression Trees for
219 High-Dimensional Prediction and Variable Selection". In: Journal of the
220 American Statistical Association 113.522, pp. 626-636.
221 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
222 Bayesian additive regression trees," The Annals of Applied Statistics,
223 Ann. Appl. Stat. 4(1), 266-298, (March 2010).
224 """
226 _bart: Bart
227 _yhat_test: Float32[Array, 'ndpost m'] | None = None
229 def __init__(
230 self,
231 x_train: Real[Array, 'p n'] | DataFrame,
232 y_train: Float32[Array, ' n'] | Series,
233 *,
234 x_test: Real[Array, 'p m'] | DataFrame | None = None,
235 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
236 sparse: bool = False,
237 theta: FloatLike | None = None,
238 a: FloatLike = 0.5,
239 b: FloatLike = 1.0,
240 rho: FloatLike | None = None,
241 varprob: Float[Array, ' p'] | None = None,
242 xinfo: Float[Array, 'p n'] | None = None,
243 usequants: bool = False,
244 rm_const: bool = True,
245 sigest: FloatLike | None = None,
246 sigdf: FloatLike = 3.0,
247 sigquant: FloatLike = 0.9,
248 k: FloatLike = 2.0,
249 power: FloatLike = 2.0,
250 base: FloatLike = 0.95,
251 lamda: FloatLike | None = None,
252 tau_num: FloatLike | None = None,
253 offset: FloatLike | None = None,
254 w: Float[Array, ' n'] | None = None,
255 ntree: int | None = None,
256 numcut: int = 100,
257 ndpost: int = 1000,
258 nskip: int = 100,
259 keepevery: int | None = None,
260 printevery: int | None = 100,
261 mc_cores: int = 2,
262 seed: int | Key[Array, ''] = 0,
263 bart_kwargs: Mapping = MappingProxyType({}),
264 ) -> None:
265 # set defaults that depend on type of regression
266 if keepevery is None: 1be
267 keepevery = 10 if type == 'pbart' else 1 1e
268 if ntree is None: 1bie
269 ntree = 50 if type == 'pbart' else 200 1min
271 # set most calling arguments for Bart
272 kwargs: dict = dict( 1bomin
273 x_train=x_train,
274 y_train=y_train,
275 outcome_type='binary' if type == 'pbart' else 'continuous',
276 sparse=sparse,
277 theta=theta,
278 a=a,
279 b=b,
280 rho=rho,
281 varprob=varprob,
282 xinfo=xinfo,
283 usequants=usequants,
284 rm_const=rm_const,
285 sigest=sigest,
286 sigdf=sigdf,
287 sigquant=sigquant,
288 k=k,
289 power=power,
290 base=base,
291 lamda=lamda,
292 tau_num=tau_num,
293 offset=offset,
294 w=w,
295 num_trees=ntree,
296 numcut=numcut,
297 ndpost=ndpost,
298 nskip=nskip,
299 keepevery=keepevery,
300 printevery=printevery,
301 seed=seed,
302 maxdepth=6,
303 **process_mc_cores(y_train, mc_cores),
304 )
306 # set min_points_per_leaf unless the user set it already
307 if 'min_points_per_leaf' not in bart_kwargs.get('init_kw', {}): 1bog
308 bart_kwargs = dict(bart_kwargs) 1g
309 init_kw = dict(bart_kwargs.get('init_kw', {})) 1g
310 init_kw['min_points_per_leaf'] = 5 1g
311 bart_kwargs['init_kw'] = init_kw 1g
313 # add user arguments
314 kwargs.update(bart_kwargs) 1b
316 # invoke Bart
317 self._bart = Bart(**kwargs) 1b
319 # predict at test points
320 if x_test is not None: 1qb
321 self._yhat_test = self._bart.predict( 1b
322 x_test, kind=PredictKind.latent_samples
323 )
325 # Public attributes from Bart
327 @property
328 def ndpost(self) -> int:
329 """The number of MCMC samples saved, after burn-in."""
330 return self._bart.ndpost 1f
332 @property
333 def offset(self) -> Float32[Array, '']:
334 """The prior mean of the latent mean function."""
335 return self._bart.offset 1c
337 @property
338 def sigest(self) -> Float32[Array, ''] | None:
339 """The estimated standard deviation of the error used to set `lamda`."""
340 return self._bart.sigest 1r
342 # Private attributes from Bart
344 @property
345 def _main_trace(self) -> mcmcloop.MainTrace:
346 return self._bart._main_trace # noqa: SLF001 1f
348 @property
349 def _burnin_trace(self) -> mcmcloop.BurninTrace:
350 return self._bart._burnin_trace # noqa: SLF001 1f
352 @property
353 def _mcmc_state(self) -> mcmcstep.State:
354 return self._bart._mcmc_state # noqa: SLF001 1f
356 @property
357 def _splits(self) -> Real[Array, 'p max_num_splits']:
358 return self._bart._splits # noqa: SLF001 1c
360 @property
361 def _x_train_fmt(self) -> Hashable:
362 return self._bart._x_train_fmt # noqa: SLF001
364 # Properties
366 @property
367 def yhat_test(self) -> Float32[Array, 'ndpost m'] | None:
368 """The conditional posterior mean at `x_test` for each MCMC iteration."""
369 return self._yhat_test 1c
371 @cached_property
372 def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
373 """The posterior probability of y being True at `x_test` for each MCMC iteration."""
374 if self._yhat_test is None or self._mcmc_state.binary_y is None: 1he
375 return None 1e
376 return ndtr(self._yhat_test) 1h
378 @cached_property
379 def prob_test_mean(self) -> Float32[Array, ' m'] | None:
380 """The marginal posterior probability of y being True at `x_test`."""
381 if self.prob_test is None: 1ed
382 return None 1e
383 return self.prob_test.mean(axis=0) 1d
385 @cached_property
386 def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
387 """The posterior probability of y being True at `x_train` for each MCMC iteration."""
388 if self._mcmc_state.binary_y is not None: 1he
389 return ndtr(self.yhat_train) 1h
390 else:
391 return None 1e
393 @cached_property
394 def prob_train_mean(self) -> Float32[Array, ' n'] | None:
395 """The marginal posterior probability of y being True at `x_train`."""
396 if self.prob_train is None: 1ed
397 return None 1e
398 else:
399 return self.prob_train.mean(axis=0) 1d
401 @cached_property
402 def sigma(
403 self,
404 ) -> (
405 Float32[Array, ' nskip+ndpost']
406 | Float32[Array, 'nskip+ndpost/mc_cores mc_cores']
407 | None
408 ):
409 """The standard deviation of the error, including burn-in samples."""
410 if self._mcmc_state.binary_y is not None: 1fd
411 return None 1d
412 assert self._burnin_trace.error_cov_inv.ndim <= 2 # chains and samples 1f
413 return jnp.sqrt( 1f
414 jnp.reciprocal(
415 jnp.concatenate(
416 [
417 self._burnin_trace.error_cov_inv.T,
418 self._main_trace.error_cov_inv.T,
419 ],
420 axis=0,
421 )
422 )
423 )
425 @cached_property
426 def sigma_(self) -> Float32[Array, 'ndpost'] | None:
427 """The standard deviation of the error, only over the post-burnin samples and flattened."""
428 if self._mcmc_state.binary_y is not None: 1cd
429 return None 1d
430 assert self._main_trace.error_cov_inv.ndim <= 2 # chains and samples 1c
431 return jnp.sqrt(jnp.reciprocal(self._main_trace.error_cov_inv)).reshape(-1) 1c
433 @cached_property
434 def sigma_mean(self) -> Float32[Array, ''] | None:
435 """The mean of `sigma`, only over the post-burnin samples."""
436 if self.sigma_ is None: 1cd
437 return None 1d
438 return self.sigma_.mean() 1c
440 @cached_property
441 def varcount(self) -> Int32[Array, 'ndpost p']:
442 """Histogram of predictor usage for decision rules in the trees."""
443 return self._bart.varcount 1f
445 @cached_property
446 def varcount_mean(self) -> Float32[Array, ' p']:
447 """Average of `varcount` across MCMC iterations."""
448 return self._bart.varcount_mean 1c
450 @cached_property
451 def varprob(self) -> Float32[Array, 'ndpost p']:
452 """Posterior samples of the probability of choosing each predictor for a decision rule."""
453 return self._bart.varprob 1f
455 @cached_property
456 def varprob_mean(self) -> Float32[Array, ' p']:
457 """The marginal posterior probability of each predictor being chosen for a decision rule."""
458 return self._bart.varprob_mean 1c
460 @cached_property
461 def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
462 """The marginal posterior mean at `x_test`.
464 Not defined with binary regression because it's error-prone, typically
465 the right thing to consider would be `prob_test_mean`.
466 """
467 if self._yhat_test is None or self._mcmc_state.binary_y is not None: 1cd
468 return None 1d
469 return self._yhat_test.mean(axis=0) 1c
471 @cached_property
472 def yhat_train(self) -> Float32[Array, 'ndpost n']:
473 """The conditional posterior mean at `x_train` for each MCMC iteration."""
474 return self._bart.predict('train', kind=PredictKind.latent_samples) 1f
476 @cached_property
477 def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
478 """The marginal posterior mean at `x_train`.
480 Not defined with binary regression because it's error-prone, typically
481 the right thing to consider would be `prob_train_mean`.
482 """
483 if self._mcmc_state.binary_y is not None: 1cd
484 return None 1d
485 else:
486 return self.yhat_train.mean(axis=0) 1c
488 # Public methods from Bart
490 def predict(
491 self, x_test: Real[Array, 'p m'] | DataFrame
492 ) -> Float32[Array, 'ndpost m']:
493 """
494 Evaluate the sum-of-trees at `x_test` for each MCMC iteration.
496 Parameters
497 ----------
498 x_test
499 The test predictors.
501 Returns
502 -------
503 Posterior samples of the latent function value at `x_test`. In the continuous case, this is the conditional mean.
504 """
505 return self._bart.predict(x_test, kind=PredictKind.latent_samples) 1s
508class gbart(mc_gbart):
509 """Subclass of `mc_gbart` that forces `mc_cores=1`."""
511 def __init__(self, *args: Any, **kwargs: Any) -> None:
512 if 'mc_cores' in kwargs: 512 ↛ 515line 512 didn't jump to line 515 because the condition on line 512 was always true1j
513 msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'" 1j
514 raise TypeError(msg) 1j
515 kwargs.update(mc_cores=1)
516 super().__init__(*args, **kwargs)
519def process_mc_cores(y_train: Array | Series, mc_cores: int) -> dict[str, Any]:
520 """Determine the arguments to pass to `Bart` to configure multiple chains."""
521 # one chain, disable multichain altogether
522 if abs(mc_cores) == 1: 1bg
523 return dict(num_chains=None) 1g
525 # determine if we are on cpu; this point may raise an exception
526 platform = get_platform(y_train, mc_cores) 1b
528 # set the num_chains argument
529 mc_cores = abs(mc_cores) 1b
530 kwargs = dict(num_chains=mc_cores) 1b
532 # if on cpu, try to shard the chains across multiple virtual cpus
533 if platform == 'cpu': 533 ↛ 565line 533 didn't jump to line 565 because the condition on line 533 was always true1b
534 # determine number of logical cpu cores
535 num_cores = cpu_count() 1b
536 assert num_cores is not None, 'could not determine number of cpu cores' 1b
538 # determine number of shards that evenly divides chains
539 for num_shards in range(num_cores, 0, -1): 539 ↛ 544line 539 didn't jump to line 544 because the loop on line 539 didn't complete1bp
540 if mc_cores % num_shards == 0: 1bp
541 break 1b
543 # handle the case where there are less jax cpu devices that that
544 if num_shards > 1: 544 ↛ 562line 544 didn't jump to line 562 because the condition on line 544 was always true1b
545 num_jax_cpus = device_count('cpu') 1b
546 if num_jax_cpus < num_shards: 1b
547 for new_num_shards in range(num_jax_cpus, 0, -1): 547 ↛ 550line 547 didn't jump to line 550 because the loop on line 547 didn't complete1b
548 if mc_cores % new_num_shards == 0: 548 ↛ 547line 548 didn't jump to line 547 because the condition on line 548 was always true1b
549 break 1b
550 msg = ( 1b
551 f'`mc_gbart` would like to shard {mc_cores} chains across '
552 f'{num_shards} virtual jax cpu devices, but jax is set up '
553 f'with only {num_jax_cpus} cpu devices, so it will use '
554 f'{new_num_shards} devices instead. To enable '
555 'parallelization, please increase the limit with '
556 '`jax.config.update("jax_num_cpu_devices", <num_devices>)`.'
557 )
558 warn(msg) 1b
559 num_shards = new_num_shards 1b
561 # set the number of shards
562 if num_shards > 1: 1b
563 kwargs.update(num_chain_devices=num_shards) 1b
565 return kwargs 1b
568def get_platform(y_train: Array | Series, mc_cores: int) -> str:
569 """Get the platform for `process_mc_cores` from `y_train` or the default device."""
570 if isinstance(y_train, Array) and hasattr(y_train, 'platform'): 1bkl
571 return y_train.platform() 1b
572 elif ( 572 ↛ 580line 572 didn't jump to line 580 because the condition on line 572 was always true1kl
573 not isinstance(y_train, Array) and not jit_active()
574 # this condition means: y_train is not an array, but we are not under
575 # jit, so y_train is going to be converted to an array on the default
576 # device
577 ) or mc_cores < 0:
578 return get_default_device().platform 1kl
579 else:
580 msg = (
581 'Could not determine the platform from `y_train`, maybe `mc_gbart` '
582 'was used with a `jax.jit`ted function? The platform is needed to '
583 'determine whether the computation is going to run on CPU to '
584 'automatically shard the chains across multiple virtual CPU '
585 'devices. To acknowledge this problem and circumvent it '
586 'by using the current default jax device, negate `mc_cores`.'
587 )
588 raise RuntimeError(msg)