Coverage for src / bartz / _interface.py: 89%

457 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 18:11 +0000

1# bartz/src/bartz/_interface.py 

2# 

3# Copyright (c) 2025-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Main high-level interface of the package.""" 

26 

27import math 

28from collections.abc import Mapping, Sequence 

29from dataclasses import replace 

30from enum import Enum 

31from functools import cached_property, partial 

32from types import MappingProxyType 

33from typing import Any, Literal, Protocol, TypedDict 

34 

35import jax 

36import jax.numpy as jnp 

37from equinox import Module, error_if, field 

38from jax import Device, debug_nans, device_put, jit, lax, make_mesh, random, tree 

39from jax.scipy.linalg import solve_triangular 

40from jax.scipy.special import ndtr 

41from jax.sharding import AxisType, Mesh, PartitionSpec 

42from jaxtyping import ( 

43 Array, 

44 Bool, 

45 Float, 

46 Float32, 

47 Int32, 

48 Integer, 

49 Key, 

50 Real, 

51 Shaped, 

52 UInt, 

53) 

54from numpy import ndarray 

55 

56from bartz import mcmcloop, mcmcstep, prepcovars 

57from bartz.grove import ( 

58 TreesTrace, 

59 check_trace, 

60 evaluate_forest, 

61 forest_depth_distr, 

62 points_per_node_distr, 

63) 

64from bartz.jaxext import equal_shards, is_key 

65from bartz.jaxext.scipy.special import ndtri 

66from bartz.jaxext.scipy.stats import invgamma 

67from bartz.mcmcloop import RunMCMCResult, compute_varcount, evaluate_trace, run_mcmc 

68from bartz.mcmcstep import OutcomeType, make_p_nonterminal 

69from bartz.mcmcstep._state import ( 

70 _inv_via_chol_with_gersh, 

71 chol_with_gersh, 

72 get_num_chains, 

73) 

74 

75FloatLike = float | Float[Any, ''] 

76 

77 

78class PredictKind(Enum): 

79 """Kind of output of `Bart.predict`.""" 

80 

81 mean = 'mean' 

82 """The posterior mean of the conditional mean, shape ``(m,)`` (or 

83 ``(k, m)`` for multivariate regression).""" 

84 

85 mean_samples = 'mean_samples' 

86 """Per-sample conditional mean, shape ``(ndpost, m)`` (or ``(ndpost, 

87 k, m)``). For binary regression, this is the probit-transformed 

88 sum-of-trees.""" 

89 

90 outcome_samples = 'outcome_samples' 

91 """Samples of the outcome variable, shape ``(ndpost, m)`` (or 

92 ``(ndpost, k, m)``). For binary regression, these are Bernoulli 

93 draws. For continuous regression, these are Gaussian draws with the 

94 posterior noise variance.""" 

95 

96 latent_samples = 'latent_samples' 

97 """Raw sum-of-trees values, shape ``(ndpost, m)`` (or ``(ndpost, k, 

98 m)``).""" 

99 

100 

101class DataFrame(Protocol): 

102 """DataFrame duck-type for `Bart`.""" 

103 

104 columns: Sequence[str] 

105 """The names of the columns.""" 

106 

107 def to_numpy(self) -> ndarray: 

108 """Convert the dataframe to a 2d numpy array with columns on the second axis.""" 

109 ... 

110 

111 

112class Series(Protocol): 

113 """Series duck-type for `Bart`.""" 

114 

115 name: str | None 

116 """The name of the series.""" 

117 

118 def to_numpy(self) -> ndarray: 

119 """Convert the series to a 1d numpy array.""" 

120 ... 

121 

122 

123class Bart(Module): 

