Coverage for src / bartz / BART / _gbart.py: 93%

135 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/BART/_gbart.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package.""" 

26 

27from collections.abc import Hashable, Mapping 

28from functools import cached_property 

29from os import cpu_count 

30from types import MappingProxyType 

31from typing import Any, Literal 

32from warnings import warn 

33 

34from equinox import Module 

35from jax import device_count 

36from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real 

37 

38from bartz import mcmcloop, mcmcstep 

39from bartz._interface import Bart, DataFrame, FloatLike, Series 

40from bartz.jaxext import get_default_device, jit_active 

41 

42 

43class mc_gbart(Module): 

44 R""" 

45 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_. 

46 

47 Regress `y_train` on `x_train` with a latent mean function represented as 

48 a sum of decision trees. The inference is carried out by sampling the 

49 posterior distribution of the tree ensemble with an MCMC. 

50 

51 Parameters 

52 ---------- 

53 x_train 

54 The training predictors. 

55 y_train 

56 The training responses. 

57 x_test 

58 The test predictors. 

59 type 

60 The type of regression. 'wbart' for continuous regression, 'pbart' for 

61 binary regression with probit link. 

62 sparse 

63 Whether to activate variable selection on the predictors as done in 

64 [1]_. 

65 theta 

66 a 

67 b 

68 rho 

69 Hyperparameters of the sparsity prior used for variable selection. 

70 

71 The prior distribution on the choice of predictor for each decision rule 

72 is 

73 

74 .. math:: 

75 (s_1, \ldots, s_p) \sim 

76 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p). 

77 

78 If `theta` is not specified, it's a priori distributed according to 

79 

80 .. math:: 

81 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim 

82 \operatorname{Beta}(\mathtt{a}, \mathtt{b}). 

83 

84 If not specified, `rho` is set to the number of predictors p. To tune 

85 the prior, consider setting a lower `rho` to prefer more sparsity. 

86 If setting `theta` directly, it should be in the ballpark of p or lower 

87 as well. 

88 varprob 

89 The probability distribution over the `p` predictors for choosing a 

90 predictor to split on in a decision node a priori. Must be > 0. It does 

91 not need to be normalized to sum to 1. If not specified, use a uniform 

92 distribution. If ``sparse=True``, this is used as initial value for the 

93 MCMC. 

94 xinfo 

95 A matrix with the cutpoins to use to bin each predictor. If not 

96 specified, it is generated automatically according to `usequants` and 

97 `numcut`. 

98 

99 Each row shall contain a sorted list of cutpoints for a predictor. If 

100 there are less cutpoints than the number of columns in the matrix, 

101 fill the remaining cells with NaN. 

102 

103 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

104 usequants 

105 Whether to use predictors quantiles instead of a uniform grid to bin 

106 predictors. Ignored if `xinfo` is specified. 

107 rm_const 

108 How to treat predictors with no associated decision rules (i.e., there 

109 are no available cutpoints for that predictor). If `True` (default), 

110 they are ignored. If `False`, an error is raised if there are any. 

111 sigest 

112 An estimate of the residual standard deviation on `y_train`, used to set 

113 `lamda`. If not specified, it is estimated by linear regression (with 

114 intercept, and without taking into account `w`). If `y_train` has less 

115 than two elements, it is set to 1. If n <= p, it is set to the standard 

116 deviation of `y_train`. Ignored if `lamda` is specified. 

117 sigdf 

118 The degrees of freedom of the scaled inverse-chisquared prior on the 

119 noise variance. 

120 sigquant 

121 The quantile of the prior on the noise variance that shall match 

122 `sigest` to set the scale of the prior. Ignored if `lamda` is specified. 

123 k 

124 The inverse scale of the prior standard deviation on the latent mean 

125 function, relative to half the observed range of `y_train`. If `y_train` 

126 has less than two elements, `k` is ignored and the scale is set to 1. 

127 power 

128 base 

129 Parameters of the prior on tree node generation. The probability that a 

130 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) ** 

131 power``. 

132 lamda 

