Coverage for src / bartz / BART / _gbart.py: 93%
135 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/BART/_gbart.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package."""
27from collections.abc import Hashable, Mapping
28from functools import cached_property
29from os import cpu_count
30from types import MappingProxyType
31from typing import Any, Literal
32from warnings import warn
34from equinox import Module
35from jax import device_count
36from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real
38from bartz import mcmcloop, mcmcstep
39from bartz._interface import Bart, DataFrame, FloatLike, Series
40from bartz.jaxext import get_default_device, jit_active
43class mc_gbart(Module):
44 R"""
45 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
47 Regress `y_train` on `x_train` with a latent mean function represented as
48 a sum of decision trees. The inference is carried out by sampling the
49 posterior distribution of the tree ensemble with an MCMC.
51 Parameters
52 ----------
53 x_train
54 The training predictors.
55 y_train
56 The training responses.
57 x_test
58 The test predictors.
59 type
60 The type of regression. 'wbart' for continuous regression, 'pbart' for
61 binary regression with probit link.
62 sparse
63 Whether to activate variable selection on the predictors as done in
64 [1]_.
65 theta
66 a
67 b
68 rho
69 Hyperparameters of the sparsity prior used for variable selection.
71 The prior distribution on the choice of predictor for each decision rule
72 is
74 .. math::
75 (s_1, \ldots, s_p) \sim
76 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
78 If `theta` is not specified, it's a priori distributed according to
80 .. math::
81 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
82 \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
84 If not specified, `rho` is set to the number of predictors p. To tune
85 the prior, consider setting a lower `rho` to prefer more sparsity.
86 If setting `theta` directly, it should be in the ballpark of p or lower
87 as well.
88 varprob
89 The probability distribution over the `p` predictors for choosing a
90 predictor to split on in a decision node a priori. Must be > 0. It does
91 not need to be normalized to sum to 1. If not specified, use a uniform
92 distribution. If ``sparse=True``, this is used as initial value for the
93 MCMC.
94 xinfo
95 A matrix with the cutpoins to use to bin each predictor. If not
96 specified, it is generated automatically according to `usequants` and
97 `numcut`.
99 Each row shall contain a sorted list of cutpoints for a predictor. If
100 there are less cutpoints than the number of columns in the matrix,
101 fill the remaining cells with NaN.
103 `xinfo` shall be a matrix even if `x_train` is a dataframe.
104 usequants
105 Whether to use predictors quantiles instead of a uniform grid to bin
106 predictors. Ignored if `xinfo` is specified.
107 rm_const
108 How to treat predictors with no associated decision rules (i.e., there
109 are no available cutpoints for that predictor). If `True` (default),
110 they are ignored. If `False`, an error is raised if there are any.
111 sigest
112 An estimate of the residual standard deviation on `y_train`, used to set
113 `lamda`. If not specified, it is estimated by linear regression (with
114 intercept, and without taking into account `w`). If `y_train` has less
115 than two elements, it is set to 1. If n <= p, it is set to the standard
116 deviation of `y_train`. Ignored if `lamda` is specified.
117 sigdf
118 The degrees of freedom of the scaled inverse-chisquared prior on the
119 noise variance.
120 sigquant
121 The quantile of the prior on the noise variance that shall match
122 `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
123 k
124 The inverse scale of the prior standard deviation on the latent mean
125 function, relative to half the observed range of `y_train`. If `y_train`
126 has less than two elements, `k` is ignored and the scale is set to 1.
127 power
128 base
129 Parameters of the prior on tree node generation. The probability that a
130 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
131 power``.
132 lamda
133 The prior harmonic mean of the error variance. (The harmonic mean of x
134 is 1/mean(1/x).) If not specified, it is set based on `sigest` and
135 `sigquant`.
136 tau_num
137 The numerator in the expression that determines the prior standard
138 deviation of leaves. If not specified, default to ``(max(y_train) -
139 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
140 continuous regression, and 3 for binary regression.
141 offset
142 The prior mean of the latent mean function. If not specified, it is set
143 to the mean of `y_train` for continuous regression, and to
144 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
145 `offset` is set to 0. With binary regression, if `y_train` is all
146 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
147 ``Phi^-1(n/(n+1))``, respectively.
148 w
149 Coefficients that rescale the error standard deviation on each
150 datapoint. Not specifying `w` is equivalent to setting it to 1 for all
151 datapoints. Note: `w` is ignored in the automatic determination of
152 `sigest`, so either the weights should be O(1), or `sigest` should be
153 specified by the user.
154 ntree
155 The number of trees used to represent the latent mean function. By
156 default 200 for continuous regression and 50 for binary regression.
157 numcut
158 If `usequants` is `False`: the exact number of cutpoints used to bin the
159 predictors, ranging between the minimum and maximum observed values
160 (excluded).
162 If `usequants` is `True`: the maximum number of cutpoints to use for
163 binning the predictors. Each predictor is binned such that its
164 distribution in `x_train` is approximately uniform across bins. The
165 number of bins is at most the number of unique values appearing in
166 `x_train`, or ``numcut + 1``.
168 Before running the algorithm, the predictors are compressed to the
169 smallest integer type that fits the bin indices, so `numcut` is best set
170 to the maximum value of an unsigned integer type, like 255.
172 Ignored if `xinfo` is specified.
173 ndpost
174 The number of MCMC samples to save, after burn-in. `ndpost` is the
175 total number of samples across all chains. `ndpost` is rounded up to the
176 first multiple of `mc_cores`.
177 nskip
178 The number of initial MCMC samples to discard as burn-in. This number
179 of samples is discarded from each chain.
180 keepevery
181 The thinning factor for the MCMC samples, after burn-in. By default, 1
182 for continuous regression and 10 for binary regression.
183 printevery
184 The number of iterations (including thinned-away ones) between each log
185 line. Set to `None` to disable logging. ^C interrupts the MCMC only
186 every `printevery` iterations, so with logging disabled it's impossible
187 to kill the MCMC conveniently.
188 mc_cores
189 The number of independent MCMC chains.
190 seed
191 The seed for the random number generator.
192 bart_kwargs
193 Additional arguments passed to `bartz.Bart`.
195 Notes
196 -----
197 This interface imitates the function ``mc_gbart`` from the R package `BART3
198 <https://github.com/rsparapa/bnptools>`_, but with these differences:
200 - If `x_train` and `x_test` are matrices, they have one predictor per row
201 instead of per column.
202 - If ``usequants=False``, R BART3 switches to quantiles anyway if there are
203 less predictor values than the required number of bins, while bartz
204 always follows the specification.
205 - Some functionality is missing.
206 - The error variance parameter is called `lamda` instead of `lambda`.
207 - There are some additional attributes, and some missing.
208 - The trees have a maximum depth of 6.
209 - `rm_const` refers to predictors without decision rules instead of
210 predictors that are constant in `x_train`.
211 - If `rm_const=True` and some variables are dropped, the predictors
212 matrix/dataframe passed to `predict` should still include them.
214 References
215 ----------
216 .. [1] Linero, Antonio R. (2018). "Bayesian Regression Trees for
217 High-Dimensional Prediction and Variable Selection". In: Journal of the
218 American Statistical Association 113.522, pp. 626-636.
219 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
220 Bayesian additive regression trees," The Annals of Applied Statistics,
221 Ann. Appl. Stat. 4(1), 266-298, (March 2010).
222 """
224 _bart: Bart
226 def __init__(
227 self,
228 x_train: Real[Array, 'p n'] | DataFrame,
229 y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
230 *,
231 x_test: Real[Array, 'p m'] | DataFrame | None = None,
232 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
233 sparse: bool = False,
234 theta: FloatLike | None = None,
235 a: FloatLike = 0.5,
236 b: FloatLike = 1.0,
237 rho: FloatLike | None = None,
238 varprob: Float[Array, ' p'] | None = None,
239 xinfo: Float[Array, 'p n'] | None = None,
240 usequants: bool = False,
241 rm_const: bool = True,
242 sigest: FloatLike | None = None,
243 sigdf: FloatLike = 3.0,
244 sigquant: FloatLike = 0.9,
245 k: FloatLike = 2.0,
246 power: FloatLike = 2.0,
247 base: FloatLike = 0.95,
248 lamda: FloatLike | None = None,
249 tau_num: FloatLike | None = None,
250 offset: FloatLike | None = None,
251 w: Float[Array, ' n'] | None = None,
252 ntree: int | None = None,
253 numcut: int = 100,
254 ndpost: int = 1000,
255 nskip: int = 100,
256 keepevery: int | None = None,
257 printevery: int | None = 100,
258 mc_cores: int = 2,
259 seed: int | Key[Array, ''] = 0,
260 bart_kwargs: Mapping = MappingProxyType({}),
261 ) -> None:
262 # set defaults that depend on type of regression
263 if keepevery is None: 1!Nnk$Ss=5I4tZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr
264 keepevery = 10 if type == 'pbart' else 1 1!nk$s=I7g*z%uAB[O?J'm]W+C,D(v:GwbPd9j#o;-./8hyc@MpqHr
265 if ntree is None: 1!Nnk$Ss=5I4tZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr
266 ntree = 50 if type == 'pbart' else 200 1kT(XvpqHr
268 # set most calling arguments for Bart
269 kwargs: dict = dict( 1!Nnk$Ss=5ItZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr
270 x_train=x_train,
271 y_train=y_train,
272 x_test=x_test,
273 type=type,
274 sparse=sparse,
275 theta=theta,
276 a=a,
277 b=b,
278 rho=rho,
279 varprob=varprob,
280 xinfo=xinfo,
281 usequants=usequants,
282 rm_const=rm_const,
283 sigest=sigest,
284 sigdf=sigdf,
285 sigquant=sigquant,
286 k=k,
287 power=power,
288 base=base,
289 lamda=lamda,
290 tau_num=tau_num,
291 offset=offset,
292 w=w,
293 num_trees=ntree,
294 numcut=numcut,
295 ndpost=ndpost,
296 nskip=nskip,
297 keepevery=keepevery,
298 printevery=printevery,
299 seed=seed,
300 maxdepth=6,
301 **process_mc_cores(y_train, mc_cores),
302 )
304 # set min_points_per_leaf unless the user set it already
305 if 'min_points_per_leaf' not in bart_kwargs.get('init_kw', {}): 1!Nnk$Ss=5ItZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr
306 bart_kwargs = dict(bart_kwargs) 1!nk$s=I7g*z%uAB?J'm+C(vwbPd9j#o8h)xycpqHr
307 init_kw = dict(bart_kwargs.get('init_kw', {})) 1!nk$s=I7g*z%uAB?J'm+C(vwbPd9j#o8h)xycpqHr
308 init_kw['min_points_per_leaf'] = 5 1!nk$s=I7g*z%uAB?J'm+C(vwbPd9j#o8h)xycpqHr
309 bart_kwargs['init_kw'] = init_kw 1!nk$s=I7g*z%uAB?J'm+C(vwbPd9j#o8h)xycpqHr
311 # add user arguments
312 kwargs.update(bart_kwargs) 1!Nnk$Ss=5ItZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr
314 # invoke Bart
315 self._bart = Bart(**kwargs) 1!Nnk$Ss=5ItZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr
317 # Public attributes from Bart
319 @property
320 def ndpost(self) -> int:
321 """The number of MCMC samples saved, after burn-in."""
322 return self._bart.ndpost 1^_`}{webPid;-./)Rx
324 @property
325 def offset(self) -> Float32[Array, '']:
326 """The prior mean of the latent mean function."""
327 return self._bart.offset 1^|_,FD:LGwebPid;-./8hyfc
329 @property
330 def sigest(self) -> Float32[Array, ''] | None:
331 """The estimated standard deviation of the error used to set `lamda`."""
332 return self._bart.sigest 1,FD:LG8hyfc@M
334 @property
335 def yhat_test(self) -> Float32[Array, 'ndpost m'] | None:
336 """The conditional posterior mean at `x_test` for each MCMC iteration."""
337 return self._bart.yhat_test 1^|_7lgwebPid8hyfc
339 # Private attributes from Bart
341 @property
342 def _main_trace(self) -> mcmcloop.MainTrace:
343 return self._bart._main_trace # noqa: SLF001 2! N n k $ S s ~ abbbt Z T 7 l g * 0 z % U u A 1 B [ V O ' K m ] 2 W + 3 C , F D ( X v : L G w e b P i d 9 E j # Q o ; - . / 8 h ) R x y f c @ Y M p q r
345 @property
346 def _burnin_trace(self) -> mcmcloop.BurninTrace:
347 return self._bart._burnin_trace # noqa: SLF001 2~ abbb, F D : L G ; - . / y f c @ Y M
349 @property
350 def _mcmc_state(self) -> mcmcstep.State:
351 return self._bart._mcmc_state # noqa: SLF001 2! N n k $ S s ^ | _ ` } { ~ abbbdbt ebZ fbT 7 l g * 0 z % U u A 1 B [ V O ' K m ] 2 W + 3 C , F D ( X v : L G w e b P i d 9 E j # Q o ; - . / 8 h ) R x y f c @ Y M p q r
353 @property
354 def _splits(self) -> Real[Array, 'p max_num_splits']:
355 return self._bart._splits # noqa: SLF001 1^|_r
357 @property
358 def _x_train_fmt(self) -> Hashable:
359 return self._bart._x_train_fmt # noqa: SLF001
361 # Cached properties from Bart
363 @cached_property
364 def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
365 """The posterior probability of y being True at `x_test` for each MCMC iteration."""
366 return self._bart.prob_test 1|7lgwebiyfc
368 @cached_property
369 def prob_test_mean(self) -> Float32[Array, ' m'] | None:
370 """The marginal posterior probability of y being True at `x_test`."""
371 return self._bart.prob_test_mean 1webi
373 @cached_property
374 def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
375 """The posterior probability of y being True at `x_train` for each MCMC iteration."""
376 return self._bart.prob_train 1}7lgwebiyfc
378 @cached_property
379 def prob_train_mean(self) -> Float32[Array, ' n'] | None:
380 """The marginal posterior probability of y being True at `x_train`."""
381 return self._bart.prob_train_mean 1webiyfc
383 @cached_property
384 def sigma(
385 self,
386 ) -> (
387 Float32[Array, ' nskip+ndpost']
388 | Float32[Array, 'nskip+ndpost/mc_cores mc_cores']
389 | None
390 ):
391 """The standard deviation of the error, including burn-in samples."""
392 return self._bart.sigma 1`{7lgwebPd9Ej8hyfc
394 @cached_property
395 def sigma_(self) -> Float32[Array, 'ndpost'] | None:
396 """The standard deviation of the error, only over the post-burnin samples and flattened."""
397 return self._bart.sigma_ 1^_7lgwebPdyfc
399 @cached_property
400 def sigma_mean(self) -> Float32[Array, ''] | None:
401 """The mean of `sigma`, only over the post-burnin samples."""
402 return self._bart.sigma_mean 1^_webPd8hyfc
404 @cached_property
405 def varcount(self) -> Int32[Array, 'ndpost p']:
406 """Histogram of predictor usage for decision rules in the trees."""
407 return self._bart.varcount 1|`{7lgwebPid-./yfc
409 @cached_property
410 def varcount_mean(self) -> Float32[Array, ' p']:
411 """Average of `varcount` across MCMC iterations."""
412 return self._bart.varcount_mean 1$Ss^_webPidyfc
414 @cached_property
415 def varprob(self) -> Float32[Array, 'ndpost p']:
416 """Posterior samples of the probability of choosing each predictor for a decision rule."""
417 return self._bart.varprob 1!Nnk`{7lgwebPidyfc
419 @cached_property
420 def varprob_mean(self) -> Float32[Array, ' p']:
421 """The marginal posterior probability of each predictor being chosen for a decision rule."""
422 return self._bart.varprob_mean 1!Nnk^_webPidyfcpq
424 @cached_property
425 def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
426 """The marginal posterior mean at `x_test`.
428 Not defined with binary regression because it's error-prone, typically
429 the right thing to consider would be `prob_test_mean`.
430 """
431 return self._bart.yhat_test_mean 1^_webPd8hyfc
433 @cached_property
434 def yhat_train(self) -> Float32[Array, 'ndpost n']:
435 """The conditional posterior mean at `x_train` for each MCMC iteration."""
436 return self._bart.yhat_train 1`}{7lg[VO'Km,FDwebPid9Ej#Qo8h)Rxyfc
438 @cached_property
439 def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
440 """The marginal posterior mean at `x_train`.
442 Not defined with binary regression because it's error-prone, typically
443 the right thing to consider would be `prob_train_mean`.
444 """
445 return self._bart.yhat_train_mean 1^_webPd8hyfc
447 # Public methods from Bart
449 def predict(
450 self, x_test: Real[Array, 'p m'] | DataFrame
451 ) -> Float32[Array, 'ndpost m']:
452 """
453 Compute the posterior mean at `x_test` for each MCMC iteration.
455 Parameters
456 ----------
457 x_test
458 The test predictors.
460 Returns
461 -------
462 The conditional posterior mean at `x_test` for each MCMC iteration.
463 """
464 return self._bart.predict(x_test) 1%Uu9Ej#Qo
467class gbart(mc_gbart):
468 """Subclass of `mc_gbart` that forces `mc_cores=1`."""
470 def __init__(self, *args: Any, **kwargs: Any) -> None:
471 if 'mc_cores' in kwargs: 471 ↛ 474line 471 didn't jump to line 474 because the condition on line 471 was always true2cb
472 msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'" 2cb
473 raise TypeError(msg) 2cb
474 kwargs.update(mc_cores=1)
475 super().__init__(*args, **kwargs)
478def process_mc_cores(y_train: Array | Series, mc_cores: int) -> dict[str, Any]:
479 """Determine the arguments to pass to `Bart` to configure multiple chains."""
480 # one chain, disable multichain altogether
481 if abs(mc_cores) == 1: 1!Nnk$Ss=5I4tZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr
482 return dict(num_chains=None) 1!$=7*%[?']+,(:wP9#;-./8)y@
484 # determine if we are on cpu; this point may raise an exception
485 platform = get_platform(y_train, mc_cores) 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
487 # set the num_chains argument
488 mc_cores = abs(mc_cores) 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
489 kwargs = dict(num_chains=mc_cores) 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
491 # if on cpu, try to shard the chains across multiple virtual cpus
492 if platform == 'cpu': 492 ↛ 524line 492 didn't jump to line 524 because the condition on line 492 was always true1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
493 # determine number of logical cpu cores
494 num_cores = cpu_count() 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
495 assert num_cores is not None, 'could not determine number of cpu cores' 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
497 # determine number of shards that evenly divides chains
498 for num_shards in range(num_cores, 0, -1): 498 ↛ 503line 498 didn't jump to line 503 because the loop on line 498 didn't complete1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
499 if mc_cores % num_shards == 0: 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
500 break 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
502 # handle the case where there are less jax cpu devices that that
503 if num_shards > 1: 503 ↛ 521line 503 didn't jump to line 521 because the condition on line 503 was always true1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
504 num_jax_cpus = device_count('cpu') 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
505 if num_jax_cpus < num_shards: 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
506 for new_num_shards in range(num_jax_cpus, 0, -1): 506 ↛ 509line 506 didn't jump to line 509 because the loop on line 506 didn't complete14t
507 if mc_cores % new_num_shards == 0: 507 ↛ 506line 507 didn't jump to line 506 because the condition on line 507 was always true14t
508 break 14t
509 msg = ( 14t
510 f'`mc_gbart` would like to shard {mc_cores} chains across '
511 f'{num_shards} virtual jax cpu devices, but jax is set up '
512 f'with only {num_jax_cpus} cpu devices, so it will use '
513 f'{new_num_shards} devices instead. To enable '
514 'parallelization, please increase the limit with '
515 '`jax.config.update("jax_num_cpu_devices", <num_devices>)`.'
516 )
517 warn(msg) 14t
518 num_shards = new_num_shards 14t
520 # set the number of shards
521 if num_shards > 1: 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
522 kwargs.update(num_chain_devices=num_shards) 1NnkSs5ItZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
524 return kwargs 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
527def get_platform(y_train: Array | Series, mc_cores: int) -> str:
528 """Get the platform for `process_mc_cores` from `y_train` or the default device."""
529 if isinstance(y_train, Array) and hasattr(y_train, 'platform'): 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
530 return y_train.platform() 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr
531 elif ( 531 ↛ 539line 531 didn't jump to line 539 because the condition on line 531 was always true1KmEj
532 not isinstance(y_train, Array) and not jit_active()
533 # this condition means: y_train is not an array, but we are not under
534 # jit, so y_train is going to be converted to an array on the default
535 # device
536 ) or mc_cores < 0:
537 return get_default_device().platform 1KmEj
538 else:
539 msg = (
540 'Could not determine the platform from `y_train`, maybe `mc_gbart` '
541 'was used with a `jax.jit`ted function? The platform is needed to '
542 'determine whether the computation is going to run on CPU to '
543 'automatically shard the chains across multiple virtual CPU '
544 'devices. To acknowledge this problem and circumvent it '
545 'by using the current default jax device, negate `mc_cores`.'
546 )
547 raise RuntimeError(msg)