Coverage for src / bartz / BART / _gbart.py: 85%

127 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-13 00:35 +0000

1# bartz/src/bartz/BART/_gbart.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package.""" 

26 

27from collections.abc import Mapping 

28from functools import cached_property 

29from os import cpu_count 

30from types import MappingProxyType 

31from typing import Any, Literal 

32from warnings import warn 

33 

34from equinox import Module 

35from jax import device_count 

36from jax import numpy as jnp 

37from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real 

38 

39from bartz import mcmcloop, mcmcstep 

40from bartz._interface import Bart, DataFrame, FloatLike, Series 

41from bartz.jaxext import get_default_device 

42 

43 

44class mc_gbart(Module): 

45 R""" 

46 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_. 

47 

48 Regress `y_train` on `x_train` with a latent mean function represented as 

49 a sum of decision trees. The inference is carried out by sampling the 

50 posterior distribution of the tree ensemble with an MCMC. 

51 

52 Parameters 

53 ---------- 

54 x_train 

55 The training predictors. 

56 y_train 

57 The training responses. 

58 x_test 

59 The test predictors. 

60 type 

61 The type of regression. 'wbart' for continuous regression, 'pbart' for 

62 binary regression with probit link. 

63 sparse 

64 Whether to activate variable selection on the predictors as done in 

65 [1]_. 

66 theta 

67 a 

68 b 

69 rho 

70 Hyperparameters of the sparsity prior used for variable selection. 

71 

72 The prior distribution on the choice of predictor for each decision rule 

73 is 

74 

75 .. math:: 

76 (s_1, \ldots, s_p) \sim 

77 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p). 

78 

79 If `theta` is not specified, it's a priori distributed according to 

80 

81 .. math:: 

82 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim 

83 \operatorname{Beta}(\mathtt{a}, \mathtt{b}). 

84 

85 If not specified, `rho` is set to the number of predictors p. To tune 

86 the prior, consider setting a lower `rho` to prefer more sparsity. 

87 If setting `theta` directly, it should be in the ballpark of p or lower 

88 as well. 

89 xinfo 

90 A matrix with the cutpoins to use to bin each predictor. If not 

91 specified, it is generated automatically according to `usequants` and 

92 `numcut`. 

93 

94 Each row shall contain a sorted list of cutpoints for a predictor. If 

95 there are less cutpoints than the number of columns in the matrix, 

96 fill the remaining cells with NaN. 

97 

98 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

99 usequants 

100 Whether to use predictors quantiles instead of a uniform grid to bin 

101 predictors. Ignored if `xinfo` is specified. 

102 rm_const 

103 How to treat predictors with no associated decision rules (i.e., there 

104 are no available cutpoints for that predictor). If `True` (default), 

105 they are ignored. If `False`, an error is raised if there are any. If 

106 `None`, no check is performed, and the output of the MCMC may not make 

107 sense if there are predictors without cutpoints. The option `None` is 

108 provided only to allow jax tracing. 

109 sigest 

110 An estimate of the residual standard deviation on `y_train`, used to set 

111 `lamda`. If not specified, it is estimated by linear regression (with 

112 intercept, and without taking into account `w`). If `y_train` has less 

113 than two elements, it is set to 1. If n <= p, it is set to the standard 

114 deviation of `y_train`. Ignored if `lamda` is specified. 

115 sigdf 

116 The degrees of freedom of the scaled inverse-chisquared prior on the 

117 noise variance. 

118 sigquant 

119 The quantile of the prior on the noise variance that shall match 

120 `sigest` to set the scale of the prior. Ignored if `lamda` is specified. 

121 k 

122 The inverse scale of the prior standard deviation on the latent mean 

123 function, relative to half the observed range of `y_train`. If `y_train` 

124 has less than two elements, `k` is ignored and the scale is set to 1. 

125 power 

126 base 

127 Parameters of the prior on tree node generation. The probability that a 

128 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) ** 

129 power``. 

130 lamda 

131 The prior harmonic mean of the error variance. (The harmonic mean of x 

132 is 1/mean(1/x).) If not specified, it is set based on `sigest` and 

133 `sigquant`. 

134 tau_num 

135 The numerator in the expression that determines the prior standard 

136 deviation of leaves. If not specified, default to ``(max(y_train) - 

137 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for 

138 continuous regression, and 3 for binary regression. 

139 offset 

140 The prior mean of the latent mean function. If not specified, it is set 

141 to the mean of `y_train` for continuous regression, and to 

142 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty, 

143 `offset` is set to 0. With binary regression, if `y_train` is all 

144 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or 

145 ``Phi^-1(n/(n+1))``, respectively. 

146 w 

147 Coefficients that rescale the error standard deviation on each 

148 datapoint. Not specifying `w` is equivalent to setting it to 1 for all 

149 datapoints. Note: `w` is ignored in the automatic determination of 

150 `sigest`, so either the weights should be O(1), or `sigest` should be 

151 specified by the user. 

152 ntree 

153 The number of trees used to represent the latent mean function. By 

154 default 200 for continuous regression and 50 for binary regression. 

155 numcut 

156 If `usequants` is `False`: the exact number of cutpoints used to bin the 

157 predictors, ranging between the minimum and maximum observed values 

158 (excluded). 

159 

160 If `usequants` is `True`: the maximum number of cutpoints to use for 

161 binning the predictors. Each predictor is binned such that its 

162 distribution in `x_train` is approximately uniform across bins. The 

163 number of bins is at most the number of unique values appearing in 

164 `x_train`, or ``numcut + 1``. 

165 

166 Before running the algorithm, the predictors are compressed to the 

167 smallest integer type that fits the bin indices, so `numcut` is best set 

168 to the maximum value of an unsigned integer type, like 255. 

169 

170 Ignored if `xinfo` is specified. 

171 ndpost 

172 The number of MCMC samples to save, after burn-in. `ndpost` is the 

173 total number of samples across all chains. `ndpost` is rounded up to the 

174 first multiple of `mc_cores`. 

175 nskip 

176 The number of initial MCMC samples to discard as burn-in. This number 

177 of samples is discarded from each chain. 

178 keepevery 

179 The thinning factor for the MCMC samples, after burn-in. By default, 1 

180 for continuous regression and 10 for binary regression. 

181 printevery 

182 The number of iterations (including thinned-away ones) between each log 

183 line. Set to `None` to disable logging. ^C interrupts the MCMC only 

184 every `printevery` iterations, so with logging disabled it's impossible 

185 to kill the MCMC conveniently. 

186 mc_cores 

187 The number of independent MCMC chains. 

188 seed 

189 The seed for the random number generator. 

190 bart_kwargs 

191 Additional arguments passed to `bartz.Bart`. 

192 

193 Notes 

194 ----- 

195 This interface imitates the function ``mc_gbart`` from the R package `BART3 

196 <https://github.com/rsparapa/bnptools>`_, but with these differences: 

197 

198 - If `x_train` and `x_test` are matrices, they have one predictor per row 

199 instead of per column. 

200 - If ``usequants=False``, R BART3 switches to quantiles anyway if there are 

201 less predictor values than the required number of bins, while bartz 

202 always follows the specification. 

203 - Some functionality is missing. 

204 - The error variance parameter is called `lamda` instead of `lambda`. 

205 - There are some additional attributes, and some missing. 

206 - The trees have a maximum depth of 8. 

207 - `rm_const` refers to predictors without decision rules instead of 

208 predictors that are constant in `x_train`. 

209 - If `rm_const=True` and some variables are dropped, the predictors 

210 matrix/dataframe passed to `predict` should still include them. 

211 

212 References 

213 ---------- 

214 .. [1] Linero, Antonio R. (2018). "Bayesian Regression Trees for 

215 High-Dimensional Prediction and Variable Selection". In: Journal of the 

216 American Statistical Association 113.522, pp. 626-636. 

217 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART: 

218 Bayesian additive regression trees," The Annals of Applied Statistics, 

219 Ann. Appl. Stat. 4(1), 266-298, (March 2010). 

220 """ 