124 R""" 

125 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_. 

126 

127 Regress `y_train` on `x_train` with a latent mean function represented as 

128 a sum of decision trees. The inference is carried out by sampling the 

129 posterior distribution of the tree ensemble with an MCMC. 

130 

131 Parameters 

132 ---------- 

133 x_train 

134 The training predictors. 

135 y_train 

136 The training responses. For univariate regression, a 1D array of shape 

137 `(n,)`. For multivariate regression, a 2D array of shape `(k, n)` where 

138 `k` is the number of response components, as introduced in [3]_. For 

139 binary regression, the convention is that non-zero values mean 1, zero 

140 mean 0, like booleans. 

141 outcome_type 

142 The type of regression. ``'continuous'`` for continuous regression, 

143 ``'binary'`` for binary regression with probit link. For multivariate 

144 regression, a scalar value applies to all components; alternatively, a 

145 sequence of per-component types (e.g., ``['binary', 'continuous']``) 

146 specifies mixed outcome types. 

147 sparse 

148 Whether to activate variable selection on the predictors as done in 

149 [1]_. 

150 theta 

151 a 

152 b 

153 rho 

154 Hyperparameters of the sparsity prior used for variable selection. 

155 

156 The prior distribution on the choice of predictor for each decision rule 

157 is 

158 

159 .. math:: 

160 (s_1, \ldots, s_p) \sim 

161 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p). 

162 

163 If `theta` is not specified, it's a priori distributed according to 

164 

165 .. math:: 

166 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim 

167 \operatorname{Beta}(\mathtt{a}, \mathtt{b}). 

168 

169 If not specified, `rho` is set to the number of predictors p. To tune 

170 the prior, consider setting a lower `rho` to prefer more sparsity. 

171 If setting `theta` directly, it should be in the ballpark of p or lower 

172 as well. 

173 varprob 

174 The probability distribution over the `p` predictors for choosing a 

175 predictor to split on in a decision node a priori. Must be > 0. It does 

176 not need to be normalized to sum to 1. If not specified, use a uniform 

177 distribution. If ``sparse=True``, this is used as initial value for the 

178 MCMC. 

179 xinfo 

180 A matrix with the cutpoins to use to bin each predictor. If not 

181 specified, it is generated automatically according to `usequants` and 

182 `numcut`. 

183 

184 Each row shall contain a sorted list of cutpoints for a predictor. If 

185 there are less cutpoints than the number of columns in the matrix, 

186 fill the remaining cells with NaN. 

187 

188 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

189 usequants 

190 Whether to use predictors quantiles instead of a uniform grid to bin 

191 predictors. Ignored if `xinfo` is specified. 

192 rm_const 

193 How to treat predictors with no associated decision rules (i.e., there 

194 are no available cutpoints for that predictor). If `True` (default), 

195 they are ignored. If `False`, an error is raised if there are any. 

196 sigest 

197 An estimate of the residual standard deviation on `y_train`, used to set 

198 `lamda`. If not specified, it is estimated by linear regression (with 

199 intercept, and without taking into account `w`). If `y_train` has less 

200 than two elements, it is set to 1. If n <= p, it is set to the standard 

201 deviation of `y_train`. Ignored if `lamda` is specified. For 

202 multivariate regression, can be a scalar (broadcast to all components) 

203 or a `(k,)` vector of per-component estimates. For mixed outcome types, 

204 binary component values are ignored. 

205 sigdf 

206 The degrees of freedom of the scaled inverse-chisquared prior on the 

207 noise variance. For multivariate regression, the Inverse-Wishart 

208 degrees of freedom are set to `sigdf + k - 1`. 

209 sigquant 

210 The quantile of the prior on the noise variance that shall match 

211 `sigest` to set the scale of the prior. Ignored if `lamda` is specified. 

212 k 

213 The inverse scale of the prior standard deviation on the latent mean 

214 function, relative to half the observed range of `y_train`. If `y_train` 

215 has less than two elements, `k` is ignored and the scale is set to 1. 

216 power 

217 base 

218 Parameters of the prior on tree node generation. The probability that a 

219 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) ** 

220 power``. 

221 lamda 

222 The prior harmonic mean of the error variance. (The harmonic mean of x 

223 is 1/mean(1/x).) If not specified, it is set based on `sigest` and 

224 `sigquant`. For multivariate regression, can be a scalar (broadcast 

225 to all components) or a `(k,)` vector. For mixed outcome types, binary 

226 component values are ignored. 

227 tau_num 

228 The numerator in the expression that determines the prior standard 

229 deviation of leaves. If not specified, default to ``(max(y_train) - 

230 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for 

231 continuous regression, and 3 for binary regression. For multivariate 

232 regression, the range is computed per component. For mixed outcome 

233 types, each component uses the default for its type. 

234 offset 

235 The prior mean of the latent mean function. If not specified, it is set 

236 to the mean of `y_train` for continuous regression, and to 

237 ``Phi^-1(mean(y_train != 0))`` for binary regression. If `y_train` is 

238 empty, `offset` is set to 0. With binary regression, if `y_train` is 

239 all zero or all non-zero, it is set to ``Phi^-1(1/(n+1))`` or 

240 ``Phi^-1(n/(n+1))``, respectively. For multivariate regression, can be 

241 a scalar (broadcast to all components) or a `(k,)` vector. If not 

242 specified, it is set to the per-component mean of `y_train`. For mixed 

243 outcome types, each component uses the default for its type. 

244 w 

245 Coefficients that rescale the error standard deviation on each 

246 datapoint. Not specifying `w` is equivalent to setting it to 1 for all 

247 datapoints. Note: `w` is ignored in the automatic determination of 

248 `sigest`, so either the weights should be O(1), or `sigest` should be 

249 specified by the user. Not supported for multivariate regression. 

250 num_trees 

251 The number of trees used to represent the latent mean function. 

252 numcut 

253 If `usequants` is `False`: the exact number of cutpoints used to bin the 

254 predictors, ranging between the minimum and maximum observed values 

255 (excluded). 

256 

257 If `usequants` is `True`: the maximum number of cutpoints to use for 

258 binning the predictors. Each predictor is binned such that its 

259 distribution in `x_train` is approximately uniform across bins. The 

260 number of bins is at most the number of unique values appearing in 

261 `x_train`, or ``numcut + 1``. 

262 

263 Before running the algorithm, the predictors are compressed to the 

264 smallest integer type that fits the bin indices, so `numcut` is best set 

265 to the maximum value of an unsigned integer type, like 255. 

266 

267 Ignored if `xinfo` is specified. 

268 ndpost 

269 The number of MCMC samples to save, after burn-in. `ndpost` is the 

270 total number of samples across all chains. `ndpost` is rounded up to the 

271 first multiple of `num_chains`. 

272 nskip 

273 The number of initial MCMC samples to discard as burn-in. This number 

274 of samples is discarded from each chain. 

275 keepevery 

276 The thinning factor for the MCMC samples, after burn-in. 

277 printevery 

278 The number of iterations (including thinned-away ones) between each log 

279 line. Set to `None` to disable logging. ^C interrupts the MCMC only 

280 every `printevery` iterations, so with logging disabled it's impossible 

281 to kill the MCMC conveniently. 

282 num_chains 

283 The number of independent Markov chains to run. 

284 

285 The difference between ``num_chains=None`` and ``num_chains=1`` is that 

286 in the latter case in the object attributes and some methods there will 

287 be an explicit chain axis of size 1. 

288 num_chain_devices 

289 The number of devices to spread the chains across. Must be a divisor of 

290 `num_chains`. Each device will run a fraction of the chains. 

291 num_data_devices 

292 The number of devices to split datapoints across. Must be a divisor of 

293 `n`. This is useful only with very high `n`, about > 1000_000. 

294 

295 If both num_chain_devices and num_data_devices are specified, the total 

296 number of devices used is the product of the two. 

297 devices 

298 One or more devices used to run the MCMC on. If not specified, the 

299 computation will follow the placement of the input arrays. If a list of 

300 devices, this argument can be longer than the number of devices needed. 

301 seed 

302 The seed for the random number generator. 

303 maxdepth 

304 The maximum depth of the trees. This is 1-based, so with the default 

305 ``maxdepth=6``, the depths of the levels range from 0 to 5. 

306 init_kw 

307 Additional arguments passed to `bartz.mcmcstep.init`. 

308 run_mcmc_kw 

309 Additional arguments passed to `bartz.mcmcloop.run_mcmc`. 

310 

311 References 

312 ---------- 

313 .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for 

314 High-Dimensional Prediction and Variable Selection”. In: Journal of the 

315 American Statistical Association 113.522, pp. 626-636. 

316 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART: 

317 Bayesian additive regression trees," The Annals of Applied Statistics, 

318 Ann. Appl. Stat. 4(1), 266-298, (March 2010). 

319 .. [3] Um, Seungha, Antonio R. Linero, Debajyoti Sinha, and Dipankar 

320 Bandyopadhyay (2023). "Bayesian additive regression trees for 

321 multivariate skewed responses". In: Statistics in Medicine 42.3, 

322 pp. 246-263. 

323 

324 """ 

325 

326 _main_trace: mcmcloop.MainTrace 

327 _burnin_trace: mcmcloop.BurninTrace 

328 _mcmc_state: mcmcstep.State 

329 _splits: Real[Array, 'p max_num_splits'] 

330 _binary_mask: Bool[Array, ''] | Bool[Array, ' k'] 

331 _x_train_fmt: Any = field(static=True) 

332 

333 offset: Float32[Array, ''] | Float32[Array, ' k'] 

334 """The prior mean of the latent mean function.""" 

335 

336 sigest: Float32[Array, ''] | Float32[Array, ' k'] | None = None 

337 """The estimated standard deviation of the error used to set `lamda`.""" 

338 

339 def __init__( 

340 self, 

341 x_train: Real[Array, 'p n'] | DataFrame, 

342 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Series, 

343 *, 

344 outcome_type: OutcomeType | str | Sequence[OutcomeType | str] = 'continuous', 

345 sparse: bool = False, 

346 theta: FloatLike | None = None, 

347 a: FloatLike = 0.5, 

348 b: FloatLike = 1.0, 

349 rho: FloatLike | None = None, 

350 varprob: Float[Array, ' p'] | None = None, 

351 xinfo: Float[Array, 'p n'] | None = None, 

352 usequants: bool = False, 

353 rm_const: bool = True, 

354 sigest: FloatLike | Float[Array, ' k'] | None = None, 

355 sigdf: FloatLike = 3.0, 

356 sigquant: FloatLike = 0.9, 

357 k: FloatLike = 2.0, 

358 power: FloatLike = 2.0, 

359 base: FloatLike = 0.95, 

360 lamda: FloatLike | Float[Array, ' k'] | None = None, 

361 tau_num: FloatLike | None = None, 

362 offset: FloatLike | Float[Array, ' k'] | None = None, 

363 w: Float[Array, ' n'] | Series | None = None, 

364 num_trees: int = 200, 

365 numcut: int = 255, 

366 ndpost: int = 1000, 

367 nskip: int = 1000, 

368 keepevery: int = 1, 

369 printevery: int | None = 100, 

370 num_chains: int | None = 4, 

371 num_chain_devices: int | None = None, 

372 num_data_devices: int | None = None, 

373 devices: Device | Sequence[Device] | None = None, 

374 seed: int | Key[Array, ''] = 0, 

375 maxdepth: int = 6, 

376 init_kw: Mapping = MappingProxyType({}), 

377 run_mcmc_kw: Mapping = MappingProxyType({}), 

378 ) -> None: 

379 # check data and put it in the right format 

380 x_train, x_train_fmt = self._process_predictor_input(x_train) 1bca

381 y_train = self._process_response_input(y_train) 1bca

