Coverage for src / bartz / BART / _gbart.py: 94%

160 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 18:11 +0000

1# bartz/src/bartz/BART/_gbart.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package.""" 

26 

27from collections.abc import Hashable, Mapping 

28from functools import cached_property 

29from os import cpu_count 

30from types import MappingProxyType 

31from typing import Any, Literal 

32from warnings import warn 

33 

34import jax.numpy as jnp 

35from equinox import Module 

36from jax import device_count 

37from jax.scipy.special import ndtr 

38from jaxtyping import Array, Float, Float32, Int32, Key, Real 

39 

40from bartz import mcmcloop, mcmcstep 

41from bartz._interface import Bart, DataFrame, FloatLike, PredictKind, Series 

42from bartz.jaxext import get_default_device, jit_active 

43 

44 

45class mc_gbart(Module): 

46 R""" 

47 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_. 

48 

49 Regress `y_train` on `x_train` with a latent mean function represented as 

50 a sum of decision trees. The inference is carried out by sampling the 

51 posterior distribution of the tree ensemble with an MCMC. 

52 

53 Parameters 

54 ---------- 

55 x_train 

56 The training predictors. 

57 y_train 

58 The training responses. 

59 x_test 

60 The test predictors. 

61 type 

62 The type of regression. 'wbart' for continuous regression, 'pbart' for 

63 binary regression with probit link. 

64 sparse 

65 Whether to activate variable selection on the predictors as done in 

66 [1]_. 

67 theta 

68 a 

69 b 

70 rho 

71 Hyperparameters of the sparsity prior used for variable selection. 

72 

73 The prior distribution on the choice of predictor for each decision rule 

74 is 

75 

76 .. math:: 

77 (s_1, \ldots, s_p) \sim 

78 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p). 

79 

80 If `theta` is not specified, it's a priori distributed according to 

81 

82 .. math:: 

83 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim 

84 \operatorname{Beta}(\mathtt{a}, \mathtt{b}). 

85 

86 If not specified, `rho` is set to the number of predictors p. To tune 

87 the prior, consider setting a lower `rho` to prefer more sparsity. 

88 If setting `theta` directly, it should be in the ballpark of p or lower 

89 as well. 

90 varprob 

91 The probability distribution over the `p` predictors for choosing a 

92 predictor to split on in a decision node a priori. Must be > 0. It does 

93 not need to be normalized to sum to 1. If not specified, use a uniform 

94 distribution. If ``sparse=True``, this is used as initial value for the 

95 MCMC. 

96 xinfo 

97 A matrix with the cutpoins to use to bin each predictor. If not 

98 specified, it is generated automatically according to `usequants` and 

99 `numcut`. 

100 

101 Each row shall contain a sorted list of cutpoints for a predictor. If 

102 there are less cutpoints than the number of columns in the matrix, 

103 fill the remaining cells with NaN. 

104 

105 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

106 usequants 

107 Whether to use predictors quantiles instead of a uniform grid to bin 

108 predictors. Ignored if `xinfo` is specified. 

109 rm_const 

110 How to treat predictors with no associated decision rules (i.e., there 

111 are no available cutpoints for that predictor). If `True` (default), 

112 they are ignored. If `False`, an error is raised if there are any. 

113 sigest 

114 An estimate of the residual standard deviation on `y_train`, used to set 

115 `lamda`. If not specified, it is estimated by linear regression (with 

116 intercept, and without taking into account `w`). If `y_train` has less 

117 than two elements, it is set to 1. If n <= p, it is set to the standard 

118 deviation of `y_train`. Ignored if `lamda` is specified. 

119 sigdf 

120 The degrees of freedom of the scaled inverse-chisquared prior on the 

121 noise variance. 

122 sigquant 

123 The quantile of the prior on the noise variance that shall match 

124 `sigest` to set the scale of the prior. Ignored if `lamda` is specified. 

125 k 

126 The inverse scale of the prior standard deviation on the latent mean 

127 function, relative to half the observed range of `y_train`. If `y_train` 

128 has less than two elements, `k` is ignored and the scale is set to 1. 

129 power 

130 base 

131 Parameters of the prior on tree node generation. The probability that a 

132 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) ** 

133 power``. 

134 lamda 

135 The prior harmonic mean of the error variance. (The harmonic mean of x 

136 is 1/mean(1/x).) If not specified, it is set based on `sigest` and 

137 `sigquant`. 

138 tau_num 

139 The numerator in the expression that determines the prior standard 

140 deviation of leaves. If not specified, default to ``(max(y_train) - 

141 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for 

142 continuous regression, and 3 for binary regression. 

143 offset 

144 The prior mean of the latent mean function. If not specified, it is set 

145 to the mean of `y_train` for continuous regression, and to 

146 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty, 

147 `offset` is set to 0. With binary regression, if `y_train` is all 

148 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or 

149 ``Phi^-1(n/(n+1))``, respectively. 

150 w 

151 Coefficients that rescale the error standard deviation on each 

152 datapoint. Not specifying `w` is equivalent to setting it to 1 for all 

153 datapoints. Note: `w` is ignored in the automatic determination of 

154 `sigest`, so either the weights should be O(1), or `sigest` should be 

155 specified by the user. 

156 ntree 

157 The number of trees used to represent the latent mean function. By 

158 default 200 for continuous regression and 50 for binary regression. 

159 numcut 

160 If `usequants` is `False`: the exact number of cutpoints used to bin the 

161 predictors, ranging between the minimum and maximum observed values 

162 (excluded). 

163 

164 If `usequants` is `True`: the maximum number of cutpoints to use for 

165 binning the predictors. Each predictor is binned such that its 

166 distribution in `x_train` is approximately uniform across bins. The 

167 number of bins is at most the number of unique values appearing in 

168 `x_train`, or ``numcut + 1``. 

169 

170 Before running the algorithm, the predictors are compressed to the 

171 smallest integer type that fits the bin indices, so `numcut` is best set 

172 to the maximum value of an unsigned integer type, like 255. 

173 

174 Ignored if `xinfo` is specified. 

175 ndpost 

176 The number of MCMC samples to save, after burn-in. `ndpost` is the 

177 total number of samples across all chains. `ndpost` is rounded up to the 

178 first multiple of `mc_cores`. 

179 nskip 

180 The number of initial MCMC samples to discard as burn-in. This number 

181 of samples is discarded from each chain. 

182 keepevery 

183 The thinning factor for the MCMC samples, after burn-in. By default, 1 

184 for continuous regression and 10 for binary regression. 

185 printevery 

186 The number of iterations (including thinned-away ones) between each log 

187 line. Set to `None` to disable logging. ^C interrupts the MCMC only 

188 every `printevery` iterations, so with logging disabled it's impossible 

189 to kill the MCMC conveniently. 

190 mc_cores 

191 The number of independent MCMC chains. 

192 seed 

193 The seed for the random number generator. 

194 bart_kwargs 

195 Additional arguments passed to `bartz.Bart`. 

196 

197 Notes 

198 ----- 

199 This interface imitates the function ``mc_gbart`` from the R package `BART3 

200 <https://github.com/rsparapa/bnptools>`_, but with these differences: 

201 

202 - If `x_train` and `x_test` are matrices, they have one predictor per row 

203 instead of per column. 

204 - If ``usequants=False``, R BART3 switches to quantiles anyway if there are 

205 less predictor values than the required number of bins, while bartz 

206 always follows the specification. 

207 - Some functionality is missing. 

208 - The error variance parameter is called `lamda` instead of `lambda`. 

209 - There are some additional attributes, and some missing. 

210 - The trees have a maximum depth of 6. 

211 - `rm_const` refers to predictors without decision rules instead of 

212 predictors that are constant in `x_train`. 

213 - If `rm_const=True` and some variables are dropped, the predictors 

214 matrix/dataframe passed to `predict` should still include them. 

215 

216 References 

217 ---------- 

218 .. [1] Linero, Antonio R. (2018). "Bayesian Regression Trees for 

219 High-Dimensional Prediction and Variable Selection". In: Journal of the 

220 American Statistical Association 113.522, pp. 626-636. 

221 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART: 

222 Bayesian additive regression trees," The Annals of Applied Statistics, 

223 Ann. Appl. Stat. 4(1), 266-298, (March 2010). 

224 """ 