221 

222 _bart: Bart 

223 

224 def __init__( 

225 self, 

226 x_train: Real[Array, 'p n'] | DataFrame, 

227 y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series, 

228 *, 

229 x_test: Real[Array, 'p m'] | DataFrame | None = None, 

230 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002 

231 sparse: bool = False, 

232 theta: FloatLike | None = None, 

233 a: FloatLike = 0.5, 

234 b: FloatLike = 1.0, 

235 rho: FloatLike | None = None, 

236 xinfo: Float[Array, 'p n'] | None = None, 

237 usequants: bool = False, 

238 rm_const: bool | None = True, 

239 sigest: FloatLike | None = None, 

240 sigdf: FloatLike = 3.0, 

241 sigquant: FloatLike = 0.9, 

242 k: FloatLike = 2.0, 

243 power: FloatLike = 2.0, 

244 base: FloatLike = 0.95, 

245 lamda: FloatLike | None = None, 

246 tau_num: FloatLike | None = None, 

247 offset: FloatLike | None = None, 

248 w: Float[Array, ' n'] | None = None, 

249 ntree: int | None = None, 

250 numcut: int = 100, 

251 ndpost: int = 1000, 

252 nskip: int = 100, 

253 keepevery: int | None = None, 

254 printevery: int | None = 100, 

255 mc_cores: int = 2, 

256 seed: int | Key[Array, ''] = 0, 

257 bart_kwargs: Mapping = MappingProxyType({}), 

258 ): 