382 self._check_same_length(x_train, y_train) 1bca

383 

384 if w is not None: 1bzca

385 w = self._process_response_input(w) 1z

386 self._check_same_length(x_train, w) 1z

387 

388 # check data types are correct for continuous/binary/multivariate regression 

389 outcome_type, binary_mask = self._check_type_settings(y_train, outcome_type, w) 1bca

390 

391 # process sparsity settings 

392 theta, a, b, rho = self._process_sparsity_settings( 1bca

393 x_train, sparse, theta, a, b, rho 

394 ) 

395 

396 # process "standardization" settings 

397 offset = self._process_offset_settings(y_train, binary_mask, offset) 1bca

398 leaf_prior_cov_inv = self._process_leaf_variance_settings( 1bca

399 y_train, binary_mask, k, num_trees, tau_num 

400 ) 

401 error_cov_df, error_cov_scale, sigest = self._process_error_variance_settings( 1bca

402 x_train, y_train, outcome_type, binary_mask, sigest, sigdf, sigquant, lamda 

403 ) 

404 

405 # determine splits 

406 splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo) 1bca

407 x_train = self._bin_predictors(x_train, splits) 1bca

408 

409 # setup and run mcmc 

410 initial_state = self._setup_mcmc( 1bca

411 x_train, 

412 y_train, 

413 outcome_type, 

414 offset, 

415 w, 

416 max_split, 

417 leaf_prior_cov_inv, 

418 error_cov_df, 

419 error_cov_scale, 

420 power, 

421 base, 

422 maxdepth, 

423 num_trees, 

424 init_kw, 

425 rm_const, 

426 theta, 

427 a, 

428 b, 

429 rho, 

430 varprob, 

431 num_chains, 

432 num_chain_devices, 

433 num_data_devices, 

434 devices, 

435 sparse, 

436 nskip, 

437 ) 

438 result = self._run_mcmc( 1bca

439 initial_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw 

440 ) 

441 

442 # set public attributes 

443 # set offset from the state because of buffer donation 

444 self.offset = result.final_state.offset 1bca

445 self.sigest = sigest 1bca

446 

447 # set private attributes 

448 self._main_trace = result.main_trace 1bca

449 self._burnin_trace = result.burnin_trace 1bca

450 self._mcmc_state = result.final_state 1bca

451 self._splits = splits 1bca

452 self._x_train_fmt = x_train_fmt 1bca

453 self._binary_mask = binary_mask 1bca

454 

455 def predict( 

456 self, 

457 x_test: Real[Array, 'p m'] | DataFrame | str, 

458 *, 

459 kind: PredictKind | str = 'mean', 

460 key: Key[Array, ''] | None = None, 

461 w: Float[Array, ' m'] | Series | None = None, 

462 ) -> ( 

463 Float32[Array, ' m'] 

464 | Float32[Array, 'k m'] 

465 | Float32[Array, 'ndpost m'] 

466 | Float32[Array, 'ndpost k m'] 

467 ): 

468 """ 

469 Compute predictions at `x_test`. 

470 

471 Parameters 

472 ---------- 

473 x_test 

474 The test predictors, or the string ``'train'`` to compute 

475 predictions on the training data. 

476 kind 

477 The kind of output. See `PredictKind` for details. 

478 key 

479 Jax random key, required when ``kind='outcome_samples'``. 

480 w 

481 Per-observation error scale for ``kind='outcome_samples'``. 

482 Required when the model was fit with weights and ``x_test`` is 

483 new data. 

484 

485 Returns 

486 ------- 

487 Predictions at `x_test` in the requested format. 

488 

489 Raises 

490 ------ 

491 ValueError 

492 If `x_test` has a different format than `x_train`, or if `w` 

493 is specified when it should be `None`, or if `w` is not 

494 specified when it is required. 

495 

496 """ 

497 # parse arguments 

498 kind = PredictKind(kind) 1bga

499 if kind is PredictKind.outcome_samples and key is None: 499 ↛ 500line 499 didn't jump to line 500 because the condition on line 499 was never true1bgfea

500 msg = '`key` not specified' 

501 raise ValueError(msg) 

502 w = self._process_w_test(x_test, kind, w) 1bgfea

503 x_test = self._process_x_test(x_test, w) 1bga

504 

505 # get latent i.e. bare sum-of-trees predictions 

506 latent = self._predict(x_test) 1bga

507 if kind is PredictKind.latent_samples: 1bgfa

508 return latent 1bga

509 

510 # sample posterior (uses latent directly, no probit squash needed) 

511 binary_indices = self._mcmc_state.binary_indices 1gfa

512 if kind is PredictKind.outcome_samples: 1gfea

513 return self._sample_outcome(key, latent, binary_indices, w) 1fea

514 

515 # squash predictions to (0, 1) if probit 

516 if binary_indices is not None: 1gnfAa

517 indexing = jnp.s_[..., binary_indices, :] 1nA

518 mean_samples = latent.at[indexing].set(ndtr(latent[indexing])) 1nA

519 elif self._mcmc_state.binary_y is not None: 1gfa

520 mean_samples = ndtr(latent) 1g

521 else: 

522 mean_samples = latent 1fa

523 

524 # take mean or return samples 

525 if kind is PredictKind.mean: 1gfea

526 return mean_samples.mean(axis=0) 1fea

527 return mean_samples 1gfa

528 

529 @property 

530 def ndpost(self) -> int: 

531 """The total number of posterior samples after burn-in across all chains. 

532 

533 May be larger than the initialization argument `ndpost` if it was not 

534 divisible by the number of chains. 

535 """ 

536 return self._main_trace.grow_prop_count.size 1lgE

537 

538 @property 

539 def num_trees(self) -> int: 

540 """Return the number of trees used in the model.""" 

541 return self._mcmc_state.forest.split_tree.shape[-2] 1789

542 

543 def get_latent_prec( 

544 self, only_continuous: bool = False 

545 ) -> ( 

546 Float32[Array, ' nskip+ndpost'] 

547 | Float32[Array, 'nskip+ndpost k k'] 

548 | Float32[Array, 'num_chains nskip+ndpost/num_chains'] 

549 | Float32[Array, 'num_chains nskip+ndpost/num_chains k k'] 

550 ): 

551 """Return the posterior samples of the latent error precision matrix. 

552 

553 Parameters 

554 ---------- 

555 only_continuous 

556 If `True` and the model has mixed binary-continuous outcomes, 

557 return only the submatrix for the continuous components. 

558 

559 Returns 

560 ------- 

561 MCMC samples of the error precision matrix. 

562 

563 Notes 

564 ----- 

565 This method is meant to check for convergence, so it returns the full 

566 MCMC trace and does not concatenate chains together. For probit 

567 regression, this returns the precision of the latent error term, not 

568 the Bernoulli precision for the binary outcome. For heteroskedastic 

569 regression, the returned precision is the global precision parameter, 

570 that would have to be divided by a squared weight to get the precision 

571 on a given datapoint. 

572 

573 Raises 

574 ------ 

575 ValueError 

576 If `only_continuous` is `True` but the model has only binary 

577 outcomes, so there is no continuous submatrix to return. 

578 """ 

579 binary_indices = self._mcmc_state.binary_indices 1kcf

580 if ( 1ikcEPsf

581 only_continuous 

582 and binary_indices is None 

583 and self._mcmc_state.binary_y is not None 

584 ): 

585 msg = 'Model has only binary outcomes, so there is no continuous submatrix to return.' 1iP

586 raise ValueError(msg) 1iP

587 

588 burnin = self._burnin_trace.error_cov_inv 1kcEsf

589 main = self._main_trace.error_cov_inv 1kcf

590 # trace shape is (chains?, samples, ...) where chains is optional 

