Coverage for src / bartz / _interface.py: 92%

333 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/_interface.py 

2# 

3# Copyright (c) 2025-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Main high-level interface of the package.""" 

26 

27import math 

28from collections.abc import Mapping, Sequence 

29from functools import cached_property, partial 

30from types import MappingProxyType 

31from typing import Any, Literal, Protocol, TypedDict 

32 

33import jax 

34import jax.numpy as jnp 

35from equinox import Module, error_if, field 

36from jax import Device, device_put, jit, lax, make_mesh 

37from jax.scipy.special import ndtr 

38from jax.sharding import AxisType, Mesh 

39from jaxtyping import ( 

40 Array, 

41 Bool, 

42 Float, 

43 Float32, 

44 Int32, 

45 Integer, 

46 Key, 

47 Real, 

48 Shaped, 

49 UInt, 

50) 

51from numpy import ndarray 

52 

53from bartz import mcmcloop, mcmcstep, prepcovars 

54from bartz.jaxext import is_key 

55from bartz.jaxext.scipy.special import ndtri 

56from bartz.jaxext.scipy.stats import invgamma 

57from bartz.mcmcloop import RunMCMCResult, compute_varcount, evaluate_trace, run_mcmc 

58from bartz.mcmcstep import make_p_nonterminal 

59from bartz.mcmcstep._state import get_num_chains 

60 

61FloatLike = float | Float[Any, ''] 

62 

63 

64class DataFrame(Protocol): 

65 """DataFrame duck-type for `Bart`.""" 

66 

67 columns: Sequence[str] 

68 """The names of the columns.""" 

69 

70 def to_numpy(self) -> ndarray: 

71 """Convert the dataframe to a 2d numpy array with columns on the second axis.""" 

72 ... 

73 

74 

75class Series(Protocol): 

76 """Series duck-type for `Bart`.""" 

77 

78 name: str | None 

79 """The name of the series.""" 

80 

81 def to_numpy(self) -> ndarray: 

82 """Convert the series to a 1d numpy array.""" 

83 ... 

84 

85 

86class Bart(Module): 

