Coverage for src / bartz / BART / _gbart.py: 85%
127 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
1# bartz/src/bartz/BART/_gbart.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package."""
27from collections.abc import Mapping
28from functools import cached_property
29from os import cpu_count
30from types import MappingProxyType
31from typing import Any, Literal
32from warnings import warn
34from equinox import Module
35from jax import device_count
36from jax import numpy as jnp
37from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real
39from bartz import mcmcloop, mcmcstep
40from bartz._interface import Bart, DataFrame, FloatLike, Series
41from bartz.jaxext import get_default_device
44class mc_gbart(Module):
45 R"""
46 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
48 Regress `y_train` on `x_train` with a latent mean function represented as
49 a sum of decision trees. The inference is carried out by sampling the
50 posterior distribution of the tree ensemble with an MCMC.
52 Parameters
53 ----------
54 x_train
55 The training predictors.
56 y_train
57 The training responses.
58 x_test
59 The test predictors.
60 type
61 The type of regression. 'wbart' for continuous regression, 'pbart' for
62 binary regression with probit link.
63 sparse
64 Whether to activate variable selection on the predictors as done in
65 [1]_.
66 theta
67 a
68 b
69 rho
70 Hyperparameters of the sparsity prior used for variable selection.
72 The prior distribution on the choice of predictor for each decision rule
73 is
75 .. math::
76 (s_1, \ldots, s_p) \sim
77 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
79 If `theta` is not specified, it's a priori distributed according to
81 .. math::
82 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
83 \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
85 If not specified, `rho` is set to the number of predictors p. To tune
86 the prior, consider setting a lower `rho` to prefer more sparsity.
87 If setting `theta` directly, it should be in the ballpark of p or lower
88 as well.
89 xinfo
90 A matrix with the cutpoins to use to bin each predictor. If not
91 specified, it is generated automatically according to `usequants` and
92 `numcut`.
94 Each row shall contain a sorted list of cutpoints for a predictor. If
95 there are less cutpoints than the number of columns in the matrix,
96 fill the remaining cells with NaN.
98 `xinfo` shall be a matrix even if `x_train` is a dataframe.
99 usequants
100 Whether to use predictors quantiles instead of a uniform grid to bin
101 predictors. Ignored if `xinfo` is specified.
102 rm_const
103 How to treat predictors with no associated decision rules (i.e., there
104 are no available cutpoints for that predictor). If `True` (default),
105 they are ignored. If `False`, an error is raised if there are any. If
106 `None`, no check is performed, and the output of the MCMC may not make
107 sense if there are predictors without cutpoints. The option `None` is
108 provided only to allow jax tracing.
109 sigest
110 An estimate of the residual standard deviation on `y_train`, used to set
111 `lamda`. If not specified, it is estimated by linear regression (with
112 intercept, and without taking into account `w`). If `y_train` has less
113 than two elements, it is set to 1. If n <= p, it is set to the standard
114 deviation of `y_train`. Ignored if `lamda` is specified.
115 sigdf
116 The degrees of freedom of the scaled inverse-chisquared prior on the
117 noise variance.
118 sigquant
119 The quantile of the prior on the noise variance that shall match
120 `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
121 k
122 The inverse scale of the prior standard deviation on the latent mean
123 function, relative to half the observed range of `y_train`. If `y_train`
124 has less than two elements, `k` is ignored and the scale is set to 1.
125 power
126 base
127 Parameters of the prior on tree node generation. The probability that a
128 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
129 power``.
130 lamda
131 The prior harmonic mean of the error variance. (The harmonic mean of x
132 is 1/mean(1/x).) If not specified, it is set based on `sigest` and
133 `sigquant`.
134 tau_num
135 The numerator in the expression that determines the prior standard
136 deviation of leaves. If not specified, default to ``(max(y_train) -
137 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
138 continuous regression, and 3 for binary regression.
139 offset
140 The prior mean of the latent mean function. If not specified, it is set
141 to the mean of `y_train` for continuous regression, and to
142 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
143 `offset` is set to 0. With binary regression, if `y_train` is all
144 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
145 ``Phi^-1(n/(n+1))``, respectively.
146 w
147 Coefficients that rescale the error standard deviation on each
148 datapoint. Not specifying `w` is equivalent to setting it to 1 for all
149 datapoints. Note: `w` is ignored in the automatic determination of
150 `sigest`, so either the weights should be O(1), or `sigest` should be
151 specified by the user.
152 ntree
153 The number of trees used to represent the latent mean function. By
154 default 200 for continuous regression and 50 for binary regression.
155 numcut
156 If `usequants` is `False`: the exact number of cutpoints used to bin the
157 predictors, ranging between the minimum and maximum observed values
158 (excluded).
160 If `usequants` is `True`: the maximum number of cutpoints to use for
161 binning the predictors. Each predictor is binned such that its
162 distribution in `x_train` is approximately uniform across bins. The
163 number of bins is at most the number of unique values appearing in
164 `x_train`, or ``numcut + 1``.
166 Before running the algorithm, the predictors are compressed to the
167 smallest integer type that fits the bin indices, so `numcut` is best set
168 to the maximum value of an unsigned integer type, like 255.
170 Ignored if `xinfo` is specified.
171 ndpost
172 The number of MCMC samples to save, after burn-in. `ndpost` is the
173 total number of samples across all chains. `ndpost` is rounded up to the
174 first multiple of `mc_cores`.
175 nskip
176 The number of initial MCMC samples to discard as burn-in. This number
177 of samples is discarded from each chain.
178 keepevery
179 The thinning factor for the MCMC samples, after burn-in. By default, 1
180 for continuous regression and 10 for binary regression.
181 printevery
182 The number of iterations (including thinned-away ones) between each log
183 line. Set to `None` to disable logging. ^C interrupts the MCMC only
184 every `printevery` iterations, so with logging disabled it's impossible
185 to kill the MCMC conveniently.
186 mc_cores
187 The number of independent MCMC chains.
188 seed
189 The seed for the random number generator.
190 bart_kwargs
191 Additional arguments passed to `bartz.Bart`.
193 Notes
194 -----
195 This interface imitates the function ``mc_gbart`` from the R package `BART3
196 <https://github.com/rsparapa/bnptools>`_, but with these differences:
198 - If `x_train` and `x_test` are matrices, they have one predictor per row
199 instead of per column.
200 - If ``usequants=False``, R BART3 switches to quantiles anyway if there are
201 less predictor values than the required number of bins, while bartz
202 always follows the specification.
203 - Some functionality is missing.
204 - The error variance parameter is called `lamda` instead of `lambda`.
205 - There are some additional attributes, and some missing.
206 - The trees have a maximum depth of 8.
207 - `rm_const` refers to predictors without decision rules instead of
208 predictors that are constant in `x_train`.
209 - If `rm_const=True` and some variables are dropped, the predictors
210 matrix/dataframe passed to `predict` should still include them.
212 References
213 ----------
214 .. [1] Linero, Antonio R. (2018). "Bayesian Regression Trees for
215 High-Dimensional Prediction and Variable Selection". In: Journal of the
216 American Statistical Association 113.522, pp. 626-636.
217 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
218 Bayesian additive regression trees," The Annals of Applied Statistics,
219 Ann. Appl. Stat. 4(1), 266-298, (March 2010).
220 """
222 _bart: Bart
224 def __init__(
225 self,
226 x_train: Real[Array, 'p n'] | DataFrame,
227 y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
228 *,
229 x_test: Real[Array, 'p m'] | DataFrame | None = None,
230 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
231 sparse: bool = False,
232 theta: FloatLike | None = None,
233 a: FloatLike = 0.5,
234 b: FloatLike = 1.0,
235 rho: FloatLike | None = None,
236 xinfo: Float[Array, 'p n'] | None = None,
237 usequants: bool = False,
238 rm_const: bool | None = True,
239 sigest: FloatLike | None = None,
240 sigdf: FloatLike = 3.0,
241 sigquant: FloatLike = 0.9,
242 k: FloatLike = 2.0,
243 power: FloatLike = 2.0,
244 base: FloatLike = 0.95,
245 lamda: FloatLike | None = None,
246 tau_num: FloatLike | None = None,
247 offset: FloatLike | None = None,
248 w: Float[Array, ' n'] | None = None,
249 ntree: int | None = None,
250 numcut: int = 100,
251 ndpost: int = 1000,
252 nskip: int = 100,
253 keepevery: int | None = None,
254 printevery: int | None = 100,
255 mc_cores: int = 2,
256 seed: int | Key[Array, ''] = 0,
257 bart_kwargs: Mapping = MappingProxyType({}),
258 ):
259 kwargs: dict = dict( 1(GH)IJKLM*NO#xy$zA;UV+mn,PQ-RS2kl4opBbcXed3ij5qr%678Zh9stYfg'TCDE!uvwWF
260 x_train=x_train,
261 y_train=y_train,
262 x_test=x_test,
263 type=type,
264 sparse=sparse,
265 theta=theta,
266 a=a,
267 b=b,
268 rho=rho,
269 xinfo=xinfo,
270 usequants=usequants,
271 rm_const=rm_const,
272 sigest=sigest,
273 sigdf=sigdf,
274 sigquant=sigquant,
275 k=k,
276 power=power,
277 base=base,
278 lamda=lamda,
279 tau_num=tau_num,
280 offset=offset,
281 w=w,
282 ntree=ntree,
283 numcut=numcut,
284 ndpost=ndpost,
285 nskip=nskip,
286 keepevery=keepevery,
287 printevery=printevery,
288 seed=seed,
289 maxdepth=8,
290 **process_mc_cores(y_train, mc_cores),
291 )
292 kwargs.update(bart_kwargs) 1(GH)IJKLM*NO#xy$zA;UV+mn,PQ-RS2kl4opBbcXed3ij5qr%678Zh9stYfg'TCDE!uvwWF
293 self._bart = Bart(**kwargs) 1(GH)IJKLM*NO#xy$zA;UV+mn,PQ-RS2kl4opBbcXed3ij5qr%678Zh9stYfg'TCDE!uvwWF
295 # Public attributes from Bart
297 @property
298 def ndpost(self) -> int:
299 """The number of MCMC samples saved, after burn-in."""
300 return self._bart.ndpost 101.=/BbcXed9st
302 @property
303 def offset(self) -> Float32[Array, '']:
304 """The prior mean of the latent mean function."""
305 return self._bart.offset 10:12kl4opBbcXed%678Zh
307 @property
308 def sigest(self) -> Float32[Array, ''] | None:
309 """The estimated standard deviation of the error used to set `lamda`."""
310 return self._bart.sigest 12kl4opZh'C
312 @property
313 def yhat_test(self) -> Float32[Array, 'ndpost m'] | None:
314 """The conditional posterior mean at `x_test` for each MCMC iteration."""
315 return self._bart.yhat_test 10:1BbcXedZhYfg
317 # Private attributes from Bart
319 @property
320 def _main_trace(self) -> mcmcloop.MainTrace:
321 return self._bart._main_trace # noqa: SLF001 1(GH)IJ?@[KLM*NO#xy$zA,PQ-RS2kl4opBbcXed3ij5qr%678Zh9stYfg'TCDE!uvwF
323 @property
324 def _burnin_trace(self) -> mcmcloop.BurninTrace:
325 return self._bart._burnin_trace # noqa: SLF001 1?@[Yfg
327 @property
328 def _mcmc_state(self) -> mcmcstep.State:
329 return self._bart._mcmc_state # noqa: SLF001 1(GH)IJ0:1.=/?@[^K_L`M*NO#xy$zA+mn,PQ-RS2kl4opBbcXed3ij5qr%678Zh9stYfg'TCDE!uvwF
331 @property
332 def _splits(self) -> Real[Array, 'p max_num_splits']:
333 return self._bart._splits # noqa: SLF001 10:1F
335 @property
336 def _x_train_fmt(self) -> Any:
337 return self._bart._x_train_fmt # noqa: SLF001
339 # Cached properties from Bart
341 @cached_property
342 def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
343 """The posterior probability of y being True at `x_test` for each MCMC iteration."""
344 return self._bart.prob_test 1:BbceYfg
346 @cached_property
347 def prob_test_mean(self) -> Float32[Array, ' m'] | None:
348 """The marginal posterior probability of y being True at `x_test`."""
349 return self._bart.prob_test_mean 1Bbce
351 @cached_property
352 def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
353 """The posterior probability of y being True at `x_train` for each MCMC iteration."""
354 return self._bart.prob_train 1=BbceYfg
356 @cached_property
357 def prob_train_mean(self) -> Float32[Array, ' n'] | None:
358 """The marginal posterior probability of y being True at `x_train`."""
359 return self._bart.prob_train_mean 1BbceYfg
361 @cached_property
362 def sigma(
363 self,
364 ) -> (
365 Float32[Array, ' nskip+ndpost']
366 | Float32[Array, 'nskip+ndpost/mc_cores mc_cores']
367 | None
368 ):
369 """The standard deviation of the error, including burn-in samples."""
370 return self._bart.sigma 1./BbcXd3ijZhYfg
372 @cached_property
373 def sigma_(self) -> Float32[Array, 'ndpost'] | None:
374 """The standard deviation of the error, only over the post-burnin samples and flattened."""
375 return self._bart.sigma_ 101BbcXdYfg
377 @cached_property
378 def sigma_mean(self) -> Float32[Array, ''] | None:
379 """The mean of `sigma`, only over the post-burnin samples."""
380 return self._bart.sigma_mean 101BbcXdZh
382 @cached_property
383 def varcount(self) -> Int32[Array, 'ndpost p']:
384 """Histogram of predictor usage for decision rules in the trees."""
385 return self._bart.varcount 1:./BbcXed678Yfg
387 @cached_property
388 def varcount_mean(self) -> Float32[Array, ' p']:
389 """Average of `varcount` across MCMC iterations."""
390 return self._bart.varcount_mean 101BbcXed
392 @cached_property
393 def varprob(self) -> Float32[Array, 'ndpost p']:
394 """Posterior samples of the probability of choosing each predictor for a decision rule."""
395 return self._bart.varprob 1./BbcXedYfg!uvw
397 @cached_property
398 def varprob_mean(self) -> Float32[Array, ' p']:
399 """The marginal posterior probability of each predictor being chosen for a decision rule."""
400 return self._bart.varprob_mean 101BbcXedDE!uvw
402 @cached_property
403 def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
404 """The marginal posterior mean at `x_test`.
406 Not defined with binary regression because it's error-prone, typically
407 the right thing to consider would be `prob_test_mean`.
408 """
409 return self._bart.yhat_test_mean 101BbcXdZh
411 @cached_property
412 def yhat_train(self) -> Float32[Array, 'ndpost n']:
413 """The conditional posterior mean at `x_train` for each MCMC iteration."""
414 return self._bart.yhat_train 1.=/$zA+mn2klBbcXed3ij5qrZh9stYfg
416 @cached_property
417 def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
418 """The marginal posterior mean at `x_train`.
420 Not defined with binary regression because it's error-prone, typically
421 the right thing to consider would be `prob_train_mean`.
422 """
423 return self._bart.yhat_train_mean 101BbcXdZhYfg
425 # Public methods from Bart
427 def predict(
428 self, x_test: Real[Array, 'p m'] | DataFrame
429 ) -> Float32[Array, 'ndpost m']:
430 """
431 Compute the posterior mean at `x_test` for each MCMC iteration.
433 Parameters
434 ----------
435 x_test
436 The test predictors.
438 Returns
439 -------
440 The conditional posterior mean at `x_test` for each MCMC iteration.
441 """
442 return self._bart.predict(x_test) 1#xy3ij5qr
445class gbart(mc_gbart):
446 """Subclass of `mc_gbart` that forces `mc_cores=1`."""
448 def __init__(self, *args, **kwargs):
449 if 'mc_cores' in kwargs: 449 ↛ 452line 449 didn't jump to line 452 because the condition on line 449 was always true1]
450 msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'" 1]
451 raise TypeError(msg) 1]
452 kwargs.update(mc_cores=1)
453 super().__init__(*args, **kwargs)
456def process_mc_cores(y_train: Array | Any, mc_cores: int) -> dict[str, Any]:
457 """Determine the arguments to pass to `Bart` to configure multiple chains."""
458 # one chain, leave default configuration which is num_chains=None
459 if abs(mc_cores) == 1: 1(GH)IJKLM*NO#xy$zA;UV+mn,PQ-RS2kl4opBbcXed3ij5qr%678Zh9stYfg'TCDE!uvwWF
460 return {} 1()*#$;+,-24BX35%678Z9Y'!
462 # determine if we are on cpu; this point may raise an exception
463 platform = get_platform(y_train, mc_cores) 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
465 # set the num_chains argument
466 mc_cores = abs(mc_cores) 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
467 kwargs = dict(num_chains=mc_cores) 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
469 # if on cpu, try to shard the chains across multiple virtual cpus
470 if platform == 'cpu': 470 ↛ 502line 470 didn't jump to line 502 because the condition on line 470 was always true1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
471 # determine number of logical cpu cores
472 num_cores = cpu_count() 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
473 assert num_cores is not None, 'could not determine number of cpu cores' 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
475 # determine number of shards that evenly divides chains
476 for num_shards in range(num_cores, 0, -1): 476 ↛ 481line 476 didn't jump to line 481 because the loop on line 476 didn't complete1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
477 if mc_cores % num_shards == 0: 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
478 break 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
480 # handle the case where there are less jax cpu devices that that
481 if num_shards > 1: 481 ↛ 499line 481 didn't jump to line 499 because the condition on line 481 was always true1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
482 num_jax_cpus = device_count('cpu') 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
483 if num_jax_cpus < num_shards: 483 ↛ 484line 483 didn't jump to line 484 because the condition on line 483 was never true1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
484 for new_num_shards in range(num_jax_cpus, 0, -1):
485 if mc_cores % new_num_shards == 0:
486 break
487 msg = (
488 f'`mc_gbart` would like to shard {mc_cores} chains across '
489 f'{num_shards} virtual jax cpu devices, but jax is set up '
490 f'with only {num_jax_cpus} cpu devices, so it will use '
491 f'{new_num_shards} devices instead. To enable '
492 'parallelization, please increase the limit with '
493 '`jax.config.update("jax_num_cpu_devices", <num_devices>)`.'
494 )
495 warn(msg)
496 num_shards = new_num_shards
498 # set the number of shards
499 if num_shards > 1: 499 ↛ 502line 499 didn't jump to line 502 because the condition on line 499 was always true1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
500 kwargs.update(num_chain_devices=num_shards) 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
502 return kwargs 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
505def get_platform(y_train: Array | Any, mc_cores: int) -> str:
506 """Get the platform for `process_mc_cores` from `y_train` or the default device."""
507 if isinstance(y_train, Array) and hasattr(y_train, 'platform'): 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
508 return y_train.platform() 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF
509 elif ( 509 ↛ 517line 509 didn't jump to line 517 because the condition on line 509 was always true1mnij
510 not isinstance(y_train, Array) and hasattr(jnp.zeros(()), 'platform')
511 # this condition means: y_train is not an array, but we are not under
512 # jit, so y_train is going to be converted to an array on the default
513 # device
514 ) or mc_cores < 0:
515 return get_default_device().platform 1mnij
516 else:
517 msg = (
518 'Could not determine the platform from `y_train`, maybe `mc_gbart` '
519 'was used with a `jax.jit`ted function? The platform is needed to '
520 'determine whether the computation is going to run on CPU to '
521 'automatically shard the chains across multiple virtual CPU '
522 'devices. To acknowledge this problem and circumvent it '
523 'by using the current default jax device, negate `mc_cores`.'
524 )
525 raise RuntimeError(msg)