591 # first axis; samples is the axis to concatenate along 

592 num_chains = get_num_chains(self._mcmc_state) 1kcf

593 sample_axis = 1 if num_chains is not None else 0 1kncsfe

594 prec = jnp.concatenate([burnin, main], axis=sample_axis) 1kncsfe

595 

596 if only_continuous and binary_indices is not None: 1kcEsf

597 *_, k, _ = prec.shape 1s

598 mask = jnp.ones(k, dtype=bool).at[binary_indices].set(False) 1s

599 cont_indices = jnp.arange(k)[mask] 1s

600 prec = prec[..., cont_indices[:, None], cont_indices[None, :]] 1s

601 

602 return prec 1kcEf

603 

604 def get_error_sdev( 

605 self, mean: bool = False 

606 ) -> ( 

607 Float32[Array, 'ndpost'] 

608 | Float32[Array, 'ndpost k'] 

609 | Float32[Array, ''] 

610 | Float32[Array, ' k'] 

611 ): 

612 """Return the error standard deviation, post-burnin, chains concatenated. 

613 

614 Parameters 

615 ---------- 

616 mean 

617 If `True`, average the precision matrix across samples first 

618 (harmonic mean at the covariance matrix level), returning a single 

619 scalar or vector instead of posterior samples. 

620 

621 Returns 

622 ------- 

623 Posterior samples (or single estimate) of the error standard deviation; NaN for binary outcomes. 

624 

625 Notes 

626 ----- 

627 Binary outcomes do have a standard deviation of course, but it's not 

628 returned by this method because that would require to evaluate 

629 predictions on a given X, since the Bernoulli variance is p(1-p). 

630 """ 

631 # reshape operations 

632 error_cov_inv = self._main_trace.error_cov_inv 1kcf

633 if error_cov_inv.ndim in (2, 4): 1kncUfe

634 # shape (chains, samples) or (chains, samples, k, k), concatenate chains 

635 error_cov_inv = lax.collapse(error_cov_inv, 0, 2) 1nUe

636 is_uv = error_cov_inv.ndim == 1 1kcf

637 if mean: 1kcfe

638 error_cov_inv = error_cov_inv.mean(0) 1fe

639 if is_uv: 639 ↛ 641line 639 didn't jump to line 641 because the condition on line 639 was never true1kcf

640 # univariate case, reshape to 1x1 matrix 

641 error_cov_inv = error_cov_inv[..., None, None] 

642 

643 # compute sdev and fill in nans for binary outcomes 

644 cov = _inv_via_chol_with_gersh(error_cov_inv) 1kcf

645 sdev = jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1)) 1kcf

646 if is_uv: 646 ↛ 647line 646 didn't jump to line 647 because the condition on line 646 was never true1kcf

647 sdev = sdev.squeeze(-1) 

648 with debug_nans(False): 1kcf

649 return jnp.where(self._binary_mask, jnp.nan, sdev) 1kcf

650 

651 @cached_property 

652 def varcount(self) -> Int32[Array, 'ndpost p']: 

653 """Histogram of predictor usage for decision rules in the trees.""" 

654 p = self._mcmc_state.forest.max_split.size 1lea

655 return varcount(p, self._main_trace) 1lea

656 

657 @cached_property 

658 def varcount_mean(self) -> Float32[Array, ' p']: 

659 """Average of `varcount` across MCMC iterations.""" 

660 return self.varcount.mean(axis=0) 1Vea

661 

662 @cached_property 

663 def varprob(self) -> Float32[Array, 'ndpost p']: 

664 """Posterior samples of the probability of choosing each predictor for a decision rule.""" 

665 max_split = self._mcmc_state.forest.max_split 1lea

666 p = max_split.size 1lea

667 varprob = self._main_trace.varprob 1lea

668 if varprob is None: 1lFQea

669 peff = jnp.count_nonzero(max_split) 1Fe

670 varprob = jnp.where(max_split, 1 / peff, 0) 1Fe

671 varprob = jnp.broadcast_to(varprob, (self.ndpost, p)) 1Fe

672 else: 

673 varprob = varprob.reshape(-1, p) 1lQa

674 return varprob 1lFQea

675 

676 @cached_property 

677 def varprob_mean(self) -> Float32[Array, ' p']: 

678 """The marginal posterior probability of each predictor being chosen for a decision rule.""" 

679 return self.varprob.mean(axis=0) 1Vea

680 

681 def _sample_outcome( 

682 self, 

683 key: Key[Array, ''], 

684 latent: Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'], 

685 binary_indices: Int32[Array, ' kb'] | None, 

686 w: Float32[Array, ' m'] | None, 

687 ) -> Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m']: 

688 """Sample from the posterior predictive distribution.""" 

689 if latent.ndim > 2: # multivariate case 689 ↛ 700line 689 didn't jump to line 700 because the condition on line 689 was always true1fea

690 error_cov_inv = self._main_trace.error_cov_inv 1fea

691 error_cov_inv = lax.collapse(error_cov_inv, 0, -2) 1fea

692 

693 # Cholesky of precision: error_cov_inv = L @ L^T 

694 L = chol_with_gersh(error_cov_inv) # (ndpost, k, k) 1fea

695 

696 # Sample z ~ N(0, I) and solve L^T @ error = z 

697 # so error = L^{-T} z ~ N(0, L^{-T} L^{-1}) = N(0, Sigma) 

698 z = random.normal(key, latent.shape) # (ndpost, k, m) 1fea

699 error = solve_triangular(L, z, trans='T', lower=True) 1fea

700 elif self._mcmc_state.binary_y is not None: 

701 # pure binary UV: probit has sigma = 1 

702 error = random.normal(key, latent.shape) 

703 else: # univariate continuous 

704 sigma = jnp.sqrt(jnp.reciprocal(self._main_trace.error_cov_inv)).reshape(-1) 

705 error = sigma[..., None] * random.normal(key, latent.shape) 

706 if w is not None: 

707 error *= w[None, :] 

708 

709 outcome = latent + error 1fea

710 

711 # convert binary outcomes via latent probit thresholding 

712 if binary_indices is not None: 1nfeAa

713 idx = jnp.s_[..., binary_indices, :] 1nA

714 outcome = outcome.at[idx].set(jnp.where(outcome[idx] > 0, 1.0, 0.0)) 1nA

715 elif self._mcmc_state.binary_y is not None: 1fea

716 outcome = jnp.where(outcome > 0, 1.0, 0.0) 1e

717 

718 return outcome 1nfeAa

719 

720 def _process_w_test( 

721 self, 

722 x_test: Real[Array, 'p m'] | DataFrame | str, 

723 kind: PredictKind, 

724 w: Float[Array, ' m'] | Series | None, 

725 ) -> Float32[Array, ' m'] | None: 

726 """Validate and resolve the error weights for prediction. 

727 

728 Parameters 

729 ---------- 

730 x_test 

731 The raw (not yet processed) test predictors, or ``'train'``. 

732 kind 

733 The prediction kind. 

734 w 

735 User-provided per-observation error scale, or `None`. 

736 

737 Returns 

738 ------- 

739 The resolved error scale as a float32 array, or `None` if weights 

740 are not applicable. 

741 

742 Raises 

743 ------ 

744 ValueError 

745 If `w` is specified when it should be `None`, or missing when 

746 required. 

747 

748 """ 

749 x_test_is_train = isinstance(x_test, str) and x_test == 'train' 1lbgea

750 has_train_weights = self._mcmc_state.prec_scale is not None 1lbgea

751 is_binary = self._mcmc_state.binary_y is not None 1bga

752 is_multivariate = self._mcmc_state.offset.ndim == 1 1bga

753 needs_weights = ( 1bgnfea

754 kind is PredictKind.outcome_samples 

755 and not is_binary 

756 and not is_multivariate 

757 and has_train_weights 

758 ) 