87 R""" 

88 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_. 

89 

90 Regress `y_train` on `x_train` with a latent mean function represented as 

91 a sum of decision trees. The inference is carried out by sampling the 

92 posterior distribution of the tree ensemble with an MCMC. 

93 

94 Parameters 

95 ---------- 

96 x_train 

97 The training predictors. 

98 y_train 

99 The training responses. 

100 x_test 

101 The test predictors. 

102 type 

103 The type of regression. 'wbart' for continuous regression, 'pbart' for 

104 binary regression with probit link. 

105 sparse 

106 Whether to activate variable selection on the predictors as done in 

107 [1]_. 

108 theta 

109 a 

110 b 

111 rho 

112 Hyperparameters of the sparsity prior used for variable selection. 

113 

114 The prior distribution on the choice of predictor for each decision rule 

115 is 

116 

117 .. math:: 

118 (s_1, \ldots, s_p) \sim 

119 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p). 

120 

121 If `theta` is not specified, it's a priori distributed according to 

122 

123 .. math:: 

124 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim 

125 \operatorname{Beta}(\mathtt{a}, \mathtt{b}). 

126 

127 If not specified, `rho` is set to the number of predictors p. To tune 

128 the prior, consider setting a lower `rho` to prefer more sparsity. 

129 If setting `theta` directly, it should be in the ballpark of p or lower 

130 as well. 

131 varprob 

132 The probability distribution over the `p` predictors for choosing a 

133 predictor to split on in a decision node a priori. Must be > 0. It does 

134 not need to be normalized to sum to 1. If not specified, use a uniform 

135 distribution. If ``sparse=True``, this is used as initial value for the 

136 MCMC. 

137 xinfo 

138 A matrix with the cutpoins to use to bin each predictor. If not 

139 specified, it is generated automatically according to `usequants` and 

140 `numcut`. 

141 

142 Each row shall contain a sorted list of cutpoints for a predictor. If 

143 there are less cutpoints than the number of columns in the matrix, 

144 fill the remaining cells with NaN. 

145 

146 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

147 usequants 

148 Whether to use predictors quantiles instead of a uniform grid to bin 

149 predictors. Ignored if `xinfo` is specified. 

150 rm_const 

151 How to treat predictors with no associated decision rules (i.e., there 

152 are no available cutpoints for that predictor). If `True` (default), 

153 they are ignored. If `False`, an error is raised if there are any. 

154 sigest 

155 An estimate of the residual standard deviation on `y_train`, used to set 

156 `lamda`. If not specified, it is estimated by linear regression (with 

157 intercept, and without taking into account `w`). If `y_train` has less 

158 than two elements, it is set to 1. If n <= p, it is set to the standard 

159 deviation of `y_train`. Ignored if `lamda` is specified. 

160 sigdf 

161 The degrees of freedom of the scaled inverse-chisquared prior on the 

162 noise variance. 

163 sigquant 

164 The quantile of the prior on the noise variance that shall match 

165 `sigest` to set the scale of the prior. Ignored if `lamda` is specified. 

166 k 

167 The inverse scale of the prior standard deviation on the latent mean 

168 function, relative to half the observed range of `y_train`. If `y_train` 

169 has less than two elements, `k` is ignored and the scale is set to 1. 

170 power 

171 base 

172 Parameters of the prior on tree node generation. The probability that a 

173 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) ** 

174 power``. 

175 lamda 

176 The prior harmonic mean of the error variance. (The harmonic mean of x 

177 is 1/mean(1/x).) If not specified, it is set based on `sigest` and 

178 `sigquant`. 

179 tau_num 

180 The numerator in the expression that determines the prior standard 

181 deviation of leaves. If not specified, default to ``(max(y_train) - 

182 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for 

183 continuous regression, and 3 for binary regression. 

184 offset 

185 The prior mean of the latent mean function. If not specified, it is set 

186 to the mean of `y_train` for continuous regression, and to 

187 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty, 

188 `offset` is set to 0. With binary regression, if `y_train` is all 

189 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or 

190 ``Phi^-1(n/(n+1))``, respectively. 

191 w 

192 Coefficients that rescale the error standard deviation on each 

193 datapoint. Not specifying `w` is equivalent to setting it to 1 for all 

194 datapoints. Note: `w` is ignored in the automatic determination of 

195 `sigest`, so either the weights should be O(1), or `sigest` should be 

196 specified by the user. 

197 num_trees 

198 The number of trees used to represent the latent mean function. 

199 numcut 

200 If `usequants` is `False`: the exact number of cutpoints used to bin the 

201 predictors, ranging between the minimum and maximum observed values 

202 (excluded). 

203 

204 If `usequants` is `True`: the maximum number of cutpoints to use for 

205 binning the predictors. Each predictor is binned such that its 

206 distribution in `x_train` is approximately uniform across bins. The 

207 number of bins is at most the number of unique values appearing in 

208 `x_train`, or ``numcut + 1``. 

209 

210 Before running the algorithm, the predictors are compressed to the 

211 smallest integer type that fits the bin indices, so `numcut` is best set 

212 to the maximum value of an unsigned integer type, like 255. 

213 

214 Ignored if `xinfo` is specified. 

215 ndpost 

216 The number of MCMC samples to save, after burn-in. `ndpost` is the 

217 total number of samples across all chains. `ndpost` is rounded up to the 

218 first multiple of `mc_cores`. 

219 nskip 

220 The number of initial MCMC samples to discard as burn-in. This number 

221 of samples is discarded from each chain. 

222 keepevery 

223 The thinning factor for the MCMC samples, after burn-in. 

224 printevery 

225 The number of iterations (including thinned-away ones) between each log 

226 line. Set to `None` to disable logging. ^C interrupts the MCMC only 

227 every `printevery` iterations, so with logging disabled it's impossible 

228 to kill the MCMC conveniently. 

229 num_chains 

230 The number of independent Markov chains to run. 

231 

232 The difference between ``num_chains=None`` and ``num_chains=1`` is that 

233 in the latter case in the object attributes and some methods there will 

234 be an explicit chain axis of size 1. 

235 num_chain_devices 

236 The number of devices to spread the chains across. Must be a divisor of 

237 `num_chains`. Each device will run a fraction of the chains. 

238 num_data_devices 

239 The number of devices to split datapoints across. Must be a divisor of 

240 `n`. This is useful only with very high `n`, about > 1000_000. 

241 

242 If both num_chain_devices and num_data_devices are specified, the total 

243 number of devices used is the product of the two. 

244 devices 

245 One or more devices used to run the MCMC on. If not specified, the 

246 computation will follow the placement of the input arrays. If a list of 

247 devices, this argument can be longer than the number of devices needed. 

248 seed 

249 The seed for the random number generator. 

250 maxdepth 

251 The maximum depth of the trees. This is 1-based, so with the default 

252 ``maxdepth=6``, the depths of the levels range from 0 to 5. 

253 init_kw 

254 Additional arguments passed to `bartz.mcmcstep.init`. 

255 run_mcmc_kw 

256 Additional arguments passed to `bartz.mcmcloop.run_mcmc`. 

257 

258 References 

259 ---------- 

260 .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for 

261 High-Dimensional Prediction and Variable Selection”. In: Journal of the 

262 American Statistical Association 113.522, pp. 626-636. 

263 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART: 

264 Bayesian additive regression trees," The Annals of Applied Statistics, 

265 Ann. Appl. Stat. 4(1), 266-298, (March 2010). 

266 """ 

267 

268 _main_trace: mcmcloop.MainTrace 

269 _burnin_trace: mcmcloop.BurninTrace 

270 _mcmc_state: mcmcstep.State 

271 _splits: Real[Array, 'p max_num_splits'] 

272 _x_train_fmt: Any = field(static=True) 

273 

274 offset: Float32[Array, ''] 

275 """The prior mean of the latent mean function.""" 

276 

277 sigest: Float32[Array, ''] | None = None 

278 """The estimated standard deviation of the error used to set `lamda`.""" 

279 

280 yhat_test: Float32[Array, 'ndpost m'] | None = None 

281 """The conditional posterior mean at `x_test` for each MCMC iteration.""" 

282 