259 kwargs: dict = dict( 1(GH)IJKLM*NO#xy$zA;UV+mn,PQ-RS2kl4opBbcXed3ij5qr%678Zh9stYfg'TCDE!uvwWF

260 x_train=x_train, 

261 y_train=y_train, 

262 x_test=x_test, 

263 type=type, 

264 sparse=sparse, 

265 theta=theta, 

266 a=a, 

267 b=b, 

268 rho=rho, 

269 xinfo=xinfo, 

270 usequants=usequants, 

271 rm_const=rm_const, 

272 sigest=sigest, 

273 sigdf=sigdf, 

274 sigquant=sigquant, 

275 k=k, 

276 power=power, 

277 base=base, 

278 lamda=lamda, 

279 tau_num=tau_num, 

280 offset=offset, 

281 w=w, 

282 ntree=ntree, 

283 numcut=numcut, 

284 ndpost=ndpost, 

285 nskip=nskip, 

286 keepevery=keepevery, 

287 printevery=printevery, 

288 seed=seed, 

289 maxdepth=8, 

290 **process_mc_cores(y_train, mc_cores), 

291 ) 

292 kwargs.update(bart_kwargs) 1(GH)IJKLM*NO#xy$zA;UV+mn,PQ-RS2kl4opBbcXed3ij5qr%678Zh9stYfg'TCDE!uvwWF

293 self._bart = Bart(**kwargs) 1(GH)IJKLM*NO#xy$zA;UV+mn,PQ-RS2kl4opBbcXed3ij5qr%678Zh9stYfg'TCDE!uvwWF

294 

295 # Public attributes from Bart 

296 

297 @property 

298 def ndpost(self) -> int: 

299 """The number of MCMC samples saved, after burn-in.""" 

300 return self._bart.ndpost 101.=/BbcXed9st

301 

302 @property 

303 def offset(self) -> Float32[Array, '']: 

304 """The prior mean of the latent mean function.""" 

305 return self._bart.offset 10:12kl4opBbcXed%678Zh

306 

307 @property 

308 def sigest(self) -> Float32[Array, ''] | None: 

309 """The estimated standard deviation of the error used to set `lamda`.""" 

310 return self._bart.sigest 12kl4opZh'C

311 

312 @property 

313 def yhat_test(self) -> Float32[Array, 'ndpost m'] | None: 

314 """The conditional posterior mean at `x_test` for each MCMC iteration.""" 

315 return self._bart.yhat_test 10:1BbcXedZhYfg

316 

317 # Private attributes from Bart 

318 

319 @property 

320 def _main_trace(self) -> mcmcloop.MainTrace: 

321 return self._bart._main_trace # noqa: SLF001 1(GH)IJ?@[KLM*NO#xy$zA,PQ-RS2kl4opBbcXed3ij5qr%678Zh9stYfg'TCDE!uvwF

322 

323 @property 

324 def _burnin_trace(self) -> mcmcloop.BurninTrace: 

325 return self._bart._burnin_trace # noqa: SLF001 1?@[Yfg

326 

327 @property 

328 def _mcmc_state(self) -> mcmcstep.State: 

329 return self._bart._mcmc_state # noqa: SLF001 1(GH)IJ0:1.=/?@[^K_L`M*NO#xy$zA+mn,PQ-RS2kl4opBbcXed3ij5qr%678Zh9stYfg'TCDE!uvwF

330 

331 @property 

332 def _splits(self) -> Real[Array, 'p max_num_splits']: 

333 return self._bart._splits # noqa: SLF001 10:1F

334 

335 @property 

336 def _x_train_fmt(self) -> Any: 

337 return self._bart._x_train_fmt # noqa: SLF001 

338 

339 # Cached properties from Bart 

340 

341 @cached_property 

342 def prob_test(self) -> Float32[Array, 'ndpost m'] | None: 