759 

760 if not needs_weights: 760 ↛ 771line 760 didn't jump to line 771 because the condition on line 760 was always true1bgnfea

761 if w is not None: 761 ↛ 762line 761 didn't jump to line 762 because the condition on line 761 was never true1bga

762 msg = ( 

763 '`w` must be `None` in this configuration' 

764 " (it is used only with kind='outcome_samples'," 

765 ' univariate continuous regression fitted with' 

766 ' weights)' 

767 ) 

768 raise ValueError(msg) 

769 return None 1bga

770 

771 if x_test_is_train: 

772 if w is not None: 

773 msg = ( 

774 "`w` must be `None` when x_test='train'" 

775 ' (training weights are used automatically)' 

776 ) 

777 raise ValueError(msg) 

778 return jnp.reciprocal(jnp.sqrt(self._mcmc_state.prec_scale)) 

779 

780 # new test data, model was fit with weights 

781 if w is None: 

782 msg = ( 

783 '`w` is required because the model was fit with' 

784 ' weights and x_test is new data' 

785 ) 

786 raise ValueError(msg) 

787 return self._process_response_input(w) 

788 

789 def _process_x_test( 

790 self, 

791 x_test: Real[Array, 'p m'] | DataFrame | str, 

792 w: Float32[Array, ' m'] | None, 

793 ) -> UInt[Array, 'p m']: 

794 """Convert x_test to binned format suitable for prediction.""" 

795 if isinstance(x_test, str): 1lbgea

796 if x_test != 'train': 796 ↛ 797line 796 didn't jump to line 797 because the condition on line 796 was never true1lga

797 msg = ( 

798 f"x_test must be an array, a DataFrame, or 'train', got {x_test!r}" 

799 ) 

800 raise ValueError(msg) 

801 return self._mcmc_state.X 1lga

802 x_test, x_test_fmt = self._process_predictor_input(x_test) 1bea

803 if x_test_fmt != self._x_train_fmt: 1bRDea

804 msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}' 1RD

805 raise ValueError(msg) 1RD

806 if w is not None: 806 ↛ 807line 806 didn't jump to line 807 because the condition on line 806 was never true1bea

807 self._check_same_length(w, x_test) 

808 return self._bin_predictors(x_test, self._splits) 1bea

809 

810 @staticmethod 

811 def _process_predictor_input( 

812 x: Real[Any, 'p n'] | DataFrame, 

813 ) -> tuple[Shaped[Array, 'p n'], Any]: 

814 if hasattr(x, 'columns'): 1bGDca

815 fmt = dict(kind='dataframe', columns=x.columns) 1GD

816 x = x.to_numpy().T 1GD

817 else: 

818 fmt = dict(kind='array', num_covar=x.shape[0]) 1bca

819 x = jnp.asarray(x) 1bca

820 assert x.ndim == 2 1bca

821 return x, fmt 1bca

822 

823 @staticmethod 

824 def _process_response_input( 

825 y: Shaped[Array, ' n'] | Shaped[Array, 'k n'] | Series, 

826 ) -> Float32[Array, ' n'] | Float32[Array, 'k n']: 

827 if hasattr(y, 'to_numpy'): 1bGca

828 y = y.to_numpy() 1G

829 y = jnp.asarray(y, jnp.float32) 1bca

830 if y.ndim < 1 or y.ndim > 2: 830 ↛ 831line 830 didn't jump to line 831 because the condition on line 830 was never true1bca

831 msg = f'y_train must be 1D (n,) or 2D (k, n). Got {y.ndim=}.' 

832 raise ValueError(msg) 

833 return y 1bca

834 

835 @staticmethod 

836 def _check_same_length(x1: Array, x2: Array) -> None: 

837 get_length = lambda x: x.shape[-1] 1bca

838 assert get_length(x1) == get_length(x2) 1bca

839 

840 @classmethod 

841 def _process_error_variance_settings( 

842 cls, 

843 x_train: Shaped[Array, 'p n'], 

844 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'], 

845 outcome_type: OutcomeType | tuple[OutcomeType, ...], 

846 binary_mask: Bool[Array, ''] | Bool[Array, ' k'], 

847 sigest: FloatLike | Float[Array, ' k'] | None, 

848 sigdf: FloatLike, 

849 sigquant: FloatLike, 

850 lamda: FloatLike | Float[Array, ' k'] | None, 

851 ) -> tuple[ 

852 Float32[Array, ''] | None, 

853 Float32[Array, ''] | Float32[Array, 'k k'] | None, 

854 Float32[Array, ''] | Float32[Array, ' k'] | None, 

855 ]: 

856 """Return (error_cov_df, error_cov_scale, sigest).""" 

857 if outcome_type is OutcomeType.binary: 1bjhca

858 if sigest is not None or lamda is not None: 858 ↛ 859line 858 didn't jump to line 859 because the condition on line 858 was never true1jh

859 msg = 'Let `sigest=None` and `lamda=None` for binary regression' 

860 raise ValueError(msg) 

861 return None, None, None 1jh

862 

863 if lamda is None: 1btca

864 # estimate sigest² 

865 sigest2 = cls._estimate_sigest2(x_train, y_train, sigest, binary_mask) 1bca

866 sigest = jnp.sqrt(sigest2) 1bca

867 

868 # lamda from sigest² 

869 alpha = sigdf / 2 1bca

870 invchi2 = invgamma.ppf(sigquant, alpha) / 2 1bca

871 invchi2rid = invchi2 * sigdf 1bca

872 lamda = sigest2 / invchi2rid 1bca

873 

874 elif sigest is not None: 874 ↛ 875line 874 didn't jump to line 875 because the condition on line 874 was never true1t

875 msg = 'Let `sigest=None` if `lamda` is specified' 

876 raise ValueError(msg) 

877 

878 else: 

879 lamda = jnp.where(binary_mask, 0.0, lamda) 1t

880 

881 # params written in multivariate form 

882 if y_train.ndim == 2: 1bBcaH

883 k = y_train.shape[0] 1Bca

884 lamda = jnp.broadcast_to(lamda, (k,)) 1Bca

885 error_cov_df = jnp.asarray(sigdf) + k - 1 1Bca

886 error_cov_scale = jnp.diag(sigdf * lamda) 1Bca

887 else: 

888 error_cov_df = jnp.asarray(sigdf) 1bH

889 error_cov_scale = jnp.asarray(sigdf * lamda) 1bH

890 

891 return error_cov_df, error_cov_scale, sigest 1bca

892 

893 @classmethod 

894 def _estimate_sigest2( 

895 cls, 

896 x_train: Shaped[Array, 'p n'], 

897 y_train: Float32[Array, '*k n'], 

898 sigest: float | Shaped[Array, '*k'] | None, 

899 binary_mask: Bool[Array, ''] | Bool[Array, ' k'], 

900 ) -> Float32[Array, '*k']: 

901 n = y_train.shape[-1] 1bca

902 if sigest is not None: 1btca

903 sigest2 = jnp.square(jnp.asarray(sigest, dtype=jnp.float32)) 1t

904 sigest2 = jnp.broadcast_to(sigest2, y_train.shape[:-1]) 1t

905 elif n < 2: 1bucWa

906 sigest2 = jnp.ones(y_train.shape[:-1]) 1uW

907 elif n <= x_train.shape[0]: 1bXca

908 sigest2 = jnp.var(y_train, axis=-1) 1X

909 else: 

910 sigest2 = cls._linear_regression(x_train, y_train) 1bca

911 return jnp.where(binary_mask, 0.0, sigest2) 1bca

912 

913 @staticmethod 

914 @jit 