225 

226 _bart: Bart 

227 _yhat_test: Float32[Array, 'ndpost m'] | None = None 

228 

229 def __init__( 

230 self, 

231 x_train: Real[Array, 'p n'] | DataFrame, 

232 y_train: Float32[Array, ' n'] | Series, 

233 *, 

234 x_test: Real[Array, 'p m'] | DataFrame | None = None, 

235 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002 

236 sparse: bool = False, 

237 theta: FloatLike | None = None, 

238 a: FloatLike = 0.5, 

239 b: FloatLike = 1.0, 

240 rho: FloatLike | None = None, 

241 varprob: Float[Array, ' p'] | None = None, 

242 xinfo: Float[Array, 'p n'] | None = None, 

243 usequants: bool = False, 

244 rm_const: bool = True, 

245 sigest: FloatLike | None = None, 

246 sigdf: FloatLike = 3.0, 

247 sigquant: FloatLike = 0.9, 

248 k: FloatLike = 2.0, 

249 power: FloatLike = 2.0, 

250 base: FloatLike = 0.95, 

251 lamda: FloatLike | None = None, 

252 tau_num: FloatLike | None = None, 

253 offset: FloatLike | None = None, 

254 w: Float[Array, ' n'] | None = None, 

255 ntree: int | None = None, 

256 numcut: int = 100, 

257 ndpost: int = 1000, 

258 nskip: int = 100, 

259 keepevery: int | None = None, 

260 printevery: int | None = 100, 

261 mc_cores: int = 2, 

262 seed: int | Key[Array, ''] = 0, 

263 bart_kwargs: Mapping = MappingProxyType({}), 

264 ) -> None: 