133 The prior harmonic mean of the error variance. (The harmonic mean of x 

134 is 1/mean(1/x).) If not specified, it is set based on `sigest` and 

135 `sigquant`. 

136 tau_num 

137 The numerator in the expression that determines the prior standard 

138 deviation of leaves. If not specified, default to ``(max(y_train) - 

139 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for 

140 continuous regression, and 3 for binary regression. 

141 offset 

142 The prior mean of the latent mean function. If not specified, it is set 

143 to the mean of `y_train` for continuous regression, and to 

144 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty, 

145 `offset` is set to 0. With binary regression, if `y_train` is all 

146 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or 

147 ``Phi^-1(n/(n+1))``, respectively. 

148 w 

149 Coefficients that rescale the error standard deviation on each 

150 datapoint. Not specifying `w` is equivalent to setting it to 1 for all 

151 datapoints. Note: `w` is ignored in the automatic determination of 

152 `sigest`, so either the weights should be O(1), or `sigest` should be 

153 specified by the user. 

154 ntree 

155 The number of trees used to represent the latent mean function. By 

156 default 200 for continuous regression and 50 for binary regression. 

157 numcut 

158 If `usequants` is `False`: the exact number of cutpoints used to bin the 

159 predictors, ranging between the minimum and maximum observed values 

160 (excluded). 

161 

162 If `usequants` is `True`: the maximum number of cutpoints to use for 

163 binning the predictors. Each predictor is binned such that its 

164 distribution in `x_train` is approximately uniform across bins. The 

165 number of bins is at most the number of unique values appearing in 

166 `x_train`, or ``numcut + 1``. 

167 

168 Before running the algorithm, the predictors are compressed to the 

169 smallest integer type that fits the bin indices, so `numcut` is best set 

170 to the maximum value of an unsigned integer type, like 255. 

171 

172 Ignored if `xinfo` is specified. 

173 ndpost 

174 The number of MCMC samples to save, after burn-in. `ndpost` is the 

175 total number of samples across all chains. `ndpost` is rounded up to the 

176 first multiple of `mc_cores`. 

177 nskip 

178 The number of initial MCMC samples to discard as burn-in. This number 

179 of samples is discarded from each chain. 

180 keepevery 

181 The thinning factor for the MCMC samples, after burn-in. By default, 1 

182 for continuous regression and 10 for binary regression. 

183 printevery 

184 The number of iterations (including thinned-away ones) between each log 

185 line. Set to `None` to disable logging. ^C interrupts the MCMC only 

186 every `printevery` iterations, so with logging disabled it's impossible 

187 to kill the MCMC conveniently. 

188 mc_cores 

189 The number of independent MCMC chains. 

190 seed 

191 The seed for the random number generator. 

192 bart_kwargs 

193 Additional arguments passed to `bartz.Bart`. 

194 

195 Notes 

196 ----- 

197 This interface imitates the function ``mc_gbart`` from the R package `BART3 

198 <https://github.com/rsparapa/bnptools>`_, but with these differences: 

199 

200 - If `x_train` and `x_test` are matrices, they have one predictor per row 

201 instead of per column. 

202 - If ``usequants=False``, R BART3 switches to quantiles anyway if there are 

203 less predictor values than the required number of bins, while bartz 

204 always follows the specification. 

205 - Some functionality is missing. 

206 - The error variance parameter is called `lamda` instead of `lambda`. 

207 - There are some additional attributes, and some missing. 

208 - The trees have a maximum depth of 6. 

209 - `rm_const` refers to predictors without decision rules instead of 

210 predictors that are constant in `x_train`. 

211 - If `rm_const=True` and some variables are dropped, the predictors 

212 matrix/dataframe passed to `predict` should still include them. 

213 

214 References 

215 ---------- 

216 .. [1] Linero, Antonio R. (2018). "Bayesian Regression Trees for 

217 High-Dimensional Prediction and Variable Selection". In: Journal of the 

218 American Statistical Association 113.522, pp. 626-636. 

219 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART: 

220 Bayesian additive regression trees," The Annals of Applied Statistics, 

221 Ann. Appl. Stat. 4(1), 266-298, (March 2010). 

222 """ 

223 

224 _bart: Bart 

225 

226 def __init__( 

227 self, 

228 x_train: Real[Array, 'p n'] | DataFrame, 

229 y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series, 

230 *, 

231 x_test: Real[Array, 'p m'] | DataFrame | None = None, 

232 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002 

233 sparse: bool = False, 

234 theta: FloatLike | None = None, 

235 a: FloatLike = 0.5, 

236 b: FloatLike = 1.0, 

237 rho: FloatLike | None = None, 

238 varprob: Float[Array, ' p'] | None = None, 

239 xinfo: Float[Array, 'p n'] | None = None, 

240 usequants: bool = False, 

241 rm_const: bool = True, 

242 sigest: FloatLike | None = None, 

243 sigdf: FloatLike = 3.0, 

244 sigquant: FloatLike = 0.9, 

245 k: FloatLike = 2.0, 

246 power: FloatLike = 2.0, 

247 base: FloatLike = 0.95, 

248 lamda: FloatLike | None = None, 

249 tau_num: FloatLike | None = None, 

250 offset: FloatLike | None = None, 

251 w: Float[Array, ' n'] | None = None, 

252 ntree: int | None = None, 

253 numcut: int = 100, 

254 ndpost: int = 1000, 

255 nskip: int = 100, 

256 keepevery: int | None = None, 

257 printevery: int | None = 100, 

258 mc_cores: int = 2, 

259 seed: int | Key[Array, ''] = 0, 

260 bart_kwargs: Mapping = MappingProxyType({}), 

261 ) -> None: 

262 # set defaults that depend on type of regression 

263 if keepevery is None: 1!Nnk$Ss=5I4tZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr

264 keepevery = 10 if type == 'pbart' else 1 1!nk$s=I7g*z%uAB[O?J'm]W+C,D(v:GwbPd9j#o;-./8hyc@MpqHr

265 if ntree is None: 1!Nnk$Ss=5I4tZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr

266 ntree = 50 if type == 'pbart' else 200 1kT(XvpqHr

267 

268 # set most calling arguments for Bart 

269 kwargs: dict = dict( 1!Nnk$Ss=5ItZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr

270 x_train=x_train, 

271 y_train=y_train, 

272 x_test=x_test, 

273 type=type, 

274 sparse=sparse, 

275 theta=theta, 

276 a=a, 

277 b=b, 

278 rho=rho, 

279 varprob=varprob, 

280 xinfo=xinfo, 

281 usequants=usequants, 

282 rm_const=rm_const, 

283 sigest=sigest, 

284 sigdf=sigdf, 

285 sigquant=sigquant, 

286 k=k, 

287 power=power, 

288 base=base, 

289 lamda=lamda, 

290 tau_num=tau_num, 

291 offset=offset, 

292 w=w, 

293 num_trees=ntree, 

294 numcut=numcut, 

295 ndpost=ndpost, 

296 nskip=nskip, 

297 keepevery=keepevery, 

298 printevery=printevery, 

299 seed=seed, 

300 maxdepth=6, 

301 **process_mc_cores(y_train, mc_cores), 

302 ) 

303 

304 # set min_points_per_leaf unless the user set it already 

305 if 'min_points_per_leaf' not in bart_kwargs.get('init_kw', {}): 1!Nnk$Ss=5ItZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr

306 bart_kwargs = dict(bart_kwargs) 1!nk$s=I7g*z%uAB?J'm+C(vwbPd9j#o8h)xycpqHr

307 init_kw = dict(bart_kwargs.get('init_kw', {})) 1!nk$s=I7g*z%uAB?J'm+C(vwbPd9j#o8h)xycpqHr

308 init_kw['min_points_per_leaf'] = 5 1!nk$s=I7g*z%uAB?J'm+C(vwbPd9j#o8h)xycpqHr

309 bart_kwargs['init_kw'] = init_kw 1!nk$s=I7g*z%uAB?J'm+C(vwbPd9j#o8h)xycpqHr

310 

311 # add user arguments 

312 kwargs.update(bart_kwargs) 1!Nnk$Ss=5ItZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr

313 

314 # invoke Bart 

315 self._bart = Bart(**kwargs) 1!Nnk$Ss=5ItZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr

316 

317 # Public attributes from Bart 

318 

319 @property 

320 def ndpost(self) -> int: 

321 """The number of MCMC samples saved, after burn-in.""" 

322 return self._bart.ndpost 1^_`}{webPid;-./)Rx

323 

324 @property 

325 def offset(self) -> Float32[Array, '']: 

326 """The prior mean of the latent mean function.""" 

327 return self._bart.offset 1^|_,FD:LGwebPid;-./8hyfc

328 

329 @property 

330 def sigest(self) -> Float32[Array, ''] | None: 

331 """The estimated standard deviation of the error used to set `lamda`.""" 

332 return self._bart.sigest 1,FD:LG8hyfc@M

333 

334 @property 

335 def yhat_test(self) -> Float32[Array, 'ndpost m'] | None: 

336 """The conditional posterior mean at `x_test` for each MCMC iteration.""" 

337 return self._bart.yhat_test 1^|_7lgwebPid8hyfc

338 

339 # Private attributes from Bart 

340 

341 @property 

342 def _main_trace(self) -> mcmcloop.MainTrace: 

343 return self._bart._main_trace # noqa: SLF001 2! N n k $ S s ~ abbbt Z T 7 l g * 0 z % U u A 1 B [ V O ' K m ] 2 W + 3 C , F D ( X v : L G w e b P i d 9 E j # Q o ; - . / 8 h ) R x y f c @ Y M p q r

344 

345 @property 

346 def _burnin_trace(self) -> mcmcloop.BurninTrace: 

347 return self._bart._burnin_trace # noqa: SLF001 2~ abbb, F D : L G ; - . / y f c @ Y M

348 

349 @property 

350 def _mcmc_state(self) -> mcmcstep.State: 

351 return self._bart._mcmc_state # noqa: SLF001 2! N n k $ S s ^ | _ ` } { ~ abbbdbt ebZ fbT 7 l g * 0 z % U u A 1 B [ V O ' K m ] 2 W + 3 C , F D ( X v : L G w e b P i d 9 E j # Q o ; - . / 8 h ) R x y f c @ Y M p q r

352 

353 @property 

354 def _splits(self) -> Real[Array, 'p max_num_splits']: 

355 return self._bart._splits # noqa: SLF001 1^|_r

356 

357 @property 

358 def _x_train_fmt(self) -> Hashable: 

359 return self._bart._x_train_fmt # noqa: SLF001 

360 

361 # Cached properties from Bart 

362 

363 @cached_property 

364 def prob_test(self) -> Float32[Array, 'ndpost m'] | None: 

365 """The posterior probability of y being True at `x_test` for each MCMC iteration.""" 

366 return self._bart.prob_test 1|7lgwebiyfc

367 

368 @cached_property 

369 def prob_test_mean(self) -> Float32[Array, ' m'] | None: 

370 """The marginal posterior probability of y being True at `x_test`.""" 

371 return self._bart.prob_test_mean 1webi

372 

373 @cached_property 

374 def prob_train(self) -> Float32[Array, 'ndpost n'] | None: 

375 """The posterior probability of y being True at `x_train` for each MCMC iteration.""" 

376 return self._bart.prob_train 1}7lgwebiyfc

377 

378 @cached_property 

379 def prob_train_mean(self) -> Float32[Array, ' n'] | None: 

380 """The marginal posterior probability of y being True at `x_train`.""" 

381 return self._bart.prob_train_mean 1webiyfc

382 

383 @cached_property 

384 def sigma( 

385 self, 

386 ) -> ( 

387 Float32[Array, ' nskip+ndpost'] 

388 | Float32[Array, 'nskip+ndpost/mc_cores mc_cores'] 

389 | None 

390 ): 

391 """The standard deviation of the error, including burn-in samples.""" 

392 return self._bart.sigma 1`{7lgwebPd9Ej8hyfc

393 

394 @cached_property 

395 def sigma_(self) -> Float32[Array, 'ndpost'] | None: 

396 """The standard deviation of the error, only over the post-burnin samples and flattened.""" 

397 return self._bart.sigma_ 1^_7lgwebPdyfc

398 

399 @cached_property 

400 def sigma_mean(self) -> Float32[Array, ''] | None: 

401 """The mean of `sigma`, only over the post-burnin samples.""" 

402 return self._bart.sigma_mean 1^_webPd8hyfc

403 

404 @cached_property 

405 def varcount(self) -> Int32[Array, 'ndpost p']: 

406 """Histogram of predictor usage for decision rules in the trees.""" 

407 return self._bart.varcount 1|`{7lgwebPid-./yfc

408 

409 @cached_property 

410 def varcount_mean(self) -> Float32[Array, ' p']: 

411 """Average of `varcount` across MCMC iterations.""" 

412 return self._bart.varcount_mean 1$Ss^_webPidyfc

413 

414 @cached_property 

415 def varprob(self) -> Float32[Array, 'ndpost p']: 

416 """Posterior samples of the probability of choosing each predictor for a decision rule.""" 

417 return self._bart.varprob 1!Nnk`{7lgwebPidyfc

418 

419 @cached_property 

420 def varprob_mean(self) -> Float32[Array, ' p']: 

421 """The marginal posterior probability of each predictor being chosen for a decision rule.""" 

422 return self._bart.varprob_mean 1!Nnk^_webPidyfcpq

423 

424 @cached_property 

425 def yhat_test_mean(self) -> Float32[Array, ' m'] | None: 

426 """The marginal posterior mean at `x_test`. 

427 

428 Not defined with binary regression because it's error-prone, typically 

429 the right thing to consider would be `prob_test_mean`. 

430 """ 

431 return self._bart.yhat_test_mean 1^_webPd8hyfc

432 

433 @cached_property 

434 def yhat_train(self) -> Float32[Array, 'ndpost n']: 

435 """The conditional posterior mean at `x_train` for each MCMC iteration.""" 

436 return self._bart.yhat_train 1`}{7lg[VO'Km,FDwebPid9Ej#Qo8h)Rxyfc

437 

438 @cached_property 

439 def yhat_train_mean(self) -> Float32[Array, ' n'] | None: 

440 """The marginal posterior mean at `x_train`. 

441 

442 Not defined with binary regression because it's error-prone, typically 

443 the right thing to consider would be `prob_train_mean`. 

444 """ 

445 return self._bart.yhat_train_mean 1^_webPd8hyfc

446 

447 # Public methods from Bart 

448 

449 def predict( 

450 self, x_test: Real[Array, 'p m'] | DataFrame 

451 ) -> Float32[Array, 'ndpost m']: 

452 """ 

453 Compute the posterior mean at `x_test` for each MCMC iteration. 

454 

455 Parameters 

456 ---------- 

457 x_test 

458 The test predictors. 

459 

460 Returns 

461 ------- 

462 The conditional posterior mean at `x_test` for each MCMC iteration. 

463 """ 

464 return self._bart.predict(x_test) 1%Uu9Ej#Qo

465 

466 

467class gbart(mc_gbart): 

468 """Subclass of `mc_gbart` that forces `mc_cores=1`.""" 

469 

470 def __init__(self, *args: Any, **kwargs: Any) -> None: 

471 if 'mc_cores' in kwargs: 471 ↛ 474line 471 didn't jump to line 474 because the condition on line 471 was always true2cb

472 msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'" 2cb

473 raise TypeError(msg) 2cb

474 kwargs.update(mc_cores=1) 

475 super().__init__(*args, **kwargs) 

476 

477 

478def process_mc_cores(y_train: Array | Series, mc_cores: int) -> dict[str, Any]: 

479 """Determine the arguments to pass to `Bart` to configure multiple chains.""" 

480 # one chain, disable multichain altogether 

481 if abs(mc_cores) == 1: 1!Nnk$Ss=5I4tZT7lg*0z%UuA1B[VO?6J'Km]2W+3C,FD(Xv:LGwebPid9Ej#Qo;-./8h)Rxyfc@YMpqHr

482 return dict(num_chains=None) 1!$=7*%[?']+,(:wP9#;-./8)y@

483 

484 # determine if we are on cpu; this point may raise an exception 

485 platform = get_platform(y_train, mc_cores) 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

486 

487 # set the num_chains argument 

488 mc_cores = abs(mc_cores) 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

489 kwargs = dict(num_chains=mc_cores) 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

490 

491 # if on cpu, try to shard the chains across multiple virtual cpus 

492 if platform == 'cpu': 492 ↛ 524line 492 didn't jump to line 524 because the condition on line 492 was always true1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

493 # determine number of logical cpu cores 

494 num_cores = cpu_count() 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

495 assert num_cores is not None, 'could not determine number of cpu cores' 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

496 

497 # determine number of shards that evenly divides chains 

498 for num_shards in range(num_cores, 0, -1): 498 ↛ 503line 498 didn't jump to line 503 because the loop on line 498 didn't complete1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

499 if mc_cores % num_shards == 0: 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

500 break 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

501 

502 # handle the case where there are less jax cpu devices that that 

503 if num_shards > 1: 503 ↛ 521line 503 didn't jump to line 521 because the condition on line 503 was always true1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

504 num_jax_cpus = device_count('cpu') 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

505 if num_jax_cpus < num_shards: 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

506 for new_num_shards in range(num_jax_cpus, 0, -1): 506 ↛ 509line 506 didn't jump to line 509 because the loop on line 506 didn't complete14t

507 if mc_cores % new_num_shards == 0: 507 ↛ 506line 507 didn't jump to line 506 because the condition on line 507 was always true14t

508 break 14t

509 msg = ( 14t

510 f'`mc_gbart` would like to shard {mc_cores} chains across ' 

511 f'{num_shards} virtual jax cpu devices, but jax is set up ' 

512 f'with only {num_jax_cpus} cpu devices, so it will use ' 

513 f'{new_num_shards} devices instead. To enable ' 

514 'parallelization, please increase the limit with ' 

515 '`jax.config.update("jax_num_cpu_devices", <num_devices>)`.' 

516 ) 

517 warn(msg) 14t

518 num_shards = new_num_shards 14t

519 

520 # set the number of shards 

521 if num_shards > 1: 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

522 kwargs.update(num_chain_devices=num_shards) 1NnkSs5ItZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

523 

524 return kwargs 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

525 

526 

527def get_platform(y_train: Array | Series, mc_cores: int) -> str: 

528 """Get the platform for `process_mc_cores` from `y_train` or the default device.""" 

529 if isinstance(y_train, Array) and hasattr(y_train, 'platform'): 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

530 return y_train.platform() 1NnkSs5I4tZTlg0zUuA1BVO6JKm2W3CFDXvLGebidEjQohRxfcYMpqHr

531 elif ( 531 ↛ 539line 531 didn't jump to line 539 because the condition on line 531 was always true1KmEj

532 not isinstance(y_train, Array) and not jit_active() 

533 # this condition means: y_train is not an array, but we are not under 

534 # jit, so y_train is going to be converted to an array on the default 

535 # device 

536 ) or mc_cores < 0: 

537 return get_default_device().platform 1KmEj

538 else: 

539 msg = ( 

540 'Could not determine the platform from `y_train`, maybe `mc_gbart` ' 

541 'was used with a `jax.jit`ted function? The platform is needed to ' 

542 'determine whether the computation is going to run on CPU to ' 

543 'automatically shard the chains across multiple virtual CPU ' 

544 'devices. To acknowledge this problem and circumvent it ' 

545 'by using the current default jax device, negate `mc_cores`.' 

546 ) 

547 raise RuntimeError(msg)