283 def __init__( 

284 self, 

285 x_train: Real[Array, 'p n'] | DataFrame, 

286 y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series, 

287 *, 

288 x_test: Real[Array, 'p m'] | DataFrame | None = None, 

289 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002 

290 sparse: bool = False, 

291 theta: FloatLike | None = None, 

292 a: FloatLike = 0.5, 

293 b: FloatLike = 1.0, 

294 rho: FloatLike | None = None, 

295 varprob: Float[Array, ' p'] | None = None, 

296 xinfo: Float[Array, 'p n'] | None = None, 

297 usequants: bool = False, 

298 rm_const: bool = True, 

299 sigest: FloatLike | None = None, 

300 sigdf: FloatLike = 3.0, 

301 sigquant: FloatLike = 0.9, 

302 k: FloatLike = 2.0, 

303 power: FloatLike = 2.0, 

304 base: FloatLike = 0.95, 

305 lamda: FloatLike | None = None, 

306 tau_num: FloatLike | None = None, 

307 offset: FloatLike | None = None, 

308 w: Float[Array, ' n'] | Series | None = None, 

309 num_trees: int = 200, 

310 numcut: int = 255, 

311 ndpost: int = 1000, 

312 nskip: int = 1000, 

313 keepevery: int = 1, 

314 printevery: int | None = 100, 

315 num_chains: int | None = 4, 

316 num_chain_devices: int | None = None, 

317 num_data_devices: int | None = None, 

318 devices: Device | Sequence[Device] | None = None, 

319 seed: int | Key[Array, ''] = 0, 

320 maxdepth: int = 6, 

321 init_kw: Mapping = MappingProxyType({}), 

322 run_mcmc_kw: Mapping = MappingProxyType({}), 

323 ) -> None: 

324 # check data and put it in the right format 

325 x_train, x_train_fmt = self._process_predictor_input(x_train) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

326 y_train = self._process_response_input(y_train) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

327 self._check_same_length(x_train, y_train) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

328 if w is not None: 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

329 w = self._process_response_input(w) 1sw?]phNzCn$JPQMGVaerDkAcR

330 self._check_same_length(x_train, w) 1sw?]phNzCn$JPQMGVaerDkAcR

331 

332 # check data types are correct for continuous/binary regression 

333 self._check_type_settings(y_train, type, w) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

334 # from here onwards, the type is determined by y_train.dtype == bool 

335 

336 # process sparsity settings 

337 theta, a, b, rho = self._process_sparsity_settings( 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

338 x_train, sparse, theta, a, b, rho 

339 ) 

340 

341 # process "standardization" settings 

342 offset = self._process_offset_settings(y_train, offset) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

343 sigma_mu = self._process_leaf_sdev_settings(y_train, k, num_trees, tau_num) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

344 lamda, sigest = self._process_error_variance_settings( 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

345 x_train, y_train, sigest, sigdf, sigquant, lamda 

346 ) 

347 

348 # determine splits 

349 splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

350 x_train = self._bin_predictors(x_train, splits) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

351 

352 # setup and run mcmc 

353 initial_state = self._setup_mcmc( 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

354 x_train, 

355 y_train, 

356 offset, 

357 w, 

358 max_split, 

359 lamda, 

360 sigma_mu, 

361 sigdf, 

362 power, 

363 base, 

364 maxdepth, 

365 num_trees, 

366 init_kw, 

367 rm_const, 

368 theta, 

369 a, 

370 b, 

371 rho, 

372 varprob, 

373 num_chains, 

374 num_chain_devices, 

375 num_data_devices, 

376 devices, 

377 sparse, 

378 nskip, 

379 ) 

380 result = self._run_mcmc( 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

381 initial_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw 

382 ) 

383 

384 # set public attributes 

385 # set offset from the state because of buffer donation 

386 self.offset = result.final_state.offset 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

387 self.sigest = sigest 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

388 

389 # set private attributes 

390 self._main_trace = result.main_trace 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

391 self._burnin_trace = result.burnin_trace 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

392 self._mcmc_state = result.final_state 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

393 self._splits = splits 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

394 self._x_train_fmt = x_train_fmt 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

395 

396 # predict at test points 

397 if x_test is not None: 1vYsXyZw]t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

398 self.yhat_test = self.predict(x_test) 1vYsyZw]t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxDlkE5AdjcW/R

399 

400 @property 

401 def ndpost(self) -> int: 

402 """The total number of posterior samples after burn-in across all chains. 

403 

404 May be larger than the initialization argument `ndpost` if it was not 

405 divisible by the number of chains. 

406 """ 