265 # set defaults that depend on type of regression 

266 if keepevery is None: 1be

267 keepevery = 10 if type == 'pbart' else 1 1e

268 if ntree is None: 1bie

269 ntree = 50 if type == 'pbart' else 200 1min

270 

271 # set most calling arguments for Bart 

272 kwargs: dict = dict( 1bomin

273 x_train=x_train, 

274 y_train=y_train, 

275 outcome_type='binary' if type == 'pbart' else 'continuous', 

276 sparse=sparse, 

277 theta=theta, 

278 a=a, 

279 b=b, 

280 rho=rho, 

281 varprob=varprob, 

282 xinfo=xinfo, 

283 usequants=usequants, 

284 rm_const=rm_const, 

285 sigest=sigest, 

286 sigdf=sigdf, 

287 sigquant=sigquant, 

288 k=k, 

289 power=power, 

290 base=base, 

291 lamda=lamda, 

292 tau_num=tau_num, 

293 offset=offset, 

294 w=w, 

295 num_trees=ntree, 

296 numcut=numcut, 

297 ndpost=ndpost, 

298 nskip=nskip, 

299 keepevery=keepevery, 

300 printevery=printevery, 

301 seed=seed, 

302 maxdepth=6, 

303 **process_mc_cores(y_train, mc_cores), 

304 ) 

305 

306 # set min_points_per_leaf unless the user set it already 

307 if 'min_points_per_leaf' not in bart_kwargs.get('init_kw', {}): 1bog

308 bart_kwargs = dict(bart_kwargs) 1g

309 init_kw = dict(bart_kwargs.get('init_kw', {})) 1g

310 init_kw['min_points_per_leaf'] = 5 1g

311 bart_kwargs['init_kw'] = init_kw 1g

312 

313 # add user arguments 

314 kwargs.update(bart_kwargs) 1b