343 """The posterior probability of y being True at `x_test` for each MCMC iteration.""" 

344 return self._bart.prob_test 1:BbceYfg

345 

346 @cached_property 

347 def prob_test_mean(self) -> Float32[Array, ' m'] | None: 

348 """The marginal posterior probability of y being True at `x_test`.""" 

349 return self._bart.prob_test_mean 1Bbce

350 

351 @cached_property 

352 def prob_train(self) -> Float32[Array, 'ndpost n'] | None: 

353 """The posterior probability of y being True at `x_train` for each MCMC iteration.""" 

354 return self._bart.prob_train 1=BbceYfg

355 

356 @cached_property 

357 def prob_train_mean(self) -> Float32[Array, ' n'] | None: 

358 """The marginal posterior probability of y being True at `x_train`.""" 

359 return self._bart.prob_train_mean 1BbceYfg

360 

361 @cached_property 

362 def sigma( 

363 self, 

364 ) -> ( 

365 Float32[Array, ' nskip+ndpost'] 

366 | Float32[Array, 'nskip+ndpost/mc_cores mc_cores'] 

367 | None 

368 ): 

369 """The standard deviation of the error, including burn-in samples.""" 

370 return self._bart.sigma 1./BbcXd3ijZhYfg

371 

372 @cached_property 

373 def sigma_(self) -> Float32[Array, 'ndpost'] | None: 

374 """The standard deviation of the error, only over the post-burnin samples and flattened.""" 

375 return self._bart.sigma_ 101BbcXdYfg

376 

377 @cached_property 

378 def sigma_mean(self) -> Float32[Array, ''] | None: 

379 """The mean of `sigma`, only over the post-burnin samples.""" 

380 return self._bart.sigma_mean 101BbcXdZh

381 

382 @cached_property 

383 def varcount(self) -> Int32[Array, 'ndpost p']: 

384 """Histogram of predictor usage for decision rules in the trees.""" 

385 return self._bart.varcount 1:./BbcXed678Yfg

386 

387 @cached_property 

388 def varcount_mean(self) -> Float32[Array, ' p']: 

389 """Average of `varcount` across MCMC iterations.""" 

390 return self._bart.varcount_mean 101BbcXed

391 

392 @cached_property 

393 def varprob(self) -> Float32[Array, 'ndpost p']: 

394 """Posterior samples of the probability of choosing each predictor for a decision rule.""" 

395 return self._bart.varprob 1./BbcXedYfg!uvw

396 

397 @cached_property 

398 def varprob_mean(self) -> Float32[Array, ' p']: 

399 """The marginal posterior probability of each predictor being chosen for a decision rule.""" 

400 return self._bart.varprob_mean 101BbcXedDE!uvw

401 

402 @cached_property 

403 def yhat_test_mean(self) -> Float32[Array, ' m'] | None: 

404 """The marginal posterior mean at `x_test`. 

405 

406 Not defined with binary regression because it's error-prone, typically 

407 the right thing to consider would be `prob_test_mean`. 

408 """ 

409 return self._bart.yhat_test_mean 101BbcXdZh

410 

411 @cached_property 

412 def yhat_train(self) -> Float32[Array, 'ndpost n']: 

413 """The conditional posterior mean at `x_train` for each MCMC iteration.""" 

414 return self._bart.yhat_train 1.=/$zA+mn2klBbcXed3ij5qrZh9stYfg

415 

416 @cached_property 

417 def yhat_train_mean(self) -> Float32[Array, ' n'] | None: 

418 """The marginal posterior mean at `x_train`. 

419 

420 Not defined with binary regression because it's error-prone, typically 

421 the right thing to consider would be `prob_train_mean`. 

422 """ 

423 return self._bart.yhat_train_mean 101BbcXdZhYfg

424 

425 # Public methods from Bart 

426 

427 def predict( 

428 self, x_test: Real[Array, 'p m'] | DataFrame 

429 ) -> Float32[Array, 'ndpost m']: 

430 """ 

431 Compute the posterior mean at `x_test` for each MCMC iteration. 

432 

433 Parameters 

434 ---------- 

435 x_test 

436 The test predictors. 

437 

438 Returns 

439 ------- 

440 The conditional posterior mean at `x_test` for each MCMC iteration. 

441 """ 

442 return self._bart.predict(x_test) 1#xy3ij5qr

443 

444 

445class gbart(mc_gbart): 

446 """Subclass of `mc_gbart` that forces `mc_cores=1`.""" 