407 return self._main_trace.grow_prop_count.size 1YX`{^}_mbgafqe-89.E5Aj

408 

409 @property 

410 def num_trees(self) -> int: 

411 """Return the number of trees used in the model.""" 

412 return self._mcmc_state.forest.split_tree.shape[-2] 1K!G

413 

414 @cached_property 

415 def prob_test(self) -> Float32[Array, 'ndpost m'] | None: 

416 """The posterior probability of y being True at `x_test` for each MCMC iteration.""" 

417 if self.yhat_test is None or self._mcmc_state.y.dtype != bool: 1|imhbgaqdjc

418 return None 1ihbadc

419 else: 

420 return ndtr(self.yhat_test) 1|mgqj

421 

422 @cached_property 

423 def prob_test_mean(self) -> Float32[Array, ' m'] | None: 

424 """The marginal posterior probability of y being True at `x_test`.""" 

425 if self.prob_test is None: 1bgaqx

426 return None 1ba

427 else: 

428 return self.prob_test.mean(axis=0) 1gqx

429 

430 @cached_property 

431 def prob_train(self) -> Float32[Array, 'ndpost n'] | None: 

432 """The posterior probability of y being True at `x_train` for each MCMC iteration.""" 

433 if self._mcmc_state.y.dtype == bool: 1|}imhbgaqdjc

434 return ndtr(self.yhat_train) 1|}mgqj

435 else: 

436 return None 1ihbadc

437 

438 @cached_property 

439 def prob_train_mean(self) -> Float32[Array, ' n'] | None: 

440 """The marginal posterior probability of y being True at `x_train`.""" 

441 if self.prob_train is None: 1bgaqxdjc

442 return None 1badc

443 else: 

444 return self.prob_train.mean(axis=0) 1gqxj

445 

446 @cached_property 

447 def sigma( 

448 self, 

449 ) -> ( 

450 Float32[Array, ' nskip+ndpost'] 

451 | Float32[Array, 'nskip+ndpost/mc_cores mc_cores'] 

452 | None 

453 ): 

454 """The standard deviation of the error, including burn-in samples.""" 

455 if self._burnin_trace.error_cov_inv is None: 1^_imhbgafeu3rxlkdjc

456 return None 1mg3xj

457 assert self._main_trace.error_cov_inv is not None 1^_ihbafeurlkdc

458 return jnp.sqrt( 1^_ihbafeurlkdc

459 jnp.reciprocal( 

460 jnp.concatenate( 

461 [ 

462 self._burnin_trace.error_cov_inv.T, 

463 self._main_trace.error_cov_inv.T, 

464 ], 

465 axis=0, 

466 # error_cov_inv has shape (chains? samples) in the trace 

467 ) 

468 ) 

469 ) 

470 

471 @cached_property 

472 def sigma_(self) -> Float32[Array, 'ndpost'] | None: 

473 """The standard deviation of the error, only over the post-burnin samples and flattened.""" 

474 error_cov_inv = self._main_trace.error_cov_inv 1`{imhbgafelkdjc

475 if error_cov_inv is None: 1`{imhbgafexlkdjc

476 return None 1mgxj

477 else: 

478 return jnp.sqrt(jnp.reciprocal(error_cov_inv)).reshape(-1) 1`{ihbafelkdc

479 

480 @cached_property 

481 def sigma_mean(self) -> Float32[Array, ''] | None: 

482 """The mean of `sigma`, only over the post-burnin samples.""" 

483 if self.sigma_ is None: 1`{bgafexlkdjc

484 return None 1gxj

485 return self.sigma_.mean() 1`{bafelkdc

486 

487 @cached_property 

488 def varcount(self) -> Int32[Array, 'ndpost p']: 

489 """Histogram of predictor usage for decision rules in the trees.""" 

490 p = self._mcmc_state.forest.max_split.size 1yZw|^_imhbgafqe89.djc

491 return varcount(p, self._main_trace) 1yZw|^_imhbgafqe89.djc

492 

493 @cached_property 

494 def varcount_mean(self) -> Float32[Array, ' p']: 

495 """Average of `varcount` across MCMC iterations.""" 

496 return self.varcount.mean(axis=0) 1yZw`{bgafqedjc

497 

498 @cached_property 

499 def varprob(self) -> Float32[Array, 'ndpost p']: 

500 """Posterior samples of the probability of choosing each predictor for a decision rule.""" 

501 max_split = self._mcmc_state.forest.max_split 1vYsX^_imhbgafqedjcLB

502 p = max_split.size 1vYsX^_imhbgafqedjcLB

503 varprob = self._main_trace.varprob 1vYsX^_imhbgafqedjcLB

504 if varprob is None: 1vYsX^_imhbgafqexdjcLB

505 peff = jnp.count_nonzero(max_split) 1YXmgqxj

506 varprob = jnp.where(max_split, 1 / peff, 0) 1YXmgqxj

507 varprob = jnp.broadcast_to(varprob, (self.ndpost, p)) 1YXmgqxj

508 else: 

509 varprob = varprob.reshape(-1, p) 1vs^_ihbafedcLB

510 return varprob 1vYsX^_imhbgafqexdjcLB

511 

512 @cached_property 

513 def varprob_mean(self) -> Float32[Array, ' p']: 

514 """The marginal posterior probability of each predictor being chosen for a decision rule.""" 

515 return self.varprob.mean(axis=0) 1vYsX`{bgafqedjcLB

516 

517 @cached_property 

518 def yhat_test_mean(self) -> Float32[Array, ' m'] | None: 

519 """The marginal posterior mean at `x_test`. 

520 

521 Not defined with binary regression because it's error-prone, typically 

522 the right thing to consider would be `prob_test_mean`. 

523 """ 

524 if self.yhat_test is None or self._mcmc_state.y.dtype == bool: 1`{bgafexlkdjc

525 return None 1gxj

526 else: 

527 return self.yhat_test.mean(axis=0) 1`{bafelkdc

528 

529 @cached_property 

530 def yhat_train(self) -> Float32[Array, 'ndpost n']: 

531 """The conditional posterior mean at `x_train` for each MCMC iteration.""" 

532 x_train = self._mcmc_state.X 1^}_imho7nO(J0+Mbgafqeu3rHxDlkE5Adjc

533 return self._predict(x_train) 1^}_imho7nO(J0+Mbgafqeu3rHxDlkE5Adjc

534 

535 @cached_property 

536 def yhat_train_mean(self) -> Float32[Array, ' n'] | None: 

537 """The marginal posterior mean at `x_train`. 

538 

539 Not defined with binary regression because it's error-prone, typically 

540 the right thing to consider would be `prob_train_mean`. 

541 """ 

542 if self._mcmc_state.y.dtype == bool: 1`{bgafexlkdjc

543 return None 1gxj

544 else: 

545 return self.yhat_train.mean(axis=0) 1`{bafelkdc

546 

547 def predict( 

548 self, x_test: Real[Array, 'p m'] | DataFrame 

549 ) -> Float32[Array, 'ndpost m']: 

550 """ 

551 Compute the posterior mean at `x_test` for each MCMC iteration. 

552 

553 Parameters 

554 ---------- 

555 x_test 

556 The test predictors. 

557 

558 Returns 

559 ------- 

560 The conditional posterior mean at `x_test` for each MCMC iteration. 

561 

562 Raises 

563 ------ 

564 ValueError 

565 If `x_test` has a different format than `x_train`. 

566 """ 

567 x_test, x_test_fmt = self._process_predictor_input(x_test) 1vYsyZw]t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxDlkE5AdjcW/R

568 if x_test_fmt != self._x_train_fmt: 1vYsyZw]t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxDlkE5AdjcW/R

569 msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}' 1F6z

570 raise ValueError(msg) 1F6z

571 x_test = self._bin_predictors(x_test, self._splits) 1vYsyZw]t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxDlkE5AdjcW/R

572 return self._predict(x_test) 1vYsyZw]t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxDlkE5AdjcW/R

573 

574 @staticmethod 

575 def _process_predictor_input( 

576 x: Real[Any, 'p n'] | DataFrame, 

577 ) -> tuple[Shaped[Array, 'p n'], Any]: 

578 if hasattr(x, 'columns'): 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

579 fmt = dict(kind='dataframe', columns=x.columns) 1F6zu3r

580 x = x.to_numpy().T 1F6zu3r

581 else: 

582 fmt = dict(kind='array', num_covar=x.shape[0]) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

583 x = jnp.asarray(x) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

584 assert x.ndim == 2 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

585 return x, fmt 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

586 

587 @staticmethod 

588 def _process_response_input(y: Shaped[Array, ' n'] | Series) -> Shaped[Array, ' n']: 

589 if hasattr(y, 'to_numpy'): 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

590 y = y.to_numpy() 1zu3r

591 y = jnp.asarray(y) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

592 assert y.ndim == 1 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

593 return y 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

594 

595 @staticmethod 

596 def _check_same_length(x1: Array, x2: Array) -> None: 

597 get_length = lambda x: x.shape[-1] 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

598 assert get_length(x1) == get_length(x2) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

599 

600 @classmethod 

601 def _process_error_variance_settings( 

602 cls, 

603 x_train: Shaped[Array, 'p n'], 

604 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

605 sigest: FloatLike | None, 

606 sigdf: FloatLike, 

607 sigquant: FloatLike, 

608 lamda: FloatLike | None, 

609 ) -> tuple[Float32[Array, ''] | None, ...]: 

610 """Return (lamda, sigest).""" 

611 if y_train.dtype == bool: 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

612 if sigest is not None: 612 ↛ 613line 612 didn't jump to line 613 because the condition on line 612 was never true1YZ@#m%627;()*+!,gq3x5j/

613 msg = 'Let `sigest=None` for binary regression' 

614 raise ValueError(msg) 

615 if lamda is not None: 615 ↛ 616line 615 didn't jump to line 616 because the condition on line 615 was never true1YZ@#m%627;()*+!,gq3x5j/

616 msg = 'Let `lamda=None` for binary regression' 

617 raise ValueError(msg) 

618 return None, None 1YZ@#m%627;()*+!,gq3x5j/

619 elif lamda is not None: 619 ↛ 620line 619 didn't jump to line 620 because the condition on line 619 was never true1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:

620 if sigest is not None: 

621 msg = 'Let `sigest=None` if `lamda` is specified' 

622 raise ValueError(msg) 

623 return lamda, None 

624 else: 

625 if sigest is not None: 625 ↛ 626line 625 didn't jump to line 626 because the condition on line 625 was never true1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:

626 sigest2 = jnp.square(sigest) 

627 elif y_train.size < 2: 1vsXyw=?]tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:

628 sigest2 = 1 10M1V-89.[:

629 elif y_train.size <= x_train.shape[0]: 1vsXyw=?]tpihSNFzICon'$OJTPUQKGbafeurHDlkEAdcWRLB

630 sigest2 = jnp.var(y_train) 1WR

631 else: 

632 sigest2 = cls._linear_regression(x_train, y_train) 1vsXyw=?]tpihSNFzICon'$OJTPUQKGbafeurHDlkEAdcLB

633 alpha = sigdf / 2 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:

634 invchi2 = invgamma.ppf(sigquant, alpha) / 2 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:

635 invchi2rid = invchi2 * sigdf 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:

636 return sigest2 / invchi2rid, jnp.sqrt(sigest2) 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:

637 

638 @staticmethod 

639 @jit 

640 def _linear_regression( 

641 x_train: Shaped[Array, 'p n'], y_train: Float32[Array, ' n'] 

642 ) -> Float32[Array, '']: 

643 """Return the error variance estimated with OLS with intercept.""" 

644 x_centered = x_train.T - x_train.mean(axis=1) 1]tponLB

645 y_centered = y_train - y_train.mean() 1]tponLB

646 # centering is equivalent to adding an intercept column 

647 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1]tponLB

648 chisq = chisq.squeeze(0) 1]tponLB

649 dof = len(y_train) - rank 1]tponLB

650 return chisq / dof 1]tponLB

651 

652 @staticmethod 

653 def _check_type_settings( 

654 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

655 type: str, # noqa: A002 

656 w: Float[Array, ' n'] | None, 

657 ) -> None: 

658 match type: 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

659 case 'wbart': 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