315 

316 # invoke Bart 

317 self._bart = Bart(**kwargs) 1b

318 

319 # predict at test points 

320 if x_test is not None: 1qb

321 self._yhat_test = self._bart.predict( 1b

322 x_test, kind=PredictKind.latent_samples 

323 ) 

324 

325 # Public attributes from Bart 

326 

327 @property 

328 def ndpost(self) -> int: 

329 """The number of MCMC samples saved, after burn-in.""" 

330 return self._bart.ndpost 1f

331 

332 @property 

333 def offset(self) -> Float32[Array, '']: 

334 """The prior mean of the latent mean function.""" 

335 return self._bart.offset 1c

336 

337 @property 

338 def sigest(self) -> Float32[Array, ''] | None: 

339 """The estimated standard deviation of the error used to set `lamda`.""" 

340 return self._bart.sigest 1r

341 

342 # Private attributes from Bart 

343 

344 @property 

345 def _main_trace(self) -> mcmcloop.MainTrace: 

346 return self._bart._main_trace # noqa: SLF001 1f

347 

348 @property 

349 def _burnin_trace(self) -> mcmcloop.BurninTrace: 

350 return self._bart._burnin_trace # noqa: SLF001 1f

351 

352 @property 

353 def _mcmc_state(self) -> mcmcstep.State: 

354 return self._bart._mcmc_state # noqa: SLF001 1f

355 

356 @property 

357 def _splits(self) -> Real[Array, 'p max_num_splits']: 

358 return self._bart._splits # noqa: SLF001 1c

359 

360 @property 

361 def _x_train_fmt(self) -> Hashable: 

362 return self._bart._x_train_fmt # noqa: SLF001 

363 

364 # Properties 

365 

366 @property 

367 def yhat_test(self) -> Float32[Array, 'ndpost m'] | None: 

368 """The conditional posterior mean at `x_test` for each MCMC iteration.""" 

369 return self._yhat_test 1c

370 

371 @cached_property 

372 def prob_test(self) -> Float32[Array, 'ndpost m'] | None: 

373 """The posterior probability of y being True at `x_test` for each MCMC iteration.""" 

374 if self._yhat_test is None or self._mcmc_state.binary_y is None: 1he

375 return None 1e

376 return ndtr(self._yhat_test) 1h

377 

378 @cached_property 

379 def prob_test_mean(self) -> Float32[Array, ' m'] | None: 

380 """The marginal posterior probability of y being True at `x_test`.""" 

381 if self.prob_test is None: 1ed

382 return None 1e

383 return self.prob_test.mean(axis=0) 1d

384 

385 @cached_property 

386 def prob_train(self) -> Float32[Array, 'ndpost n'] | None: 

387 """The posterior probability of y being True at `x_train` for each MCMC iteration.""" 

388 if self._mcmc_state.binary_y is not None: 1he

389 return ndtr(self.yhat_train) 1h

390 else: 

391 return None 1e

392 

393 @cached_property 

394 def prob_train_mean(self) -> Float32[Array, ' n'] | None: 

395 """The marginal posterior probability of y being True at `x_train`.""" 

396 if self.prob_train is None: 1ed

397 return None 1e

398 else: 

399 return self.prob_train.mean(axis=0) 1d

400 

401 @cached_property 

402 def sigma( 

403 self, 

404 ) -> ( 

405 Float32[Array, ' nskip+ndpost'] 

406 | Float32[Array, 'nskip+ndpost/mc_cores mc_cores'] 

407 | None 

408 ): 

409 """The standard deviation of the error, including burn-in samples.""" 

410 if self._mcmc_state.binary_y is not None: 1fd

411 return None 1d

412 assert self._burnin_trace.error_cov_inv.ndim <= 2 # chains and samples 1f

413 return jnp.sqrt( 1f

414 jnp.reciprocal( 

415 jnp.concatenate( 

416 [ 

417 self._burnin_trace.error_cov_inv.T, 

418 self._main_trace.error_cov_inv.T, 

419 ], 

420 axis=0, 

421 ) 

422 ) 

423 ) 