915 def _linear_regression( 

916 x_train: Shaped[Array, 'p n'], 

917 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'], 

918 ) -> Float32[Array, ''] | Float32[Array, ' k']: 

919 """Return the error variance estimated with OLS with intercept.""" 

920 x_centered = x_train.T - x_train.mean(axis=1) 1bca

921 y_centered = y_train.T - y_train.mean(axis=-1) 1bca

922 # centering is equivalent to adding an intercept column 

923 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1bca

924 chisq = chisq.reshape(y_train.shape[:-1]) 1bca

925 dof = y_train.shape[-1] - rank 1bca

926 return chisq / dof 1bca

927 

928 @staticmethod 

929 def _check_type_settings( 

930 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'], 

931 outcome_type: OutcomeType | str | Sequence[OutcomeType | str], 

932 w: Float[Array, ' n'] | None, 

933 ) -> tuple[ 

934 OutcomeType | tuple[OutcomeType, ...], Bool[Array, ''] | Bool[Array, ' k'] 

935 ]: 

936 # standardize outcome_type to OutcomeType or tuple[OutcomeType, ...] 

937 if isinstance(outcome_type, Sequence) and not isinstance(outcome_type, str): 1bmicva

938 outcome_type = tuple(OutcomeType(t) for t in outcome_type) 1miv

939 num_types = len(outcome_type) 1miv

940 if len(set(outcome_type)) == 1: 1Imiv

941 outcome_type = outcome_type[0] 1I

942 else: 

943 num_types = None 1bca

944 outcome_type = OutcomeType(outcome_type) 1bca

945 

946 # validation 

947 if num_types is not None and ( 1bIYmicva

948 y_train.ndim != 2 or num_types != y_train.shape[0] 

949 ): 

950 msg = ( 1IY

951 f'Sequence outcome_type of length {num_types}' 

952 f' requires y_train.shape=({num_types}, n),' 

953 f' found {y_train.shape=}.' 

954 ) 

955 raise ValueError(msg) 1I

956 if w is not None and not ( 1bzZSmicva

957 outcome_type is OutcomeType.continuous and y_train.ndim == 1 

958 ): 

959 msg = 'Weights are only supported for univariate continuous regression.' 1ZS

960 raise ValueError(msg) 1S

961 

962 if isinstance(outcome_type, tuple): 1bzmicva

963 binary_mask = jnp.array([t is OutcomeType.binary for t in outcome_type]) 1miv

964 else: 

965 binary_mask = jnp.bool_(outcome_type is OutcomeType.binary) 1bca

966 binary_mask = jnp.broadcast_to(binary_mask, y_train.shape[:-1]) 1bmicva

967 

968 return outcome_type, binary_mask 1bca

969 

970 @staticmethod 

971 def _process_sparsity_settings( 

972 x_train: Real[Array, 'p n'], 

973 sparse: bool, 

974 theta: FloatLike | None, 

975 a: FloatLike, 

976 b: FloatLike, 

977 rho: FloatLike | None, 

978 ) -> ( 

979 tuple[None, None, None, None] 

980 | tuple[FloatLike, None, None, None] 

981 | tuple[None, FloatLike, FloatLike, FloatLike] 

982 ): 

983 """Return (theta, a, b, rho).""" 

984 if not sparse: 1wbxhca

985 return None, None, None, None 1wxh

986 elif theta is not None: 1bzica

987 return theta, None, None, None 1zi

988 else: 

989 if rho is None: 989 ↛ 992line 989 didn't jump to line 992 because the condition on line 989 was always true1bca

990 p, _ = x_train.shape 1bca

991 rho = float(p) 1bca

992 return None, a, b, rho 1bca

993 

994 @staticmethod 

995 def _process_offset_settings( 

996 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'], 

997 binary_mask: Bool[Array, ''] | Bool[Array, ' k'], 

998 offset: float | Float32[Any, ''] | Float32[Any, ' k'] | None, 

999 ) -> Float32[Array, ''] | Float32[Array, ' k']: 

1000 """Return offset.""" 

1001 if offset is not None: 1btca

1002 off = jnp.asarray(offset, jnp.float32) 1t

1003 return jnp.broadcast_to(off, y_train.shape[:-1]) 1t

1004 if y_train.shape[-1] < 1: 1bucCa

1005 return jnp.zeros(y_train.shape[:-1]) 1uC

1006 

1007 bound = 1 / (1 + y_train.shape[-1]) 1bca

1008 binary_offset = ndtri(jnp.clip((y_train != 0).mean(-1), bound, 1 - bound)) 1bca

1009 continuous_offset = y_train.mean(-1) 1bca

1010 return jnp.where(binary_mask, binary_offset, continuous_offset) 1bca

1011 

1012 @staticmethod 

1013 def _process_leaf_variance_settings( 

1014 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'], 

1015 binary_mask: Bool[Array, ''] | Bool[Array, ' k'], 

1016 k: FloatLike, 

1017 num_trees: int, 

1018 tau_num: FloatLike | None, 

1019 ) -> Float32[Array, ''] | Float32[Array, 'k k']: 

1020 """Return `leaf_prior_cov_inv`.""" 

1021 # determine `tau_num` if not specified 

1022 if tau_num is None: 1022 ↛ 1030line 1022 didn't jump to line 1030 because the condition on line 1022 was always true1bca

1023 if y_train.shape[-1] < 2: 1bucCa

1024 continuous_tau = jnp.ones(y_train.shape[:-1]) 1uC

1025 else: 

1026 continuous_tau = (y_train.max(-1) - y_train.min(-1)) / 2 1bca

1027 tau_num = jnp.where(binary_mask, 3.0, continuous_tau) 1bca

1028 

1029 # leaf prior standard deviation 

1030 sigma_mu = tau_num / (k * math.sqrt(num_trees)) 1bca

1031 

1032 # leaf prior precision matrix 

1033 leaf_prior_cov_inv = jnp.reciprocal(jnp.square(sigma_mu)) 1bca

1034 if y_train.ndim == 2: 1bBcaH0

1035 leaf_prior_cov_inv = jnp.diag( 1Bca

1036 jnp.broadcast_to(leaf_prior_cov_inv, y_train.shape[:-1]) 

1037 ) 

1038 return leaf_prior_cov_inv 1bcaH0

1039 

1040 @staticmethod 

1041 def _determine_splits( 

1042 x_train: Real[Array, 'p n'], 

1043 usequants: bool, 

1044 numcut: int, 

1045 xinfo: Float[Array, 'p n'] | None, 

1046 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

1047 if xinfo is not None: 1bucCa

1048 if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]: 1uTC

1049 msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)' 1T

1050 raise ValueError(msg) 1T

1051 return prepcovars.parse_xinfo(xinfo) 1uC

1052 elif usequants: 1bjhca

1053 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1jh

1054 else: 

1055 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1bca

1056 

1057 @staticmethod 

1058 def _bin_predictors( 

1059 x: Real[Array, 'p n'], splits: Real[Array, 'p max_num_splits'] 

1060 ) -> UInt[Array, 'p n']: 

1061 return prepcovars.bin_predictors(x, splits) 1bca

1062 

1063 @staticmethod 

1064 def _setup_mcmc( 

1065 x_train: Real[Array, 'p n'], 

1066 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'], 

1067 outcome_type: OutcomeType | tuple[OutcomeType, ...], 

1068 offset: Float32[Array, ''] | Float32[Array, ' k'], 

1069 w: Float[Array, ' n'] | None, 

1070 max_split: UInt[Array, ' p'], 

1071 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

1072 error_cov_df: FloatLike | None, 

1073 error_cov_scale: FloatLike | Float32[Array, 'k k'] | None, 

1074 power: FloatLike, 

1075 base: FloatLike, 

1076 maxdepth: int, 

1077 num_trees: int, 

1078 init_kw: Mapping[str, Any], 

1079 rm_const: bool, 

1080 theta: FloatLike | None, 

1081 a: FloatLike | None, 

1082 b: FloatLike | None, 

1083 rho: FloatLike | None, 

1084 varprob: Float[Any, ' p'] | None, 

1085 num_chains: int | None, 

1086 num_chain_devices: int | None, 

1087 num_data_devices: int | None, 

1088 devices: Device | Sequence[Device] | None, 

1089 sparse: bool, 

1090 nskip: int, 

1091 ) -> mcmcstep.State: 