660 if y_train.dtype != jnp.float32: 660 ↛ 661line 660 didn't jump to line 661 because the condition on line 660 was never true1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:

661 msg = ( 

662 'Continuous regression requires y_train.dtype=float32,' 

663 f' got {y_train.dtype=} instead.' 

664 ) 

665 raise TypeError(msg) 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:

666 case 'pbart': 666 ↛ 676line 666 didn't jump to line 676 because the pattern on line 666 always matched1YZ@#m%627;()*+!,gq3x5j/

667 if w is not None: 667 ↛ 668line 667 didn't jump to line 668 because the condition on line 667 was never true1YZ@#m%627;()*+!,gq3x5j/

668 msg = 'Binary regression does not support weights, set `w=None`' 

669 raise ValueError(msg) 

670 if y_train.dtype != bool: 670 ↛ 671line 670 didn't jump to line 671 because the condition on line 670 was never true1YZ@#m%627;()*+!,gq3x5j/

671 msg = ( 

672 'Binary regression requires y_train.dtype=bool,' 

673 f' got {y_train.dtype=} instead.' 

674 ) 

675 raise TypeError(msg) 1YZ@#m%627;()*+!,gq3x5j/

676 case _: 

677 msg = f'Invalid {type=}' 

678 raise ValueError(msg) 

679 

680 @staticmethod 

681 def _process_sparsity_settings( 

682 x_train: Real[Array, 'p n'], 

683 sparse: bool, 

684 theta: FloatLike | None, 

685 a: FloatLike, 

686 b: FloatLike, 

687 rho: FloatLike | None, 

688 ) -> ( 

689 tuple[None, None, None, None] 

690 | tuple[FloatLike, None, None, None] 

691 | tuple[None, FloatLike, FloatLike, FloatLike] 

692 ): 

693 """Return (theta, a, b, rho).""" 

694 if not sparse: 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

695 return None, None, None, None 1YXyZw@#m%627;()*+!,gq3x-89.5j/[:

696 elif theta is not None: 1vs=?]tpihSNFzICon'$OJTPUQ0MKG1VbafeurHDlkEAdcWRLB

697 return theta, None, None, None 1s?]phNzCn$JPQMGVaerDkAcRL

698 else: 

699 if rho is None: 699 ↛ 702line 699 didn't jump to line 702 because the condition on line 699 was always true1v=tiSFIo'OTU0K1bfuHlEdWB

700 p, _ = x_train.shape 1v=tiSFIo'OTU0K1bfuHlEdWB

701 rho = float(p) 1v=tiSFIo'OTU0K1bfuHlEdWB

702 return None, a, b, rho 1v=tiSFIo'OTU0K1bfuHlEdWB

703 

704 @staticmethod 

705 def _process_offset_settings( 

706 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

707 offset: float | Float32[Any, ''] | None, 

708 ) -> Float32[Array, '']: 

709 """Return offset.""" 

710 if offset is not None: 710 ↛ 711line 710 didn't jump to line 711 because the condition on line 710 was never true1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

711 return jnp.asarray(offset) 

712 elif y_train.size < 1: 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

713 return jnp.array(0.0) 10+M-89.[:

714 else: 

715 mean = y_train.mean() 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*QK!G1,Vbgafqeu3rHxDlkE5AdjcW/RLB

716 

717 if y_train.dtype == bool: 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*QK!G1,Vbgafqeu3rHxDlkE5AdjcW/RLB

718 bound = 1 / (1 + y_train.size) 1YZ@#m%627;()*!,gq3x5j/

719 mean = jnp.clip(mean, bound, 1 - bound) 1YZ@#m%627;()*!,gq3x5j/

720 return ndtri(mean) 1YZ@#m%627;()*!,gq3x5j/

721 else: 

722 return mean 1vsXyw=?]tpihSNFzICon'$OJTPUQKG1VbafeurHDlkEAdcWRLB

723 

724 @staticmethod 

725 def _process_leaf_sdev_settings( 

726 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

727 k: FloatLike, 

728 num_trees: int, 

729 tau_num: FloatLike | None, 

730 ) -> FloatLike: 

731 """Return sigma_mu.""" 

732 if tau_num is None: 732 ↛ 740line 732 didn't jump to line 740 because the condition on line 732 was always true1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

733 if y_train.dtype == bool: 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

734 tau_num = 3.0 1YZ@#m%627;()*+!,gq3x5j/

735 elif y_train.size < 2: 1vsXyw=?]tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:

736 tau_num = 1.0 10M1V-89.[:

737 else: 

738 tau_num = (y_train.max() - y_train.min()) / 2 1vsXyw=?]tpihSNFzICon'$OJTPUQKGbafeurHDlkEAdcWRLB

739 

740 return tau_num / (k * math.sqrt(num_trees)) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

741 

742 @staticmethod 

743 def _determine_splits( 

744 x_train: Real[Array, 'p n'], 

745 usequants: bool, 

746 numcut: int, 

747 xinfo: Float[Array, 'p n'] | None, 

748 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

749 if xinfo is not None: 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:

750 if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]: 1X0+M,V-89.[:

751 msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)' 1[

752 raise ValueError(msg) 1[

753 return prepcovars.parse_xinfo(xinfo) 1X0+M,V-89.:

754 elif usequants: 1vYsyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*QK!G1bgafqeu3rHxDlkE5AdjcW/RLB

755 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1YsZw@?]#pmh%N6z2C7n;$(J)P*Q!Ggaqe3rxDk5Ajc/R