424 

425 @cached_property 

426 def sigma_(self) -> Float32[Array, 'ndpost'] | None: 

427 """The standard deviation of the error, only over the post-burnin samples and flattened.""" 

428 if self._mcmc_state.binary_y is not None: 1cd

429 return None 1d

430 assert self._main_trace.error_cov_inv.ndim <= 2 # chains and samples 1c

431 return jnp.sqrt(jnp.reciprocal(self._main_trace.error_cov_inv)).reshape(-1) 1c

432 

433 @cached_property 

434 def sigma_mean(self) -> Float32[Array, ''] | None: 

435 """The mean of `sigma`, only over the post-burnin samples.""" 

436 if self.sigma_ is None: 1cd

437 return None 1d

438 return self.sigma_.mean() 1c

439 

440 @cached_property 

441 def varcount(self) -> Int32[Array, 'ndpost p']: 

442 """Histogram of predictor usage for decision rules in the trees.""" 

443 return self._bart.varcount 1f

444 

445 @cached_property 

446 def varcount_mean(self) -> Float32[Array, ' p']: 

447 """Average of `varcount` across MCMC iterations.""" 

448 return self._bart.varcount_mean 1c

449 

450 @cached_property 

451 def varprob(self) -> Float32[Array, 'ndpost p']: 

452 """Posterior samples of the probability of choosing each predictor for a decision rule.""" 

453 return self._bart.varprob 1f

454 

455 @cached_property 

456 def varprob_mean(self) -> Float32[Array, ' p']: 

457 """The marginal posterior probability of each predictor being chosen for a decision rule.""" 

458 return self._bart.varprob_mean 1c

459 

460 @cached_property 

461 def yhat_test_mean(self) -> Float32[Array, ' m'] | None: 

462 """The marginal posterior mean at `x_test`. 

463 

464 Not defined with binary regression because it's error-prone, typically 

465 the right thing to consider would be `prob_test_mean`. 

466 """ 

467 if self._yhat_test is None or self._mcmc_state.binary_y is not None: 1cd

468 return None 1d

469 return self._yhat_test.mean(axis=0) 1c

470 

471 @cached_property 

472 def yhat_train(self) -> Float32[Array, 'ndpost n']: 

473 """The conditional posterior mean at `x_train` for each MCMC iteration.""" 

474 return self._bart.predict('train', kind=PredictKind.latent_samples) 1f

475 

476 @cached_property 

477 def yhat_train_mean(self) -> Float32[Array, ' n'] | None: 

478 """The marginal posterior mean at `x_train`. 

479 

480 Not defined with binary regression because it's error-prone, typically 

481 the right thing to consider would be `prob_train_mean`. 

482 """ 

483 if self._mcmc_state.binary_y is not None: 1cd

484 return None 1d

485 else: 

486 return self.yhat_train.mean(axis=0) 1c

487 

488 # Public methods from Bart 

489 

490 def predict( 

491 self, x_test: Real[Array, 'p m'] | DataFrame 

492 ) -> Float32[Array, 'ndpost m']: 

493 """ 

494 Evaluate the sum-of-trees at `x_test` for each MCMC iteration. 

495 

496 Parameters 

497 ---------- 

498 x_test 

499 The test predictors. 

500 

501 Returns 

502 ------- 

503 Posterior samples of the latent function value at `x_test`. In the continuous case, this is the conditional mean. 

504 """ 

505 return self._bart.predict(x_test, kind=PredictKind.latent_samples) 1s

506 

507 

508class gbart(mc_gbart): 

509 """Subclass of `mc_gbart` that forces `mc_cores=1`.""" 

510 

511 def __init__(self, *args: Any, **kwargs: Any) -> None: 

512 if 'mc_cores' in kwargs: 512 ↛ 515line 512 didn't jump to line 515 because the condition on line 512 was always true1j