1092 p_nonterminal = make_p_nonterminal(maxdepth, base, power) 1bca

1093 

1094 # process device settings 

1095 device_kw, device = process_device_settings( 1bca

1096 y_train, num_chains, num_chain_devices, num_data_devices, devices 

1097 ) 

1098 

1099 kw: dict = dict( 1wbxhca

1100 X=x_train, 

1101 # copy y_train because it's going to be donated in the mcmc loop 

1102 y=jnp.array(y_train), 

1103 outcome_type=outcome_type, 

1104 offset=offset, 

1105 # copy w because it's going to be donated in init 

1106 error_scale=None if w is None else jnp.array(w), 

1107 max_split=max_split, 

1108 num_trees=num_trees, 

1109 p_nonterminal=p_nonterminal, 

1110 leaf_prior_cov_inv=leaf_prior_cov_inv, 

1111 error_cov_df=error_cov_df, 

1112 error_cov_scale=error_cov_scale, 

1113 min_points_per_decision_node=10, 

1114 log_s=process_varprob(varprob, max_split), 

1115 theta=theta, 

1116 a=a, 

1117 b=b, 

1118 rho=rho, 

1119 sparse_on_at=nskip // 2 if sparse else None, 

1120 **device_kw, 

1121 ) 

1122 

1123 if rm_const: 1wboxhcpa

1124 n_empty = jnp.sum(max_split == 0).item() 1bca

1125 kw.update(filter_splitless_vars=n_empty) 1bca

1126 

1127 kw.update(init_kw) 1bocpa

1128 

1129 state = mcmcstep.init(**kw) 1bca

1130 

1131 # put state on device if requested explicitly by the user 

1132 if device is not None: 1bocpa

1133 state = device_put(state, device, donate=True) 1op

1134 

1135 return state 1bca

1136 

1137 @classmethod 

1138 def _run_mcmc( 

1139 cls, 

1140 mcmc_state: mcmcstep.State, 

1141 ndpost: int, 

1142 nskip: int, 

1143 keepevery: int, 

1144 printevery: int | None, 

1145 seed: int | Integer[Array, ''] | Key[Array, ''], 

1146 run_mcmc_kw: Mapping, 

1147 ) -> RunMCMCResult: 

1148 # prepare random generator seed 

1149 if is_key(seed): 1b1ca

1150 key = jnp.copy(seed) 1bca

1151 else: 

1152 key = jax.random.key(seed) 11

1153 

1154 # round up ndpost 

1155 num_chains = get_num_chains(mcmc_state) 1bca

1156 if num_chains is None: 1bMhca

1157 num_chains = 1 1Mca

1158 n_save = ndpost // num_chains + bool(ndpost % num_chains) 1bhca

1159 

1160 # prepare arguments 

1161 kw: dict = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery) 1bca

1162 kw.update( 1b2ohc3a

1163 mcmcloop.make_default_callback( 

1164 mcmc_state, 

1165 dot_every=None if printevery is None or printevery == 1 else 1, 

1166 report_every=printevery, 

1167 ) 

1168 ) 

1169 kw.update(run_mcmc_kw) 1b2ohc3a

1170 

1171 return run_mcmc(key, mcmc_state, n_save, **kw) 1bca

1172 

1173 def _predict( 

1174 self, x: UInt[Array, 'p m'] 

1175 ) -> Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m']: 

1176 """Evaluate trees on already quantized `x`.""" 

1177 return predict(x, self._main_trace) 1bga

1178 

1179 def check_trees( 

1180 self, error: bool = False 

1181 ) -> UInt[Array, 'num_chains ndpost/num_chains num_trees']: 

1182 """Apply `bartz.grove.check_trace` to all the tree draws. 

1183 

1184 Parameters 

1185 ---------- 

1186 error 

1187 If `True`, throw an error if any invalid trees are found. 

1188 

1189 Returns 

1190 ------- 

1191 An array where non-zero entries indicate invalid trees. 

1192 

1193 Raises 

1194 ------ 

1195 RuntimeError 

1196 If `error` is `True` and any invalid trees are found. 

1197 """ 

1198 out: UInt[Array, '*chains samples num_trees'] 

1199 out = check_trace(self._main_trace, self._mcmc_state.forest.max_split) 1bca

1200 if out.ndim < 3: 1bMhca

1201 out = out[None, :, :] 1Mca

1202 if error: 1202 ↛ 1207line 1202 didn't jump to line 1207 because the condition on line 1202 was always true1bhca

1203 bad_count = jnp.count_nonzero(out) 1bca

1204 if bad_count > 0: 1204 ↛ 1205line 1204 didn't jump to line 1205 because the condition on line 1204 was never true1bca

1205 msg = f'Found {bad_count} invalid trees in the MCMC trace.' 

1206 raise RuntimeError(msg) 

1207 return out 1bca

1208 

1209 def check_replicated_trees(self) -> None: 

1210 """Check that the trees are equal across data-sharded devices. 

1211 

1212 If the data is sharded across devices, verify that the trees (which 

1213 should be replicated) are identical on all shards. 

1214 

1215 Raises 

1216 ------ 

1217 RuntimeError 

1218 If the trees differ across devices. 

1219 """ 

1220 state = self._mcmc_state 1bca

1221 mesh = state.config.mesh 1bca

1222 if mesh is not None and 'data' in mesh.axis_names: 1qbjmhica

1223 replicated_forest = replace(state.forest, leaf_indices=None) 1jh

1224 equal = equal_shards( 1jh

1225 replicated_forest, 'data', in_specs=PartitionSpec(), mesh=mesh 

1226 ) 

1227 equal_array = jnp.stack(tree.leaves(equal)) 1jh

1228 all_equal = jnp.all(equal_array) 1jh

1229 if not all_equal.item(): 1229 ↛ 1230line 1229 didn't jump to line 1230 because the condition on line 1229 was never true1jh

1230 msg = 'The trees differ across data-sharded devices.' 

1231 raise RuntimeError(msg) 

1232 

1233 def compare_resid( 

1234 self, y: Float32[Array, ' n'] | Float32[Array, 'k n'] | None = None 

1235 ) -> tuple[ 

1236 Float32[Array, '*num_chains n'] | Float32[Array, '*num_chains k n'], 

1237 Float32[Array, '*num_chains n'] | Float32[Array, '*num_chains k n'], 

1238 ]: 

1239 """Re-compute residuals to compare them with the updated ones. 

1240 

1241 Parameters 

1242 ---------- 

1243 y 

1244 The response variable. Required for continuous regression (since 

1245 ``State`` does not store ``y`` in continuous mode). Ignored for 

1246 binary regression (where ``State.z`` is used instead). 

1247 

1248 Returns 

1249 ------- 

1250 resid1 

1251 The final state of the residuals updated during the MCMC. 

1252 resid2 

1253 The residuals computed from the final state of the trees. 

1254 """ 

1255 state = self._mcmc_state 1ry

1256 resid1 = state.resid 1ry

1257 

1258 forests = TreesTrace.from_dataclass(state.forest) 1ry

1259 trees = evaluate_forest(state.X, forests, sum_batch_axis=-1) 1ry

1260 

1261 if state.binary_indices is not None: 1ryJ