756 else: 

757 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1vy=tiSFIo'OTUK1bfuHlEdWLB

758 

759 @staticmethod 

760 def _bin_predictors( 

761 x: Real[Array, 'p n'], splits: Real[Array, 'p max_num_splits'] 

762 ) -> UInt[Array, 'p n']: 

763 return prepcovars.bin_predictors(x, splits) 1vYsXyZw=@?`|{t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

764 

765 @staticmethod 

766 def _setup_mcmc( 

767 x_train: Real[Array, 'p n'], 

768 y_train: Float32[Array, ' n'] | Bool[Array, ' n'], 

769 offset: Float32[Array, ''], 

770 w: Float[Array, ' n'] | None, 

771 max_split: UInt[Array, ' p'], 

772 lamda: Float32[Array, ''] | None, 

773 sigma_mu: FloatLike, 

774 sigdf: FloatLike, 

775 power: FloatLike, 

776 base: FloatLike, 

777 maxdepth: int, 

778 num_trees: int, 

779 init_kw: Mapping[str, Any], 

780 rm_const: bool, 

781 theta: FloatLike | None, 

782 a: FloatLike | None, 

783 b: FloatLike | None, 

784 rho: FloatLike | None, 

785 varprob: Float[Any, ' p'] | None, 

786 num_chains: int | None, 

787 num_chain_devices: int | None, 

788 num_data_devices: int | None, 

789 devices: Device | Sequence[Device] | None, 

790 sparse: bool, 

791 nskip: int, 

792 ) -> mcmcstep.State: 

793 p_nonterminal = make_p_nonterminal(maxdepth, base, power) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

794 

795 if y_train.dtype == bool: 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

796 error_cov_df = None 1YZ@#m%627;()*+!,gq3x5j/

797 error_cov_scale = None 1YZ@#m%627;()*+!,gq3x5j/

798 else: 

799 assert lamda is not None 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB:

800 # inverse gamma prior: alpha = df / 2, beta = scale / 2 

801 error_cov_df = sigdf 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB:

802 error_cov_scale = lamda * sigdf 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB:

803 

804 # process device settings 

805 device_kw, device = process_device_settings( 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

806 y_train, num_chains, num_chain_devices, num_data_devices, devices 

807 ) 

808 

809 kw: dict = dict( 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

810 X=x_train, 

811 # copy y_train because it's going to be donated in the mcmc loop 

812 y=jnp.array(y_train), 

813 offset=offset, 

814 error_scale=w, 

815 max_split=max_split, 

816 num_trees=num_trees, 

817 p_nonterminal=p_nonterminal, 

818 leaf_prior_cov_inv=jnp.reciprocal(jnp.square(sigma_mu)), 

819 error_cov_df=error_cov_df, 

820 error_cov_scale=error_cov_scale, 

821 min_points_per_decision_node=10, 

822 log_s=process_varprob(varprob, max_split), 

823 theta=theta, 

824 a=a, 

825 b=b, 

826 rho=rho, 

827 sparse_on_at=nskip // 2 if sparse else None, 

828 **device_kw, 

829 ) 

830 

831 if rm_const: 1vYsXyZw]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

832 n_empty = jnp.sum(max_split == 0).item() 1vYsXyZwt#pimhS%NF6zI2Co7n';$T)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

833 kw.update(filter_splitless_vars=n_empty) 1vYsXyZwt#pimhS%NF6zI2Co7n';$T)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

834 

835 kw.update(init_kw) 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

836 

837 state = mcmcstep.init(**kw) 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

838 

839 # put state on device if requested explicitly by the user 

840 if device is not None: 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

841 state = device_put(state, device, donate=True) 1(J

842 

843 return state 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

844 

845 @classmethod 

846 def _run_mcmc( 

847 cls, 

848 mcmc_state: mcmcstep.State, 

849 ndpost: int, 

850 nskip: int, 

851 keepevery: int, 

852 printevery: int | None, 

853 seed: int | Integer[Array, ''] | Key[Array, ''], 

854 run_mcmc_kw: Mapping, 

855 ) -> RunMCMCResult: 

856 # prepare random generator seed 

857 if is_key(seed): 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

858 key = jnp.copy(seed) 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB

859 else: 

860 key = jax.random.key(seed) 1:

861 

862 # round up ndpost 

863 num_chains = get_num_chains(mcmc_state) 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

864 if num_chains is None: 1vYsXyZw]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

865 num_chains = 1 1vyiSFo'OTU0K1bfuH-89.lEdW

866 n_save = ndpost // num_chains + bool(ndpost % num_chains) 1vYsXyZw]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

867 

868 # prepare arguments 

869 kw: dict = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery) 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

870 kw.update( 1vYsXyZw]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

871 mcmcloop.make_default_callback( 

872 mcmc_state, 

873 dot_every=None if printevery is None or printevery == 1 else 1, 

874 report_every=printevery, 

875 ) 

876 ) 

877 kw.update(run_mcmc_kw) 1vYsXyZw]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

878 

879 return run_mcmc(key, mcmc_state, n_save, **kw) 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

880 

881 def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']: 

882 """Evaluate trees on already quantized `x`.""" 

883 return predict(x, self._main_trace) 1vYsyZw^}_t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/R

884 

885 

886@partial(jit, static_argnames='p') 

887# this is jitted such that lax.collapse below does not create a copy 

888def varcount(p: int, trace: mcmcloop.MainTrace) -> Int32[Array, 'ndpost p']: 

889 """Histogram of predictor usage for decision rules in the trees, squashing chains.""" 

890 varcount: Int32[Array, '*chains samples p'] 

891 varcount = compute_varcount(p, trace) 1yZw|^_mbga89

892 return lax.collapse(varcount, 0, -1) 1yZw|^_mbga89

893 

894 

895@jit 

896# this is jitted such that lax.collapse below does not create a copy 

897def predict( 

898 x: UInt[Array, 'p m'], trace: mcmcloop.MainTrace 

899) -> Float32[Array, 'ndpost m']: 

900 """Evaluate trees on already quantized `x`, and squash chains.""" 

901 out = evaluate_trace(x, trace) 1yZw^}_t#pmI2Co7nOJ0+MK!G1HD-89E5A

902 return lax.collapse(out, 0, -1) 1yZw^}_t#pmI2Co7nOJ0+MK!G1HD-89E5A

903 

904 

905class DeviceKwArgs(TypedDict): 

906 num_chains: int | None 

907 mesh: Mesh | None 

908 target_platform: Literal['cpu', 'gpu'] | None 

909 

910 

911def process_device_settings( 

912 y_train: Array, 

913 num_chains: int | None, 

914 num_chain_devices: int | None, 

915 num_data_devices: int | None, 

916 devices: Device | Sequence[Device] | None, 

917) -> tuple[DeviceKwArgs, Device | None]: 

918 """Return the arguments for `mcmcstep.init` related to devices, and an optional device where to put the state.""" 

919 # determine devices 

920 if devices is not None: 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

921 if not hasattr(devices, '__len__'): 921 ↛ 922line 921 didn't jump to line 922 because the condition on line 921 was never true1O(J

922 devices = (devices,) 

923 device = devices[0] 1O(J

924 platform = device.platform 1O(J

925 elif hasattr(y_train, 'platform'): 925 ↛ 933line 925 didn't jump to line 933 because the condition on line 925 was always true1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$T)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

926 platform = y_train.platform() 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$T)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

927 device = None 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$T)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

928 # set device=None because if the devices were not specified explicitly 

929 # we may be in the case where computation will follow data placement, 

930 # do not disturb jax as the user may be playing with vmap, jit, reshard... 

931 devices = jax.devices(platform) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$T)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

932 else: 

933 msg = 'not possible to infer device from `y_train`, please set `devices`' 

934 raise ValueError(msg) 

935 

936 # create mesh 

937 if num_chain_devices is None and num_data_devices is None: 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

938 mesh = None 1YZ@#m%6I2C7;()*0+!1,gq3x-89.5j/

939 else: 

940 mesh = dict() 1vsXyw=?tpihSNFzI2Con'$OJTPUQMKGVbafeurHDlkEAdcWRLB:

941 if num_chain_devices is not None: 1vsXyw=?tpihSNFzI2Con'$OJTPUQMKGVbafeurHDlkEAdcWRLB:

942 mesh.update(chains=num_chain_devices) 1sXw?tphNzI2Cn$JPQMGVaerDkAcRLB:

943 if num_data_devices is not None: 1vsXyw=?tpihSNFzI2Con'$OJTPUQMKGVbafeurHDlkEAdcWRLB:

944 mesh.update(data=num_data_devices) 1vy=tiSFI2Co'OTUKbfuHlEdW

945 mesh = make_mesh( 1vsXyw=?tpihSNFzI2Con'$OJTPUQMKGVbafeurHDlkEAdcWRLB:

946 axis_shapes=tuple(mesh.values()), 

947 axis_names=tuple(mesh), 

948 axis_types=(AxisType.Auto,) * len(mesh), 

949 devices=devices, 

950 ) 

951 device = None 1vsXyw=?tpihSNFzI2Con'$OJTPUQMKGVbafeurHDlkEAdcWRLB:

952 # set device=None because `mcmcstep.init` will `device_put` with the 

953 # mesh already, we don't want to undo its work 

954 

955 # prepare arguments to `init` 

956 settings = DeviceKwArgs( 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

957 num_chains=num_chains, 

958 mesh=mesh, 

959 target_platform=None 

960 if mesh is not None or hasattr(y_train, 'platform') 

961 else platform, 

962 # here we don't take into account the case where the user has set both 

963 # batch sizes; since the user has to be playing with `init_kw` to do 

964 # that, we'll let `init` throw the error and the user set 

965 # `target_platform` themselves so they have a clearer idea how the 

966 # thing works. 

967 ) 

968 

969 return settings, device 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

970 

971 

972def process_varprob( 

973 varprob: Float[Any, ' p'] | None, max_split: UInt[Array, ' p'] 

974) -> Float32[Array, ' p'] | None: 

975 """Convert varprob to log_s.""" 

976 if varprob is None: 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:

977 return None 1vYXt#imS%F6I2o7';O(T)U*0+K!1,bgfqu3Hx-89.lE5djW/LB:

978 varprob = jnp.asarray(varprob) 1syZw=@?]phNzCn$JPQMGVaerDkAcR

979 assert varprob.shape == max_split.shape, 'varprob must have shape (p,)' 1syZw=@?]phNzCn$JPQMGVaerDkAcR

980 varprob = error_if(varprob, varprob <= 0, 'varprob must be > 0') 1syZw=@?]phNzCn$JPQMGVaerDkAcR

981 return jnp.log(varprob) 1syZw]phNzCn$JPQMGVaerDkAcR