513 msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'" 1j

514 raise TypeError(msg) 1j

515 kwargs.update(mc_cores=1) 

516 super().__init__(*args, **kwargs) 

517 

518 

519def process_mc_cores(y_train: Array | Series, mc_cores: int) -> dict[str, Any]: 

520 """Determine the arguments to pass to `Bart` to configure multiple chains.""" 

521 # one chain, disable multichain altogether 

522 if abs(mc_cores) == 1: 1bg

523 return dict(num_chains=None) 1g

524 

525 # determine if we are on cpu; this point may raise an exception 

526 platform = get_platform(y_train, mc_cores) 1b

527 

528 # set the num_chains argument 

529 mc_cores = abs(mc_cores) 1b

530 kwargs = dict(num_chains=mc_cores) 1b

531 

532 # if on cpu, try to shard the chains across multiple virtual cpus 

533 if platform == 'cpu': 533 ↛ 565line 533 didn't jump to line 565 because the condition on line 533 was always true1b

534 # determine number of logical cpu cores 

535 num_cores = cpu_count() 1b

536 assert num_cores is not None, 'could not determine number of cpu cores' 1b

537 

538 # determine number of shards that evenly divides chains 

539 for num_shards in range(num_cores, 0, -1): 539 ↛ 544line 539 didn't jump to line 544 because the loop on line 539 didn't complete1bp

540 if mc_cores % num_shards == 0: 1bp

541 break 1b

542 

543 # handle the case where there are less jax cpu devices that that 

544 if num_shards > 1: 544 ↛ 562line 544 didn't jump to line 562 because the condition on line 544 was always true1b

545 num_jax_cpus = device_count('cpu') 1b

546 if num_jax_cpus < num_shards: 1b

547 for new_num_shards in range(num_jax_cpus, 0, -1): 547 ↛ 550line 547 didn't jump to line 550 because the loop on line 547 didn't complete1b

548 if mc_cores % new_num_shards == 0: 548 ↛ 547line 548 didn't jump to line 547 because the condition on line 548 was always true1b

549 break 1b

550 msg = ( 1b

551 f'`mc_gbart` would like to shard {mc_cores} chains across ' 

552 f'{num_shards} virtual jax cpu devices, but jax is set up ' 

553 f'with only {num_jax_cpus} cpu devices, so it will use ' 

554 f'{new_num_shards} devices instead. To enable ' 

555 'parallelization, please increase the limit with ' 

556 '`jax.config.update("jax_num_cpu_devices", <num_devices>)`.' 

557 ) 

558 warn(msg) 1b

559 num_shards = new_num_shards 1b

560 

561 # set the number of shards 

562 if num_shards > 1: 1b

563 kwargs.update(num_chain_devices=num_shards) 1b

564 

565 return kwargs 1b

566 

567 

568def get_platform(y_train: Array | Series, mc_cores: int) -> str: 

569 """Get the platform for `process_mc_cores` from `y_train` or the default device.""" 

570 if isinstance(y_train, Array) and hasattr(y_train, 'platform'): 1bkl

571 return y_train.platform() 1b

572 elif ( 572 ↛ 580line 572 didn't jump to line 580 because the condition on line 572 was always true1kl

573 not isinstance(y_train, Array) and not jit_active() 

574 # this condition means: y_train is not an array, but we are not under 

575 # jit, so y_train is going to be converted to an array on the default 

576 # device 

577 ) or mc_cores < 0: 

578 return get_default_device().platform 1kl

579 else: 

580 msg = ( 

581 'Could not determine the platform from `y_train`, maybe `mc_gbart` ' 

582 'was used with a `jax.jit`ted function? The platform is needed to ' 

583 'determine whether the computation is going to run on CPU to ' 

584 'automatically shard the chains across multiple virtual CPU ' 

585 'devices. To acknowledge this problem and circumvent it ' 

586 'by using the current default jax device, negate `mc_cores`.' 

587 ) 

588 raise RuntimeError(msg)