1262 # mixed binary-continuous: z has only binary rows, y has all rows 

1263 assert y is not None, 'y is required for mixed regression' 1J

1264 ref = jnp.asarray(y) 1J

1265 ref = jnp.broadcast_to(ref, state.resid.shape) 1J

1266 ref = ref.at[..., state.binary_indices, :].set(state.z) 1J

1267 elif state.z is not None: 1r4y

1268 ref = state.z 14y

1269 else: 

1270 assert y is not None, 'y is required for continuous regression' 1r

1271 ref = jnp.asarray(y) 1r

1272 resid2 = ref - (trees + state.offset[..., None]) 1ry

1273 

1274 return resid1, resid2 1ry

1275 

1276 def depth_distr(self) -> Int32[Array, '*num_chains ndpost/num_chains d']: 

1277 """Histogram of tree depths for each state of the trees. 

1278 

1279 Returns 

1280 ------- 

1281 A matrix where each row contains a histogram of tree depths. 

1282 """ 

1283 out: Int32[Array, '*chains samples d'] 

1284 out = forest_depth_distr(self._main_trace.split_tree) 1N

1285 if out.ndim < 3: 1285 ↛ 1287line 1285 didn't jump to line 1287 because the condition on line 1285 was always true1N

1286 out = out[None, :, :] 1N

1287 return out 1N

1288 

1289 def _points_per_node_distr( 

1290 self, node_type: str 

1291 ) -> Int32[Array, '*num_chains ndpost/num_chains n+1']: 

1292 out: Int32[Array, '*chains samples n+1'] 

1293 out = points_per_node_distr( 1KO

1294 self._mcmc_state.X, 

1295 self._main_trace.var_tree, 

1296 self._main_trace.split_tree, 

1297 node_type, 

1298 sum_batch_axis=-1, 

1299 ) 

1300 if out.ndim < 3: 1K5O

1301 out = out[None, :, :] 1K

1302 return out 1K5O

1303 

1304 def points_per_decision_node_distr( 

1305 self, 

1306 ) -> Int32[Array, '*num_chains ndpost/num_chains n+1']: 

1307 """Histogram of number of points belonging to parent-of-leaf nodes. 

1308 

1309 Returns 

1310 ------- 

1311 For each chain, a matrix where each row contains a histogram of number of points. 

1312 """ 

1313 return self._points_per_node_distr('leaf-parent') 1KO

1314 

1315 def points_per_leaf_distr( 

1316 self, 

1317 ) -> Int32[Array, '*num_chains ndpost/num_chains n+1']: 

1318 """Histogram of number of points belonging to leaves. 

1319 

1320 Returns 

1321 ------- 

1322 A matrix where each row contains a histogram of number of points. 

1323 """ 

1324 return self._points_per_node_distr('leaf') 1!#

1325 

1326 

1327@partial(jit, static_argnames='p') 

1328# this is jitted such that lax.collapse below does not create a copy 

1329def varcount(p: int, trace: mcmcloop.MainTrace) -> Int32[Array, 'ndpost p']: 

1330 """Histogram of predictor usage for decision rules in the trees, squashing chains.""" 

1331 varcount: Int32[Array, '*chains samples p'] 

1332 varcount = compute_varcount(p, trace) 1lea

1333 return lax.collapse(varcount, 0, -1) 1lea

1334 

1335 

1336@jit 

1337# this is jitted such that lax.collapse below does not create a copy 

1338def predict( 

1339 x: UInt[Array, 'p m'], trace: mcmcloop.MainTrace 

1340) -> Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m']: 

1341 """Evaluate trees on already quantized `x`, and squash chains.""" 

1342 out = evaluate_trace(x, trace) 1bga

1343 # For MV, out has shape (*trace_shape, k, n); for UV, (*trace_shape, n). 

1344 # We must collapse only the chain/sample dims, not k. 

1345 # Detect MV: leaf_tree has an extra axis compared to split_tree. 

1346 is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim 1bga

1347 end = -2 if is_mv else -1 1b6ga

1348 return lax.collapse(out, 0, end) 1b6ga

1349 

1350 

1351class DeviceKwArgs(TypedDict): 

1352 num_chains: int | None 

1353 mesh: Mesh | None 

1354 target_platform: Literal['cpu', 'gpu'] | None 

1355 

1356 

1357def process_device_settings( 

1358 y_train: Array, 

1359 num_chains: int | None, 

1360 num_chain_devices: int | None, 

1361 num_data_devices: int | None, 

1362 devices: Device | Sequence[Device] | None, 

1363) -> tuple[DeviceKwArgs, Device | None]: 

1364 """Return the arguments for `mcmcstep.init` related to devices, and an optional device where to put the state.""" 

1365 # determine devices 

1366 if devices is not None: 1bocpa

1367 if not hasattr(devices, '__len__'): 1367 ↛ 1368line 1367 didn't jump to line 1368 because the condition on line 1367 was never true1op

1368 devices = (devices,) 

1369 device = devices[0] 1op

1370 platform = device.platform 1op

1371 elif hasattr(y_train, 'platform'): 1371 ↛ 1379line 1371 didn't jump to line 1379 because the condition on line 1371 was always true1bca

1372 platform = y_train.platform() 1bca

1373 device = None 1bca

1374 # set device=None because if the devices were not specified explicitly 

1375 # we may be in the case where computation will follow data placement, 

1376 # do not disturb jax as the user may be playing with vmap, jit, reshard... 

1377 devices = jax.devices(platform) 1bca

1378 else: 

1379 msg = 'not possible to infer device from `y_train`, please set `devices`' 

1380 raise ValueError(msg) 

1381 

1382 # create mesh 

1383 if num_chain_devices is None and num_data_devices is None: 1qbjmhica

1384 mesh = None 1bca

1385 else: 

1386 mesh = dict() 1qjmhi

1387 if num_chain_devices is not None: 1qjmhi

1388 mesh.update(chains=num_chain_devices) 1qmi

1389 if num_data_devices is not None: 1qjmhi

1390 mesh.update(data=num_data_devices) 1jh

1391 mesh = make_mesh( 1qjmhi

1392 axis_shapes=tuple(mesh.values()), 

1393 axis_names=tuple(mesh), 

1394 axis_types=(AxisType.Auto,) * len(mesh), 

1395 devices=devices, 

1396 ) 

1397 device = None 1qjh

1398 # set device=None because `mcmcstep.init` will `device_put` with the 

1399 # mesh already, we don't want to undo its work 

1400 

1401 # prepare arguments to `init` 

1402 settings = DeviceKwArgs( 1qbjohcpa

1403 num_chains=num_chains, 

1404 mesh=mesh, 

1405 target_platform=None 

1406 if mesh is not None or hasattr(y_train, 'platform') 

1407 else platform, 

1408 # here we don't take into account the case where the user has set both 

1409 # batch sizes; since the user has to be playing with `init_kw` to do 

1410 # that, we'll let `init` throw the error and the user set 

1411 # `target_platform` themselves so they have a clearer idea how the 

1412 # thing works. 

1413 ) 

1414 

1415 return settings, device 1qbjohcpa

1416 

1417 

1418def process_varprob( 

1419 varprob: Float[Any, ' p'] | None, max_split: UInt[Array, ' p'] 

1420) -> Float32[Array, ' p'] | None: 

1421 """Convert varprob to log_s.""" 

1422 if varprob is None: 1wbxLca

1423 return None 1bca

1424 varprob = jnp.asarray(varprob) 1wxL

1425 assert varprob.shape == max_split.shape, 'varprob must have shape (p,)' 1wxL

1426 varprob = error_if(varprob, varprob <= 0, 'varprob must be > 0') 1wxL

1427 return jnp.log(varprob) 1wxL