447 

448 def __init__(self, *args, **kwargs): 

449 if 'mc_cores' in kwargs: 449 ↛ 452line 449 didn't jump to line 452 because the condition on line 449 was always true1]

450 msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'" 1]

451 raise TypeError(msg) 1]

452 kwargs.update(mc_cores=1) 

453 super().__init__(*args, **kwargs) 

454 

455 

456def process_mc_cores(y_train: Array | Any, mc_cores: int) -> dict[str, Any]: 

457 """Determine the arguments to pass to `Bart` to configure multiple chains.""" 

458 # one chain, leave default configuration which is num_chains=None 

459 if abs(mc_cores) == 1: 1(GH)IJKLM*NO#xy$zA;UV+mn,PQ-RS2kl4opBbcXed3ij5qr%678Zh9stYfg'TCDE!uvwWF

460 return {} 1()*#$;+,-24BX35%678Z9Y'!

461 

462 # determine if we are on cpu; this point may raise an exception 

463 platform = get_platform(y_train, mc_cores) 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

464 

465 # set the num_chains argument 

466 mc_cores = abs(mc_cores) 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

467 kwargs = dict(num_chains=mc_cores) 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

468 

469 # if on cpu, try to shard the chains across multiple virtual cpus 

470 if platform == 'cpu': 470 ↛ 502line 470 didn't jump to line 502 because the condition on line 470 was always true1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

471 # determine number of logical cpu cores 

472 num_cores = cpu_count() 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

473 assert num_cores is not None, 'could not determine number of cpu cores' 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

474 

475 # determine number of shards that evenly divides chains 

476 for num_shards in range(num_cores, 0, -1): 476 ↛ 481line 476 didn't jump to line 481 because the loop on line 476 didn't complete1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

477 if mc_cores % num_shards == 0: 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

478 break 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

479 

480 # handle the case where there are less jax cpu devices that that 

481 if num_shards > 1: 481 ↛ 499line 481 didn't jump to line 499 because the condition on line 481 was always true1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

482 num_jax_cpus = device_count('cpu') 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

483 if num_jax_cpus < num_shards: 483 ↛ 484line 483 didn't jump to line 484 because the condition on line 483 was never true1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

484 for new_num_shards in range(num_jax_cpus, 0, -1): 

485 if mc_cores % new_num_shards == 0: 

486 break 

487 msg = ( 

488 f'`mc_gbart` would like to shard {mc_cores} chains across ' 

489 f'{num_shards} virtual jax cpu devices, but jax is set up ' 

490 f'with only {num_jax_cpus} cpu devices, so it will use ' 

491 f'{new_num_shards} devices instead. To enable ' 

492 'parallelization, please increase the limit with ' 

493 '`jax.config.update("jax_num_cpu_devices", <num_devices>)`.' 

494 ) 

495 warn(msg) 

496 num_shards = new_num_shards 

497 

498 # set the number of shards 

499 if num_shards > 1: 499 ↛ 502line 499 didn't jump to line 502 because the condition on line 499 was always true1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

500 kwargs.update(num_chain_devices=num_shards) 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

501 

502 return kwargs 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

503 

504 

505def get_platform(y_train: Array | Any, mc_cores: int) -> str: 

506 """Get the platform for `process_mc_cores` from `y_train` or the default device.""" 

507 if isinstance(y_train, Array) and hasattr(y_train, 'platform'): 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

508 return y_train.platform() 1GHIJKLMNOxyzAUVmnPQRSklopbcedijqrhstfgTCDEuvwWF

509 elif ( 509 ↛ 517line 509 didn't jump to line 517 because the condition on line 509 was always true1mnij

510 not isinstance(y_train, Array) and hasattr(jnp.zeros(()), 'platform') 

511 # this condition means: y_train is not an array, but we are not under 

512 # jit, so y_train is going to be converted to an array on the default 

513 # device 

514 ) or mc_cores < 0: 

515 return get_default_device().platform 1mnij

516 else: 

517 msg = ( 

518 'Could not determine the platform from `y_train`, maybe `mc_gbart` ' 

519 'was used with a `jax.jit`ted function? The platform is needed to ' 

520 'determine whether the computation is going to run on CPU to ' 

521 'automatically shard the chains across multiple virtual CPU ' 

522 'devices. To acknowledge this problem and circumvent it ' 

523 'by using the current default jax device, negate `mc_cores`.' 

524 ) 

525 raise RuntimeError(msg)