Coverage for src/bartz/_interface.py: 95%

628 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/_interface.py 

2# 

3# Copyright (c) 2025-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Main high-level interface of the package.""" 

26 

27import math 

28import pickle 

29from collections.abc import Collection, Hashable, Mapping, Sequence 

30from dataclasses import replace 

31from enum import Enum 

32from functools import cached_property 

33from os import PathLike, cpu_count 

34from pathlib import Path 

35 

36# WORKAROUND(python<3.15): use frozendict instead of MappingProxyType 

37from types import MappingProxyType 

38from typing import Any, Literal, Protocol, TypedDict, overload, runtime_checkable 

39from warnings import warn 

40 

41import jax 

42import jax.numpy as jnp 

43from equinox import Module, error_if, field, tree_at 

44from jax import Device, debug_nans, device_put, lax, make_mesh, random, tree 

45from jax.scipy.linalg import solve_triangular 

46from jax.scipy.special import ndtr, ndtri 

47from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec 

48from jax.typing import DTypeLike 

49from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real, Shaped, UInt 

50from numpy import ndarray 

51 

52from bartz._jaxext import equal_shards, is_key, jit, split 

53from bartz.grove import ( 

54 TreeHeaps, 

55 TreesTrace, 

56 check_trace, 

57 evaluate_forest, 

58 forest_depth_distr, 

59 format_tree, 

60 points_per_node_distr, 

61) 

62from bartz.mcmcloop import ( 

63 BurninTrace, 

64 MainTrace, 

65 RunMCMCResult, 

66 compute_varcount, 

67 evaluate_trace, 

68 make_print_callback, 

69 make_tqdm_callback, 

70 run_mcmc, 

71) 

72from bartz.mcmcstep import DiagWishart, OutcomeType, Wishart, make_p_nonterminal 

73from bartz.mcmcstep._axes import ( 

74 chain_to_axis, 

75 chain_vmap_axes, 

76 chainful_axis, 

77 get_has_chains, 

78 trace_sample_axes, 

79) 

80from bartz.mcmcstep._state import ( 

81 ArrayLike, 

82 FloatLike, 

83 State, 

84 _inv_via_chol_with_gersh, 

85 _leaf_partition_spec, 

86 chol_with_gersh, 

87 init, 

88) 

89from bartz.prepcovars import Binner, BinnerFactory, UniqueQuantileBinner 

90 

91 

92class PredictKind(Enum): 

93 """Kind of output of `Bart.predict`.""" 

94 

95 mean = 'mean' 

96 """The posterior mean of the conditional mean, shape ``(m,)`` (or 

97 ``(k, m)`` for multivariate regression).""" 

98 

99 mean_samples = 'mean_samples' 

100 """Per-sample conditional mean, shape ``(num_chains * n_save, m)`` 

101 (or ``(num_chains * n_save, k, m)``). For binary regression, this is 

102 the probit-transformed sum-of-trees.""" 

103 

104 outcome_samples = 'outcome_samples' 

105 """Samples of the outcome variable, shape ``(num_chains * n_save, 

106 m)`` (or ``(num_chains * n_save, k, m)``). For binary regression, 

107 these are Bernoulli draws. For continuous regression, these are 

108 Gaussian draws with the posterior noise variance.""" 

109 

110 latent_samples = 'latent_samples' 

111 """Raw sum-of-trees values, shape ``(num_chains * n_save, m)`` (or 

112 ``(num_chains * n_save, k, m)``).""" 

113 

114 

115@runtime_checkable 

116class DataFrame(Protocol): 

117 """DataFrame duck-type for `Bart`.""" 

118 

119 @property 

120 def columns(self) -> Collection[str]: 

121 """The names of the columns.""" 

122 ... 

123 

124 def to_numpy(self) -> Shaped[ndarray, '*shape']: 

125 """Convert the dataframe to a 2d numpy array with columns on the second axis.""" 

126 ... 

127 

128 

129@runtime_checkable 

130class Series(Protocol): 

131 """Series duck-type for `Bart`.""" 

132 

133 @property 

134 def name(self) -> Hashable: 

135 """The name of the series.""" 

136 ... 

137 

138 def to_numpy(self) -> Shaped[ndarray, '*shape']: 

139 """Convert the series to a 1d numpy array.""" 

140 ... 

141 

142 

143class SparseConfig(Module): 

144 R""" 

145 Configuration of a sparsity-inducing variable selection prior. 

146 

147 This is the prior of [1]_. Pass an instance to the `sparse` argument of 

148 `Bart` to activate variable selection on the predictors. The prior on the 

149 choice of predictor for each decision rule is 

150 

151 .. math:: 

152 (s_1, \ldots, s_p) \sim 

153 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p). 

154 

155 If `theta` is not specified, it's a priori distributed according to 

156 

157 .. math:: 

158 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim 

159 \operatorname{Beta}(\mathtt{a}, \mathtt{b}). 

160 

161 References 

162 ---------- 

163 .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for 

164 High-Dimensional Prediction and Variable Selection”. In: Journal of the 

165 American Statistical Association 113.522, pp. 626-636. 

166 """ 

167 

168 theta: FloatLike | None = None 

169 """Concentration of the Dirichlet prior. If not specified, it is sampled 

170 from a Beta prior parametrized by `a`, `b` and `rho`. If set directly, it 

171 should be in the ballpark of the predictor count p or lower.""" 

172 

173 a: FloatLike = 0.5 

174 """Shape parameter of the Beta prior on ``theta / (theta + rho)``.""" 

175 

176 b: FloatLike = 1.0 

177 """Shape parameter of the Beta prior on ``theta / (theta + rho)``.""" 

178 

179 rho: FloatLike | None = None 

180 """Scale of the Beta prior on `theta`. If not specified, set to the number 

181 of predictors p. Lower values prefer more sparsity.""" 

182 

183 augment: bool = field(static=True, default=True) 

184 """Whether to account exactly for the decision rules forbidden by the 

185 ancestors of each node when updating the variable selection probabilities, 

186 using data augmentation. On by default. Setting it to `False` ignores the 

187 forbidden rules, which is faster but only approximate. This matters most 

188 with few predictors with few cutpoints each, where the same predictor 

189 cannot be re-used down a branch.""" 

190 

191 enabled: bool = field(static=True, default=True) 

192 """Whether variable selection is active.""" 

193 

194 

195class Bart(Module): 

196 R""" 

197 Nonparametric regression with Bayesian Additive Regression Trees (BART). 

198 

199 Regress `y_train` on `x_train` with a latent mean function represented as 

200 a sum of decision trees [2]_. The inference is carried out by sampling the 

201 posterior distribution of the tree ensemble with an MCMC. 

202 

203 Parameters 

204 ---------- 

205 x_train 

206 The training predictors. 

207 y_train 

208 The training responses. For univariate regression, a 1D array of shape 

209 `(n,)`. For multivariate regression, a 2D array of shape `(k, n)` where 

210 `k` is the number of response components, as introduced in [3]_. For 

211 binary regression, the convention is that non-zero values mean 1, zero 

212 mean 0, like booleans. 

213 outcome_type 

214 The type of regression. ``'continuous'`` for continuous regression, 

215 ``'binary'`` for binary regression with probit link. For multivariate 

216 regression, a scalar value applies to all components; alternatively, a 

217 sequence of per-component types (e.g., ``['binary', 'continuous']``) 

218 specifies mixed outcome types. Binary components in multivariate 

219 outcomes follow the multivariate probit BART formulation of [4]_. 

220 sparse 

221 A `SparseConfig` for the sparsity-inducing variable selection prior of 

222 [1]_. Disabled by default; pass a `SparseConfig` to enable it. 

223 varprob 

224 The probability distribution over the `p` predictors for choosing a 

225 predictor to split on in a decision node a priori. Must be > 0. It does 

226 not need to be normalized to sum to 1. If not specified, use a uniform 

227 distribution. If `sparse` is enabled, this is used as initial value for 

228 the MCMC. 

229 binner 

230 A callable that, given the training predictors and a random key, 

231 returns a `~bartz.prepcovars.Binner` instance. The default is 

232 `~bartz.prepcovars.UniqueQuantileBinner`, which places cutpoints at 

233 the quantiles of each predictor. Other built-in options are 

234 `~bartz.prepcovars.RangeEvenBinner` (evenly-spaced cutpoints over the 

235 observed range) and `~bartz.prepcovars.GivenSplitsBinner` (R BART 

236 ``xinfo`` format). To pass options, use `functools.partial`, e.g. 

237 ``binner=partial(UniqueQuantileBinner, max_bins=128)``. 

238 rm_const 

239 How to treat predictors with no associated decision rules (i.e., there 

240 are no available cutpoints for that predictor). If `True` (default), 

241 they are ignored. If `False`, an error is raised if there are any. 

242 sigma_df 

243 The degrees of freedom of the prior on the error precision. For 

244 multivariate regression with `k` components, the Wishart degrees of 

245 freedom are set to ``sigma_df + k - 1``. 

246 sigma_scale 

247 Sets the scale of the prior on the error precision. If 'auto' (default), 

248 the prior is scaled so that the error precision equals 

249 ``diag(1 / var(y_train))`` in expectation, where with weights `error_scale` 

250 the variance is a precision-weighted one that estimates the unit-weight error 

251 variance. Otherwise, ``square(sigma_scale)`` is the prior harmonic mean of 

252 the error variance; for multivariate regression a scalar is broadcast to 

253 all components. For mixed outcome types, binary components are ignored. 

254 sigma_init 

255 The initial value of the error standard deviation in the MCMC. If 'auto' 

256 (default), the initial error precision is set to ``diag(1 / var(y_train))``, 

257 with the same precision-weighted variance as `sigma_scale` when weights are 

258 given. Otherwise, the initial precision is ``diag(1 / square(sigma_init))``; 

259 for multivariate regression a scalar is broadcast to all components. For 

260 mixed outcome types, binary components are ignored. 

261 k 

262 The inverse scale of the prior standard deviation on the latent mean 

263 function, relative to half the observed range of `y_train`. If `y_train` 

264 has less than two elements, `k` is ignored and the scale is set to 1. 

265 power 

266 base 

267 Parameters of the prior on tree node generation. The probability that a 

268 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) ** 

269 power``. 

270 tau_num 

271 The numerator in the expression that determines the prior standard 

272 deviation of leaves. If not specified, default to ``(max(y_train) - 

273 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for 

274 continuous regression, and 3 for binary regression. For multivariate 

275 regression, the range is computed per component. For mixed outcome 

276 types, each component uses the default for its type. 

277 offset 

278 The prior mean of the latent mean function. If not specified, it is set 

279 to the mean of `y_train` for continuous regression, and to 

280 ``Phi^-1(mean(y_train != 0))`` for binary regression. If `y_train` is 

281 empty, `offset` is set to 0. With binary regression, if `y_train` is 

282 all zero or all non-zero, it is set to ``Phi^-1(1/(n+1))`` or 

283 ``Phi^-1(n/(n+1))``, respectively. For multivariate regression, can be 

284 a scalar (broadcast to all components) or a `(k,)` vector. If not 

285 specified, it is set to the per-component mean of `y_train`. For mixed 

286 outcome types, each component uses the default for its type. 

287 error_scale 

288 Coefficients that rescale the error standard deviation on each 

289 datapoint. Not specifying `error_scale` is equivalent to setting it to 1 

290 for all datapoints. Shape ``(n,)`` applies the same scalar weight to every 

291 outcome component; for multivariate continuous regression, ``(k, n)`` 

292 instead supplies a per-component weight per datapoint. 

293 missing 

294 Boolean mask with the same shape as `y_train`; `True` marks entries 

295 to be ignored by the MCMC. Values of `y_train` must be finite 

296 everywhere, including at masked positions. If 2-D, the error 

297 covariance must be diagonal. 

298 num_trees 

299 The number of trees used to represent the latent mean function. 

300 n_save 

301 The number of MCMC samples to save, after burn-in, per chain. The 

302 total trace length across all chains is ``num_chains * n_save``. 

303 n_burn 

304 The number of initial MCMC samples to discard as burn-in. This number 

305 of samples is discarded from each chain. 

306 n_skip 

307 The thinning factor for the MCMC samples, after burn-in. 

308 printevery 

309 The number of iterations (including thinned-away ones) between each log 

310 line. Set to `None` to disable progress reporting entirely (this ignores 

311 `pbar`). ^C interrupts the MCMC only every `printevery` iterations, so 

312 with reporting disabled it's impossible to kill the MCMC conveniently. 

313 pbar 

314 If `True`, show a `tqdm` progress bar instead of printing log lines. The 

315 bar advances every iteration and refreshes the acceptance statistics 

316 every `printevery` iterations. Ignored if `printevery` is `None`. 

317 num_chains 

318 The number of independent Markov chains to run. 

319 

320 The difference between ``num_chains=None`` and ``num_chains=1`` is that 

321 in the latter case in the object attributes and some methods there will 

322 be an explicit chain axis of size 1. 

323 num_chain_devices 

324 The number of devices to spread the chains across. Must be a divisor of 

325 `num_chains`. Each device will run a fraction of the chains. If 'auto' 

326 (default) and running on cpu, the number of devices is picked 

327 automatically based on the number of cores and the number of available 

328 devices (all the virtual jax cpu devices, or the `devices` list if set). 

329 num_data_devices 

330 The number of devices to split datapoints across. Must be a divisor of 

331 `n`. This is useful only with very high `n`, about > 1000_000. `predict` 

332 parallelizes across the same devices, splitting the test points; the 

333 number of test points must be a multiple of `num_data_devices` as well. 

334 

335 If both num_chain_devices and num_data_devices are specified, the total 

336 number of devices used is the product of the two. 

337 devices 

338 One or more devices used to run the MCMC on. If not specified, the 

339 computation will follow the placement of the input arrays. If a list of 

340 devices, this argument can be longer than the number of devices needed. 

341 seed 

342 The seed for the random number generator. 

343 maxdepth 

344 The maximum depth of the trees. This is 1-based, so with the default 

345 ``maxdepth=6``, the depths of the levels range from 0 to 5. 

346 init_kw 

347 Additional arguments passed to `bartz.mcmcstep.init`. 

348 run_mcmc_kw 

349 Additional arguments passed to `bartz.mcmcloop.run_mcmc`. 

350 

351 References 

352 ---------- 

353 .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for 

354 High-Dimensional Prediction and Variable Selection”. In: Journal of the 

355 American Statistical Association 113.522, pp. 626-636. 

356 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART: 

357 Bayesian additive regression trees," The Annals of Applied Statistics, 

358 Ann. Appl. Stat. 4(1), 266-298, (March 2010). 

359 .. [3] Um, Seungha, Antonio R. Linero, Debajyoti Sinha, and Dipankar 

360 Bandyopadhyay (2023). "Bayesian additive regression trees for 

361 multivariate skewed responses". In: Statistics in Medicine 42.3, 

362 pp. 246-263. 

363 .. [4] Goh, Yong Chen, Wuu Kuang Soh, Andrew C. Parnell, and Keefe 

364 Murphy (2024). "Joint Models for Handling Non-Ignorable Missing 

365 Data using Bayesian Additive Regression Trees: Application to 

366 Leaf Photosynthetic Traits Data". arXiv:2412.14946 [stat.ME]. 

367 

368 """ 

369 

370 _main_trace: MainTrace 

371 _burnin_trace: BurninTrace 

372 _mcmc_state: State 

373 _binner: Binner 

374 _binary_mask: Bool[Array, ''] | Bool[Array, ' k'] 

375 # WORKAROUND(jax<0.9.1): use `jax.tree.static` instead of `field(static=True)` 

376 _x_train_fmt: Any = field(static=True) 

377 _device: Device | None = field(static=True) 

378 

379 _error_scale: Float32[Array, ' n'] | Float32[Array, 'k n'] | None = None 

380 

381 def __init__( 

382 self, 

383 x_train: Real[ArrayLike, 'p n'] | DataFrame, 

384 y_train: Float32[ArrayLike, ' n'] 

385 | Float32[ArrayLike, 'k n'] 

386 | Series 

387 | DataFrame, 

388 *, 

389 outcome_type: OutcomeType | str | Sequence[OutcomeType | str] = 'continuous', 

390 sparse: SparseConfig = SparseConfig(enabled=False), 

391 varprob: Float[ArrayLike, ' p'] | None = None, 

392 binner: BinnerFactory = UniqueQuantileBinner, 

393 rm_const: bool = True, 

394 sigma_df: FloatLike = 3.0, 

395 sigma_scale: FloatLike | Float[ArrayLike, ' k'] | Literal['auto'] = 'auto', 

396 sigma_init: FloatLike | Float[ArrayLike, ' k'] | Literal['auto'] = 'auto', 

397 k: FloatLike = 2.0, 

398 power: FloatLike = 2.0, 

399 base: FloatLike = 0.95, 

400 tau_num: FloatLike | None = None, 

401 offset: FloatLike | Float[ArrayLike, ' k'] | None = None, 

402 error_scale: Float[ArrayLike, ' n'] 

403 | Float[ArrayLike, 'k n'] 

404 | Series 

405 | DataFrame 

406 | None = None, 

407 missing: Bool[ArrayLike, ' n'] 

408 | Bool[ArrayLike, 'k n'] 

409 | Series 

410 | DataFrame 

411 | None = None, 

412 num_trees: int = 200, 

413 n_save: int = 1000, 

414 n_burn: int = 1000, 

415 n_skip: int = 1, 

416 printevery: int | None = 100, 

417 pbar: bool = True, 

418 num_chains: int | None = 4, 

419 num_chain_devices: int | None | Literal['auto'] = 'auto', 

420 num_data_devices: int | None = None, 

421 devices: Literal['cpu', 'gpu'] | Device | Sequence[Device] | None = None, 

422 seed: int | Key[Array, ''] = 0, 

423 maxdepth: int = 6, 

424 init_kw: Mapping = MappingProxyType({}), 

425 run_mcmc_kw: Mapping = MappingProxyType({}), 

426 ) -> None: 

427 # check data and put it in the right format 

428 x_train, x_train_fmt = _process_predictor_input(x_train) 

429 y_train = _process_response_input(y_train) 

430 _check_same_length(x_train, y_train) 

431 

432 if error_scale is not None: 

433 # keep=True because `error_scale` is donated downstream but also 

434 # retained as `self._error_scale` for prediction 

435 error_scale, self._error_scale = _process_response_input( 

436 error_scale, keep=True 

437 ) 

438 _check_same_length(x_train, error_scale) 

439 

440 if missing is not None: 

441 missing = _process_response_input(missing, dtype=jnp.bool_) 

442 _check_same_length(x_train, missing) 

443 

444 # check data types are correct for continuous/binary/multivariate regression 

445 outcome_type, binary_mask = _check_type_settings( 

446 y_train, outcome_type, error_scale 

447 ) 

448 

449 # process "standardization" settings 

450 offset = _process_offset_settings(y_train, binary_mask, offset) 

451 leaf_prior_cov_inv = _process_leaf_variance_settings( 

452 y_train, binary_mask, k, num_trees, tau_num 

453 ) 

454 error_cov_inv = _process_error_variance_settings( 

455 y_train, 

456 outcome_type, 

457 binary_mask, 

458 missing, 

459 sigma_df, 

460 sigma_scale, 

461 sigma_init, 

462 error_scale, 

463 ) 

464 

465 # split the user-provided seed into an mcmc key and a binner key 

466 if not is_key(seed): 

467 seed = random.key(seed) 

468 keys = split(seed) 

469 

470 # construct the binner and bin x_train 

471 binner_obj = binner(x_train, key=keys.pop()) 

472 x_train = binner_obj.bin(x_train) 

473 # copy max_split because `mcmcstep.init` donates it 

474 max_split = jnp.array(binner_obj.max_split) 

475 

476 # setup and run mcmc 

477 initial_state, mcmc_key, device = _setup_mcmc( 

478 x_train, 

479 y_train, 

480 outcome_type, 

481 offset, 

482 error_scale, 

483 missing, 

484 max_split, 

485 leaf_prior_cov_inv, 

486 error_cov_inv, 

487 power, 

488 base, 

489 maxdepth, 

490 num_trees, 

491 init_kw, 

492 rm_const, 

493 sparse, 

494 varprob, 

495 num_chains, 

496 num_chain_devices, 

497 num_data_devices, 

498 devices, 

499 n_burn, 

500 keys.pop(), 

501 ) 

502 result = _run_mcmc( 

503 initial_state, 

504 n_save, 

505 n_burn, 

506 n_skip, 

507 printevery, 

508 pbar, 

509 mcmc_key, 

510 run_mcmc_kw, 

511 ) 

512 

513 # set private attributes 

514 self._main_trace = result.main_trace 

515 self._burnin_trace = result.burnin_trace 

516 self._mcmc_state = result.final_state 

517 self._binner = binner_obj 

518 self._x_train_fmt = x_train_fmt 

519 self._binary_mask = binary_mask 

520 self._device = device 

521 

522 def predict( 

523 self, 

524 x_test: Real[ArrayLike, 'p m'] | DataFrame | str, 

525 *, 

526 kind: PredictKind | str = 'mean', 

527 key: Key[Array, ''] | None = None, 

528 error_scale: Float[ArrayLike, ' m'] 

529 | Float[ArrayLike, 'k m'] 

530 | Series 

531 | DataFrame 

532 | None = None, 

533 ) -> ( 

534 Float32[Array, ' m'] 

535 | Float32[Array, 'k m'] 

536 | Float32[Array, 'ndpost m'] 

537 | Float32[Array, 'ndpost k m'] 

538 ): 

539 """ 

540 Compute predictions at `x_test`. 

541 

542 Parameters 

543 ---------- 

544 x_test 

545 The test predictors, or the string ``'train'`` to compute 

546 predictions on the training data. 

547 kind 

548 The kind of output. See `PredictKind` for details. 

549 key 

550 Jax random key, required when ``kind='outcome_samples'``. 

551 error_scale 

552 Per-observation error scale for ``kind='outcome_samples'``. 

553 Required when the model was fit with weights and ``x_test`` is 

554 new data. Shape matches the shape used at fitting: ``(m,)`` for 

555 scalar weights, ``(k, m)`` for multivariate vector weights. 

556 

557 Returns 

558 ------- 

559 Predictions at `x_test` in the requested format. 

560 

561 Raises 

562 ------ 

563 ValueError 

564 If `x_test` has a different format than `x_train`, or if `error_scale` 

565 is specified when it should be `None`, or if `error_scale` is not 

566 specified when it is required, or if the model splits datapoints 

567 across devices (`num_data_devices`) and the number of test points 

568 is not a multiple of the number of data devices. 

569 

570 Notes 

571 ----- 

572 If the model splits datapoints across devices (`num_data_devices`), 

573 the test points and the returned predictions are split the same way. 

574 """ 

575 # parse arguments 

576 kind = PredictKind(kind) 

577 if kind is PredictKind.outcome_samples and key is None: 577 ↛ 578line 577 didn't jump to line 578 because the condition on line 577 was never true

578 msg = '`key` not specified' 

579 raise ValueError(msg) 

580 error_scale = self._process_error_scale_test(x_test, kind, error_scale) 

581 x_test_is_train = isinstance(x_test, str) and x_test == 'train' 

582 x_test = self._process_x_test(x_test, error_scale) 

583 

584 # place new test data on the devices of the model; the training data 

585 # is already in place 

586 if not x_test_is_train: 

587 x_test, error_scale = self._device_put_test(x_test, error_scale) 

588 

589 # invoke jitted implementation 

590 return predict( 

591 key, 

592 self._main_trace, 

593 x_test, 

594 error_scale, 

595 self._mcmc_state.binary_indices, 

596 self._mcmc_state.binary_y is not None, 

597 kind, 

598 # the test points are sharded over the mesh 'data' axis (when 

599 # there is one): the training data at `init`, new test data by 

600 # `_device_put_test`. `evaluate_trace` can't detect this on its 

601 # own at trace time, so declare it. 

602 'shard_and_autobatch', 

603 ) 

604 

605 def _drop_device_info(self) -> 'Bart': 

606 """Return a copy of the model without device placement metadata. 

607 

608 Clear the meshes in the MCMC state config and in the traces, and the 

609 explicitly requested device. Only this static metadata is dropped: the 

610 arrays keep their actual placement. 

611 """ 

612 config = replace(self._mcmc_state.config, mesh=None) 

613 main_trace = replace(self._main_trace, mesh=None) 

614 burnin_trace = replace(self._burnin_trace, mesh=None) 

615 obj = tree_at( 

616 lambda b: (b._mcmc_state.config, b._main_trace, b._burnin_trace), # noqa: SLF001 

617 self, 

618 (config, main_trace, burnin_trace), 

619 ) 

620 # `_device` is a static field, out of `tree_at`'s reach, so modify the 

621 # fresh copy in place 

622 object.__setattr__(obj, '_device', None) 

623 return obj 

624 

625 def dump(self, path: str | PathLike) -> None: 

626 """Serialize the fitted model to a file with `pickle`. 

627 

628 Parameters 

629 ---------- 

630 path 

631 The file to write to. 

632 

633 Notes 

634 ----- 

635 Intended for short-term storage (e.g. caching across processes), not 

636 long-term archival: the format depends on the versions of bartz, jax and 

637 equinox. The arrays are copied to host memory and all device/sharding 

638 placement is dropped; `load` reconstructs a single-device model. 

639 """ 

640 # drop all device info (`Device` objects are not picklable), then 

641 # gather any sharded arrays to host (dropping their sharding); the 

642 # reload is single-device 

643 obj = self._drop_device_info() 

644 obj = jax.device_get(obj) 

645 with Path(path).open('wb') as file: 

646 pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) 

647 

648 @classmethod 

649 def load(cls, path: str | PathLike) -> 'Bart': 

650 """Load a model saved with `dump`. 

651 

652 Parameters 

653 ---------- 

654 path 

655 The file to read from. 

656 

657 Returns 

658 ------- 

659 The deserialized model, on host memory with no device placement. 

660 

661 Raises 

662 ------ 

663 TypeError 

664 If the file does not contain a `Bart` instance. 

665 """ 

666 with Path(path).open('rb') as file: 

667 obj = pickle.load(file) # noqa: S301, the user owns the file 

668 if not isinstance(obj, cls): 

669 msg = f'unpickled a {type(obj).__name__}, not a {cls.__name__}' 

670 raise TypeError(msg) 

671 return obj 

672 

673 @property 

674 def offset(self) -> Float32[Array, ''] | Float32[Array, ' k']: 

675 """The prior mean of the latent mean function.""" 

676 return self._mcmc_state.offset 

677 

678 @property 

679 def n_save(self) -> int: 

680 """The number of posterior samples after burn-in saved per chain.""" 

681 sample_axis = trace_sample_axes(self._main_trace).grow_prop_count 

682 return self._main_trace.grow_prop_count.shape[sample_axis] 

683 

684 @property 

685 def num_chains(self) -> int | None: 

686 """The number of chains, `None` if scalar.""" 

687 return self._mcmc_state.num_chains() 

688 

689 @property 

690 def ndpost(self) -> int: 

691 """The total number of posterior samples after burn-in across all chains.""" 

692 return self._main_trace.grow_prop_count.size 

693 

694 @property 

695 def num_trees(self) -> int: 

696 """Return the number of trees used in the model.""" 

697 forest = self._mcmc_state.forest 

698 chain_axis = chain_vmap_axes(forest).split_tree 

699 # chainless split_tree is (num_trees, half_tree_size); num_trees is core axis 0 

700 axis = chainful_axis(0, chain_axis) 

701 return forest.split_tree.shape[axis] 

702 

703 def get_latent_prec( 

704 self, only_continuous: bool = False 

705 ) -> ( 

706 Float32[Array, ' n_burn_plus_n_save'] 

707 | Float32[Array, 'n_burn_plus_n_save k k'] 

708 | Float32[Array, 'num_chains n_burn_plus_n_save'] 

709 | Float32[Array, 'num_chains n_burn_plus_n_save k k'] 

710 ): 

711 """Return the posterior samples of the latent error precision matrix. 

712 

713 Parameters 

714 ---------- 

715 only_continuous 

716 If `True` and the model has mixed binary-continuous outcomes, 

717 return only the submatrix for the continuous components. 

718 

719 Returns 

720 ------- 

721 MCMC samples of the error precision matrix. 

722 

723 Raises 

724 ------ 

725 ValueError 

726 If `only_continuous` is `True` but the model has only binary 

727 outcomes, so there is no continuous submatrix to return. 

728 

729 Notes 

730 ----- 

731 This method is meant to check for convergence, so it returns the full 

732 MCMC trace and does not concatenate chains together. For probit 

733 regression, this returns the precision of the latent error term, not 

734 the Bernoulli precision for the binary outcome. For heteroskedastic 

735 regression, the returned precision is the global precision parameter, 

736 that would have to be divided by a squared weight to get the precision 

737 on a given datapoint. 

738 """ 

739 binary_indices = self._mcmc_state.binary_indices 

740 if ( 

741 only_continuous 

742 and binary_indices is None 

743 and self._mcmc_state.binary_y is not None 

744 ): 

745 msg = 'Model has only binary outcomes, so there is no continuous submatrix to return.' 

746 raise ValueError(msg) 

747 

748 return get_latent_prec( 

749 self._burnin_trace, 

750 self._main_trace, 

751 binary_indices, 

752 only_continuous=only_continuous, 

753 ) 

754 

755 def get_error_sdev( 

756 self, mean: bool = False 

757 ) -> ( 

758 Float32[Array, ' ndpost'] 

759 | Float32[Array, 'ndpost k'] 

760 | Float32[Array, ''] 

761 | Float32[Array, ' k'] 

762 ): 

763 """Return the error standard deviation, post-burnin, chains concatenated. 

764 

765 Parameters 

766 ---------- 

767 mean 

768 If `True`, average the error covariance matrix across samples before 

769 taking the square root, returning a single scalar or vector instead 

770 of posterior samples. 

771 

772 Returns 

773 ------- 

774 Posterior samples (or single estimate) of the error standard deviation; NaN for binary outcomes. 

775 

776 Notes 

777 ----- 

778 Binary outcomes do have a standard deviation of course, but it's not 

779 returned by this method because that would require to evaluate 

780 predictions on a given X, since the Bernoulli variance is p(1-p). 

781 """ 

782 # binary outcomes are filled with NaN, so disable the NaN check 

783 with debug_nans(False): 

784 return get_error_sdev(self._main_trace, self._binary_mask, mean=mean) 

785 

786 @cached_property 

787 def varcount(self) -> Int32[Array, 'ndpost p']: 

788 """Histogram of predictor usage for decision rules in the trees.""" 

789 p = self._mcmc_state.forest.max_split.size 

790 return varcount(p, self._main_trace) 

791 

792 @cached_property 

793 def varcount_mean(self) -> Float32[Array, ' p']: 

794 """Average of `varcount` across MCMC iterations.""" 

795 return self.varcount.mean(axis=0) 

796 

797 @cached_property 

798 def varprob(self) -> Float32[Array, 'ndpost p']: 

799 """Posterior samples of the probability of choosing each predictor for a decision rule.""" 

800 return varprob(self._mcmc_state.forest.max_split, self._main_trace) 

801 

802 @cached_property 

803 def varprob_mean(self) -> Float32[Array, ' p']: 

804 """The marginal posterior probability of each predictor being chosen for a decision rule.""" 

805 return self.varprob.mean(axis=0) 

806 

807 def _process_error_scale_test( 

808 self, 

809 x_test: Real[ArrayLike, 'p m'] | DataFrame | str, 

810 kind: PredictKind, 

811 error_scale: Float[ArrayLike, ' m'] 

812 | Float[ArrayLike, 'k m'] 

813 | Series 

814 | DataFrame 

815 | None, 

816 ) -> Float32[Array, ' m'] | Float32[Array, 'k m'] | None: 

817 """Validate and resolve the error weights for prediction. 

818 

819 Parameters 

820 ---------- 

821 x_test 

822 The raw (not yet processed) test predictors, or ``'train'``. 

823 kind 

824 The prediction kind. 

825 error_scale 

826 User-provided per-observation error scale, or `None`. 

827 

828 Returns 

829 ------- 

830 The resolved error scale as a float32 array, or `None` if weights are not applicable. 

831 

832 Raises 

833 ------ 

834 ValueError 

835 If `error_scale` is specified when it should be `None`, or missing 

836 when required. 

837 """ 

838 x_test_is_train = isinstance(x_test, str) and x_test == 'train' 

839 has_train_weights = self._error_scale is not None 

840 is_binary = self._mcmc_state.binary_y is not None 

841 needs_weights = ( 

842 kind is PredictKind.outcome_samples and not is_binary and has_train_weights 

843 ) 

844 

845 if not needs_weights: 

846 if error_scale is not None: 846 ↛ 847line 846 didn't jump to line 847 because the condition on line 846 was never true

847 msg = ( 

848 '`error_scale` must be `None` in this configuration' 

849 " (it is used only with kind='outcome_samples'," 

850 ' continuous regression fitted with weights)' 

851 ) 

852 raise ValueError(msg) 

853 return None 

854 

855 if x_test_is_train: 

856 if error_scale is not None: 856 ↛ 857line 856 didn't jump to line 857 because the condition on line 856 was never true

857 msg = ( 

858 "`error_scale` must be `None` when x_test='train'" 

859 ' (training weights are used automatically)' 

860 ) 

861 raise ValueError(msg) 

862 return self._error_scale 

863 

864 # new test data, model was fit with weights 

865 if error_scale is None: 865 ↛ 866line 865 didn't jump to line 866 because the condition on line 865 was never true

866 msg = ( 

867 '`error_scale` is required because the model was fit with' 

868 ' weights and x_test is new data' 

869 ) 

870 raise ValueError(msg) 

871 error_scale_test = _process_response_input(error_scale) 

872 assert self._error_scale is not None # implied by needs_weights 

873 if error_scale_test.ndim != self._error_scale.ndim: 873 ↛ 874line 873 didn't jump to line 874 because the condition on line 873 was never true

874 msg = ( 

875 f'`error_scale` shape mismatch with training weights: got ' 

876 f'{error_scale_test.shape=}, expected {self._error_scale.ndim}D ' 

877 f'(matching the training-weight shape).' 

878 ) 

879 raise ValueError(msg) 

880 return error_scale_test 

881 

882 def _process_x_test( 

883 self, 

884 x_test: Real[ArrayLike, 'p m'] | DataFrame | str, 

885 error_scale: Float32[Array, ' m'] | Float32[Array, 'k m'] | None, 

886 ) -> UInt[Array, 'p m']: 

887 """Convert x_test to binned format suitable for prediction.""" 

888 if isinstance(x_test, str): 

889 if x_test != 'train': 889 ↛ 890line 889 didn't jump to line 890 because the condition on line 889 was never true

890 msg = ( 

891 f"x_test must be an array, a DataFrame, or 'train', got {x_test!r}" 

892 ) 

893 raise ValueError(msg) 

894 return self._mcmc_state.X 

895 x_test, x_test_fmt = _process_predictor_input(x_test) 

896 if x_test_fmt != self._x_train_fmt: 

897 msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}' 

898 raise ValueError(msg) 

899 if error_scale is not None: 

900 _check_same_length(error_scale, x_test) 

901 return self._binner.bin(x_test) 

902 

903 def _device_put_test( 

904 self, 

905 x_test: UInt[Array, 'p m'], 

906 error_scale: Float32[Array, ' m'] | Float32[Array, 'k m'] | None, 

907 ) -> tuple[UInt[Array, 'p m'], Float32[Array, ' m'] | Float32[Array, 'k m'] | None]: 

908 """Place new test data on the devices of the model. 

909 

910 Mirror the placement of the training data done at fit time: shard over 

911 the mesh if there is one (the observation axis over 'data'), else move 

912 to the device requested explicitly at construction, if any. The inputs 

913 are donated, so they must not be used elsewhere. 

914 """ 

915 mesh = self._mcmc_state.config.mesh 

916 if mesh is not None: 

917 put = lambda a: device_put( 

918 a, 

919 NamedSharding(mesh, _leaf_partition_spec(a.ndim, None, -1, mesh)), 

920 donate=True, 

921 ) 

922 elif self._device is not None: 

923 put = lambda a: device_put(a, self._device, donate=True) 

924 else: 

925 return x_test, error_scale 

926 if error_scale is None: 

927 return put(x_test), None 

928 else: 

929 return put(x_test), put(error_scale) 

930 

931 def _check_trees( 

932 self, error: bool = False 

933 ) -> UInt[Array, 'num_chains n_save num_trees']: 

934 """Apply `bartz.grove.check_trace` to all the tree draws. 

935 

936 Parameters 

937 ---------- 

938 error 

939 If `True`, throw an error if any invalid trees are found. 

940 

941 Returns 

942 ------- 

943 An array where non-zero entries indicate invalid trees. 

944 

945 Raises 

946 ------ 

947 RuntimeError 

948 If `error` is `True` and any invalid trees are found. 

949 """ 

950 out = check_trees(self._main_trace, self._mcmc_state.forest.max_split) 

951 if error: 

952 bad_count = jnp.count_nonzero(out).item() 

953 if bad_count > 0: 

954 msg = f'Found {bad_count} invalid trees in the MCMC trace.' 

955 raise RuntimeError(msg) 

956 return out 

957 

958 def _tree_goes_bad(self) -> Bool[Array, 'num_chains n_save num_trees']: 

959 """Find iterations where a tree becomes invalid. 

960 

961 Returns 

962 ------- 

963 An array where ``(i, j, k)`` is `True` if tree `k` is invalid at iteration `j` in chain `i` but not at iteration ``j - 1``. 

964 """ 

965 return tree_goes_bad(self._main_trace, self._mcmc_state.forest.max_split) 

966 

967 def _check_replicated_trees(self) -> None: 

968 """Check that the trees are equal across data-sharded devices. 

969 

970 If the data is sharded across devices, verify that the trees (which 

971 should be replicated) are identical on all shards. 

972 

973 Raises 

974 ------ 

975 RuntimeError 

976 If the trees differ across devices. 

977 """ 

978 state = self._mcmc_state 

979 mesh = state.config.mesh 

980 if mesh is not None and 'data' in mesh.axis_names: 

981 # drop the data-sharded `leaf_indices` (not replicated) before the 

982 # cross-shard equality check; `None` is a deliberately off-type 

983 # placeholder, so use `tree_at`, which (unlike `dataclasses.replace`) 

984 # bypasses the `__init__` type checks 

985 replicated_forest = tree_at(lambda f: f.leaf_indices, state.forest, None) 

986 equal = equal_shards( 

987 replicated_forest, 'data', in_specs=PartitionSpec(), mesh=mesh 

988 ) 

989 equal_array = jnp.stack(tree.leaves(equal)) 

990 all_equal = jnp.all(equal_array) 

991 if not all_equal.item(): 991 ↛ 992line 991 didn't jump to line 992 because the condition on line 991 was never true

992 msg = 'The trees differ across data-sharded devices.' 

993 raise RuntimeError(msg) 

994 

995 def _compare_resid( 

996 self, y: Float32[Array, ' n'] | Float32[Array, 'k n'] | None = None 

997 ) -> tuple[ 

998 Float32[Array, '*num_chains n'] | Float32[Array, '*num_chains k n'], 

999 Float32[Array, '*num_chains n'] | Float32[Array, '*num_chains k n'], 

1000 ]: 

1001 """Re-compute residuals to compare them with the updated ones. 

1002 

1003 Parameters 

1004 ---------- 

1005 y 

1006 The response variable. Required for continuous regression (since 

1007 ``State`` does not store ``y`` in continuous mode). Ignored for 

1008 binary regression (where ``State.z`` is used instead). 

1009 

1010 Returns 

1011 ------- 

1012 resid1 

1013 The final state of the residuals updated during the MCMC. 

1014 resid2 

1015 The residuals computed from the final state of the trees. 

1016 """ 

1017 state = self._mcmc_state 

1018 if state.binary_indices is not None: 

1019 assert y is not None, 'y is required for mixed regression' 

1020 elif state.z is None: 

1021 assert y is not None, 'y is required for continuous regression' 

1022 y_arr = jnp.asarray(y) if y is not None else None 

1023 return compare_resid(state, y_arr) 

1024 

1025 def _depth_distr(self) -> Int32[Array, '*num_chains n_save d']: 

1026 """Histogram of tree depths for each state of the trees. 

1027 

1028 Returns 

1029 ------- 

1030 A matrix where each row contains a histogram of tree depths. 

1031 """ 

1032 return depth_distr(self._main_trace) 

1033 

1034 def _points_per_node_distr( 

1035 self, node_type: Literal['leaf', 'leaf-parent'] 

1036 ) -> Int32[Array, '*num_chains n_save n_plus_1']: 

1037 return points_per_node_distr_trace( 

1038 self._mcmc_state.X, self._main_trace, node_type 

1039 ) 

1040 

1041 def _points_per_decision_node_distr( 

1042 self, 

1043 ) -> Int32[Array, '*num_chains n_save n_plus_1']: 

1044 """Histogram of number of points belonging to parent-of-leaf nodes. 

1045 

1046 Returns 

1047 ------- 

1048 For each chain, a matrix where each row contains a histogram of number of points. 

1049 """ 

1050 return self._points_per_node_distr('leaf-parent') 

1051 

1052 def _points_per_leaf_distr(self) -> Int32[Array, '*num_chains n_save n_plus_1']: 

1053 """Histogram of number of points belonging to leaves. 

1054 

1055 Returns 

1056 ------- 

1057 A matrix where each row contains a histogram of number of points. 

1058 """ 

1059 return self._points_per_node_distr('leaf') 

1060 

1061 def _print_tree( 

1062 self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False 

1063 ) -> None: 

1064 """Print a single tree in human-readable format. 

1065 

1066 Parameters 

1067 ---------- 

1068 i_chain 

1069 The index of the MCMC chain. 

1070 i_sample 

1071 The index of the (post-burnin) sample in the chain. 

1072 i_tree 

1073 The index of the tree in the sample. 

1074 print_all 

1075 If `True`, also print the content of unused node slots. 

1076 """ 

1077 trace = self._main_trace 

1078 trees = _trees_chain_first(trace) 

1079 chain_index = i_chain if trace.has_chains else ... 

1080 trees = tree.map(lambda x: x[chain_index, i_sample, i_tree, :], trees) 

1081 s = format_tree(trees, print_all=print_all) 

1082 print(s) # noqa: T201, this method is intended for debug 

1083 

1084 

1085def _process_predictor_input( 

1086 x: Real[ArrayLike, 'p n'] | DataFrame, 

1087) -> tuple[Shaped[Array, 'p n'], Any]: 

1088 if isinstance(x, DataFrame): 

1089 fmt = dict(kind='dataframe', columns=x.columns) 

1090 x = x.to_numpy().T 

1091 else: 

1092 fmt = dict(kind='array', num_covar=x.shape[0]) 

1093 x = jnp.asarray(x) 

1094 assert x.ndim == 2 

1095 return x, fmt 

1096 

1097 

1098@overload 

1099def _process_response_input( 

1100 arr: Shaped[ArrayLike, ' n'] | Shaped[ArrayLike, 'k n'] | Series | DataFrame, 

1101 /, 

1102 *, 

1103 keep: Literal[False] = False, 

1104 dtype: DTypeLike = jnp.float32, 

1105) -> Shaped[Array, ' n'] | Shaped[Array, 'k n']: ... 

1106 

1107 

1108@overload 

1109def _process_response_input( 1109 ↛ anywhereline 1109 didn't jump anywhere: it always raised an exception.

1110 arr: Shaped[ArrayLike, ' n'] | Shaped[ArrayLike, 'k n'] | Series | DataFrame, 

1111 /, 

1112 *, 

1113 keep: Literal[True], 

1114 dtype: DTypeLike = jnp.float32, 

1115) -> tuple[ 

1116 Shaped[Array, ' n'] | Shaped[Array, 'k n'], 

1117 Shaped[Array, ' n'] | Shaped[Array, 'k n'], 

1118]: ... 

1119 

1120 

1121def _process_response_input( 

1122 arr: Shaped[ArrayLike, ' n'] | Shaped[ArrayLike, 'k n'] | Series | DataFrame, 

1123 /, 

1124 *, 

1125 keep: bool = False, 

1126 dtype: DTypeLike = jnp.float32, 

1127) -> ( 

1128 Shaped[Array, ' n'] 

1129 | Shaped[Array, 'k n'] 

1130 | tuple[ 

1131 Shaped[Array, ' n'] | Shaped[Array, 'k n'], 

1132 Shaped[Array, ' n'] | Shaped[Array, 'k n'], 

1133 ] 

1134): 

1135 if isinstance(arr, DataFrame): 

1136 arr = arr.to_numpy().T 

1137 elif isinstance(arr, Series): 

1138 arr = arr.to_numpy() 

1139 # in normal mode: one unconditional copy, safe to donate downstream. 

1140 # in `keep` mode: convert without copying when possible to get the 

1141 # keep array, then `jnp.copy` to make a separate disposable copy. 

1142 arr = jnp.array(arr, dtype, copy=not keep) 

1143 if arr.ndim < 1 or arr.ndim > 2: 1143 ↛ 1144line 1143 didn't jump to line 1144 because the condition on line 1143 was never true

1144 msg = f'response-like input must be 1D (n,) or 2D (k, n). Got {arr.ndim=}.' 

1145 raise ValueError(msg) 

1146 if keep: 

1147 return jnp.copy(arr), arr 

1148 return arr 

1149 

1150 

1151def _check_same_length(x1: Shaped[Array, '... n'], x2: Shaped[Array, '... n']) -> None: 

1152 get_length = lambda x: x.shape[-1] 

1153 assert get_length(x1) == get_length(x2) 

1154 

1155 

1156def _check_type_settings( 

1157 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'], 

1158 outcome_type: OutcomeType | str | Sequence[OutcomeType | str], 

1159 error_scale: Float[Array, ' n'] | Float[Array, 'k n'] | None, 

1160) -> tuple[OutcomeType | tuple[OutcomeType, ...], Bool[Array, ''] | Bool[Array, ' k']]: 

1161 # standardize outcome_type to OutcomeType or tuple[OutcomeType, ...] 

1162 if isinstance(outcome_type, Sequence) and not isinstance(outcome_type, str): 

1163 outcome_type = tuple(OutcomeType(t) for t in outcome_type) 

1164 num_types = len(outcome_type) 

1165 if len(set(outcome_type)) == 1: 

1166 outcome_type = outcome_type[0] 

1167 else: 

1168 num_types = None 

1169 outcome_type = OutcomeType(outcome_type) 

1170 

1171 # validation 

1172 if num_types is not None and (y_train.ndim != 2 or num_types != y_train.shape[0]): 

1173 msg = ( 

1174 f'Sequence outcome_type of length {num_types}' 

1175 f' requires y_train.shape=({num_types}, n),' 

1176 f' found {y_train.shape=}.' 

1177 ) 

1178 raise ValueError(msg) 

1179 if error_scale is not None and outcome_type is not OutcomeType.continuous: 

1180 msg = 'Weights are not supported when any outcome is binary.' 

1181 raise ValueError(msg) 

1182 if ( 1182 ↛ 1187line 1182 didn't jump to line 1187 because the condition on line 1182 was never true

1183 error_scale is not None 

1184 and error_scale.ndim == 2 

1185 and (y_train.ndim != 2 or error_scale.shape[0] != y_train.shape[0]) 

1186 ): 

1187 msg = ( 

1188 f'2D error_scale (vector per-component weights) requires y_train of ' 

1189 f'shape (k, n) with matching k; got {error_scale.shape=}, ' 

1190 f'{y_train.shape=}.' 

1191 ) 

1192 raise ValueError(msg) 

1193 

1194 if isinstance(outcome_type, tuple): 

1195 binary_mask = jnp.array([t is OutcomeType.binary for t in outcome_type]) 

1196 else: 

1197 binary_mask = jnp.bool_(outcome_type is OutcomeType.binary) 

1198 binary_mask = jnp.broadcast_to(binary_mask, y_train.shape[:-1]) 

1199 

1200 return outcome_type, binary_mask 

1201 

1202 

1203def _process_sparsity_settings( 

1204 x_train: Real[Array, 'p n'], sparse: SparseConfig 

1205) -> ( 

1206 tuple[None, None, None, None] 

1207 | tuple[FloatLike, None, None, None] 

1208 | tuple[None, FloatLike, FloatLike, FloatLike] 

1209): 

1210 """Return (theta, a, b, rho).""" 

1211 if not sparse.enabled: 

1212 return None, None, None, None 

1213 elif sparse.theta is not None: 

1214 return sparse.theta, None, None, None 

1215 else: 

1216 rho = sparse.rho 

1217 if rho is None: 

1218 p, _ = x_train.shape 

1219 rho = float(p) 

1220 return None, sparse.a, sparse.b, rho 

1221 

1222 

1223def _process_offset_settings( 

1224 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'], 

1225 binary_mask: Bool[Array, ''] | Bool[Array, ' k'], 

1226 offset: FloatLike | Float[ArrayLike, ' k'] | None, 

1227) -> Float32[Array, ''] | Float32[Array, ' k']: 

1228 """Return offset.""" 

1229 if offset is not None: 

1230 off = jnp.asarray(offset, jnp.float32) 

1231 return jnp.broadcast_to(off, y_train.shape[:-1]) 

1232 if y_train.shape[-1] < 1: 

1233 return jnp.zeros(y_train.shape[:-1]) 

1234 

1235 bound = 1 / (1 + y_train.shape[-1]) 

1236 binary_offset = ndtri(jnp.clip((y_train != 0).mean(-1), bound, 1 - bound)) 

1237 continuous_offset = y_train.mean(-1) 

1238 return jnp.where(binary_mask, binary_offset, continuous_offset) 

1239 

1240 

1241def _process_leaf_variance_settings( 

1242 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'], 

1243 binary_mask: Bool[Array, ''] | Bool[Array, ' k'], 

1244 k: FloatLike, 

1245 num_trees: int, 

1246 tau_num: FloatLike | None, 

1247) -> Float32[Array, ''] | Float32[Array, 'k k']: 

1248 """Return `leaf_prior_cov_inv`.""" 

1249 # determine `tau_num` if not specified 

1250 if tau_num is None: 

1251 if y_train.shape[-1] < 2: 

1252 continuous_tau = jnp.ones(y_train.shape[:-1]) 

1253 else: 

1254 continuous_tau = (y_train.max(-1) - y_train.min(-1)) / 2 

1255 tau_num = jnp.where(binary_mask, 3.0, continuous_tau) 

1256 

1257 # leaf prior standard deviation 

1258 sigma_mu = tau_num / (k * math.sqrt(num_trees)) 

1259 

1260 # leaf prior precision matrix 

1261 leaf_prior_cov_inv = jnp.reciprocal(jnp.square(sigma_mu)) 

1262 if y_train.ndim == 2: 

1263 leaf_prior_cov_inv = jnp.diag( 

1264 jnp.broadcast_to(leaf_prior_cov_inv, y_train.shape[:-1]) 

1265 ) 

1266 return leaf_prior_cov_inv 

1267 

1268 

1269def _process_error_variance_settings( 

1270 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'], 

1271 outcome_type: OutcomeType | tuple[OutcomeType, ...], 

1272 binary_mask: Bool[Array, ''] | Bool[Array, ' k'], 

1273 missing: Bool[Array, ' n'] | Bool[Array, 'k n'] | None, 

1274 sigma_df: FloatLike, 

1275 sigma_scale: FloatLike | Float[ArrayLike, ' k'] | Literal['auto'], 

1276 sigma_init: FloatLike | Float[ArrayLike, ' k'] | Literal['auto'], 

1277 error_scale: Float32[Array, ' n'] | Float32[Array, 'k n'] | None, 

1278) -> Wishart | None: 

1279 """Build the error precision prior from the user settings.""" 

1280 if outcome_type is OutcomeType.binary: 

1281 if not isinstance(sigma_scale, str) or not isinstance(sigma_init, str): 

1282 msg = ( 

1283 'Do not set `sigma_scale` or `sigma_init` for binary regression, ' 

1284 'they are ignored' 

1285 ) 

1286 raise ValueError(msg) 

1287 return None 

1288 

1289 *kdims, _ = y_train.shape # () or (k,) 

1290 k = kdims[0] if kdims else 1 

1291 nu = jnp.asarray(sigma_df, jnp.float32) + (k - 1) 

1292 

1293 # guarded per-component variance of y_train, computed only when an 'auto' 

1294 # spec needs it (this function is not jitted, so it would not be elided) 

1295 if isinstance(sigma_scale, str) or isinstance(sigma_init, str): 

1296 vary = _guarded_response_variance(y_train, error_scale, missing) 

1297 else: 

1298 vary = None 

1299 

1300 # prior rate: E[precision] = nu / rate, so rate = nu * var per component 

1301 rate_diag = jnp.where( 

1302 binary_mask, 0.0, nu * _resolve_error_variance(sigma_scale, vary, kdims) 

1303 ) 

1304 

1305 # initial precision = 1 / var per component (1 for binary components) 

1306 init_var = _resolve_error_variance(sigma_init, vary, kdims) 

1307 init_diag = jnp.where(binary_mask, 1.0, jnp.reciprocal(init_var)) 

1308 

1309 if y_train.ndim == 2: 

1310 rate, init = jnp.diag(rate_diag), jnp.diag(init_diag) 

1311 else: 

1312 rate, init = rate_diag, init_diag 

1313 return make_error_cov_prior(nu, rate, init, outcome_type, missing) 

1314 

1315 

1316@jit 

1317def _guarded_response_variance( 

1318 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'], 

1319 error_scale: Float32[Array, ' n'] | Float32[Array, 'k n'] | None, 

1320 missing: Bool[Array, ' n'] | Bool[Array, 'k n'] | None, 

1321) -> Float32[Array, '*k']: 

1322 """Per-component variance of `y_train`, used by the 'auto' error scale. 

1323 

1324 A precision-weighted variance (precision ``1 / error_scale ** 2``) estimates 

1325 the unit-weight ``sigma ** 2``; `missing` entries are dropped. The variance 

1326 is guarded to 1 when undefined (fewer than 2 valid points) or non-positive. 

1327 """ 

1328 if error_scale is None and missing is None: 

1329 vary = jnp.var(y_train, axis=-1) 

1330 return jnp.where(vary > 0, vary, 1.0) 

1331 else: 

1332 prec = ( 

1333 jnp.ones(()) 

1334 if error_scale is None 

1335 else jnp.reciprocal(jnp.square(error_scale)) 

1336 ) 

1337 if missing is not None: 

1338 prec = jnp.where(missing, 0.0, prec) 

1339 y_train = jnp.where(missing, 0.0, y_train) 

1340 n_valid = jnp.count_nonzero(prec, axis=-1) 

1341 wmean = jnp.sum(prec * y_train, axis=-1) / jnp.sum(prec, axis=-1) 

1342 sqdev = prec * jnp.square(y_train - wmean[..., None]) 

1343 vary = jnp.sum(sqdev, axis=-1) / n_valid 

1344 # guard on n_valid too: with a single valid point the variance is 0 in 

1345 # exact arithmetic, but float rounding in wmean can leave a tiny 

1346 # positive vary that would slip past the `vary > 0` guard 

1347 return jnp.where((n_valid > 1) & (vary > 0), vary, 1.0) 

1348 

1349 

1350def _resolve_error_variance( 

1351 spec: FloatLike | Float[ArrayLike, ' k'] | Literal['auto'], 

1352 vary: Float32[Array, '*k'] | None, 

1353 shape: Sequence[int], 

1354) -> Float32[Array, '*k']: 

1355 """Per-component error variance from a scale spec ('auto' uses var(y)).""" 

1356 if isinstance(spec, str): 

1357 if spec != 'auto': 

1358 msg = f"unrecognized value {spec!r}, expected 'auto' or a number" 

1359 raise ValueError(msg) 

1360 assert vary is not None # computed iff some spec is 'auto' 

1361 return vary 

1362 else: 

1363 return jnp.broadcast_to(jnp.square(jnp.asarray(spec, jnp.float32)), shape) 

1364 

1365 

1366def make_error_cov_prior( 

1367 nu: Float32[Array, ''], 

1368 rate: Float32[Array, ''] | Float32[Array, 'k k'], 

1369 value: Float32[Array, ''] | Float32[Array, 'k k'], 

1370 outcome_type: OutcomeType | tuple[OutcomeType, ...], 

1371 missing: Bool[Array, ' n'] | Bool[Array, 'k n'] | None, 

1372) -> Wishart: 

1373 """Build the error precision prior, diagonal-constrained where required. 

1374 

1375 Mixed binary-continuous and partial-missing (2-D mask) regression restrict 

1376 the error covariance to diagonal, so they take a `DiagWishart`; the dense 

1377 cases take a `Wishart`. `init` re-checks this choice. `value` is the initial 

1378 value of the precision. 

1379 """ 

1380 if isinstance(outcome_type, tuple): 

1381 binary = [t is OutcomeType.binary for t in outcome_type] 

1382 is_mixed = any(binary) and not all(binary) 

1383 else: 

1384 is_mixed = False 

1385 # a 2-D missingness mask only occurs with multivariate y (checked in `init`) 

1386 partial_missing = missing is not None and missing.ndim == 2 

1387 if is_mixed or partial_missing: 

1388 return DiagWishart(nu=nu, rate=rate, value=value) 

1389 else: 

1390 return Wishart(nu=nu, rate=rate, value=value) 

1391 

1392 

1393def _setup_mcmc( 

1394 x_train: Real[Array, 'p n'], 

1395 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'], 

1396 outcome_type: OutcomeType | tuple[OutcomeType, ...], 

1397 offset: Float32[Array, ''] | Float32[Array, ' k'], 

1398 error_scale: Float[Array, ' n'] | Float[Array, 'k n'] | None, 

1399 missing: Bool[Array, ' n'] | Bool[Array, 'k n'] | None, 

1400 max_split: UInt[Array, ' p'], 

1401 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

1402 error_cov_inv: Wishart | None, 

1403 power: FloatLike, 

1404 base: FloatLike, 

1405 maxdepth: int, 

1406 num_trees: int, 

1407 init_kw: Mapping[str, Any], 

1408 rm_const: bool, 

1409 sparse: SparseConfig, 

1410 varprob: Float[ArrayLike, ' p'] | None, 

1411 num_chains: int | None, 

1412 num_chain_devices: int | None | Literal['auto'], 

1413 num_data_devices: int | None, 

1414 devices: Literal['cpu', 'gpu'] | Device | Sequence[Device] | None, 

1415 n_burn: int, 

1416 mcmc_key: Key[Array, ''], 

1417) -> tuple[State, Key[Array, ''], Device | None]: 

1418 p_nonterminal = make_p_nonterminal(maxdepth, base, power) 

1419 

1420 # resolve the sparsity prior hyperparameters 

1421 theta, a, b, rho = _process_sparsity_settings(x_train, sparse) 

1422 

1423 # process device settings 

1424 device_kw, device = process_device_settings( 

1425 y_train, num_chains, num_chain_devices, num_data_devices, devices 

1426 ) 

1427 

1428 kw: dict = dict( 

1429 X=x_train, 

1430 y=y_train, 

1431 outcome_type=outcome_type, 

1432 offset=offset, 

1433 error_scale=error_scale, 

1434 missing=missing, 

1435 max_split=max_split, 

1436 num_trees=num_trees, 

1437 p_nonterminal=p_nonterminal, 

1438 leaf_prior_cov_inv=leaf_prior_cov_inv, 

1439 error_cov_inv=error_cov_inv, 

1440 min_points_per_decision_node=10, 

1441 log_s=process_varprob(varprob, max_split), 

1442 theta=theta, 

1443 a=a, 

1444 b=b, 

1445 rho=rho, 

1446 sparse_on_at=n_burn // 2 if sparse.enabled else None, 

1447 augment=sparse.augment, 

1448 **device_kw, 

1449 ) 

1450 

1451 if rm_const: 

1452 n_empty = jnp.sum(max_split == 0).item() 

1453 kw.update(filter_splitless_vars=n_empty) 

1454 

1455 kw.update(init_kw) 

1456 

1457 state = init(**kw) 

1458 

1459 # put state and mcmc key on device if requested explicitly by the user 

1460 if device is not None: 

1461 mcmc_key, state = device_put((mcmc_key, state), device, donate=True) 

1462 

1463 return state, mcmc_key, device 

1464 

1465 

1466def _run_mcmc( 

1467 mcmc_state: State, 

1468 n_save: int, 

1469 n_burn: int, 

1470 n_skip: int, 

1471 printevery: int | None, 

1472 pbar: bool, 

1473 key: Key[Array, ''], 

1474 run_mcmc_kw: Mapping, 

1475) -> RunMCMCResult: 

1476 # prepare arguments 

1477 kw: dict = dict(n_burn=n_burn, n_skip=n_skip, inner_loop_length=printevery) 

1478 # `printevery=None` disables progress reporting entirely: no callback is 

1479 # installed, so the loop traces without any `debug.callback` effect (a tqdm 

1480 # bar would otherwise advance every iteration regardless of `printevery`). 

1481 if printevery is not None: 

1482 if pbar: 

1483 kw.update(make_tqdm_callback(mcmc_state, report_every=printevery)) 

1484 else: 

1485 kw.update( 

1486 make_print_callback( 

1487 mcmc_state, 

1488 dot_every=None if printevery == 1 else 1, 

1489 report_every=printevery, 

1490 ) 

1491 ) 

1492 kw.update(run_mcmc_kw) 

1493 

1494 return run_mcmc(key, mcmc_state, n_save, **kw) 

1495 

1496 

1497@jit(static_argnames='p') 

1498# this is jitted such that lax.collapse below does not create a copy 

1499def varcount(p: int, trace: MainTrace) -> Int32[Array, 'ndpost p']: 

1500 """Histogram of predictor usage for decision rules in the trees, squashing chains.""" 

1501 varcount: Int32[Array, '*chains samples p'] 

1502 varcount = compute_varcount(p, trace, out_chain_axis=0) 

1503 return lax.collapse(varcount, 0, -1) 

1504 

1505 

1506@jit(static_argnames='mean') 

1507def get_error_sdev( 

1508 trace: MainTrace, 

1509 binary_mask: Bool[Array, ''] | Bool[Array, ' k'], 

1510 *, 

1511 mean: bool = False, 

1512) -> ( 

1513 Float32[Array, ' ndpost'] 

1514 | Float32[Array, 'ndpost k'] 

1515 | Float32[Array, ''] 

1516 | Float32[Array, ' k'] 

1517): 

1518 """Error standard deviation, post-burnin, chains concatenated.""" 

1519 prec = trace.error_cov_inv 

1520 if trace.has_chains: 

1521 # shape (chains, samples) or (chains, samples, k, k), concatenate chains 

1522 prec = chain_to_axis(prec, chain_vmap_axes(trace).error_cov_inv) 

1523 prec = lax.collapse(prec, 0, 2) 

1524 is_uv = prec.ndim == 1 

1525 if is_uv: 

1526 # univariate case, reshape to 1x1 matrix 

1527 prec = prec[..., None, None] 

1528 

1529 # invert precision to covariance, then take diagonal variance 

1530 cov = _inv_via_chol_with_gersh(prec) 

1531 var = jnp.diagonal(cov, axis1=-2, axis2=-1) 

1532 if mean: 

1533 var = var.mean(0) 

1534 sdev = jnp.sqrt(var) 

1535 if is_uv: 

1536 sdev = sdev.squeeze(-1) 

1537 return jnp.where(binary_mask, jnp.nan, sdev) 

1538 

1539 

1540@jit(static_argnames='only_continuous') 

1541def get_latent_prec( 

1542 burnin_trace: BurninTrace, 

1543 main_trace: MainTrace, 

1544 binary_indices: Int32[Array, ' kb'] | None, 

1545 *, 

1546 only_continuous: bool = False, 

1547) -> ( 

1548 Float32[Array, ' n_burn_plus_n_save'] 

1549 | Float32[Array, 'n_burn_plus_n_save k k'] 

1550 | Float32[Array, 'num_chains n_burn_plus_n_save'] 

1551 | Float32[Array, 'num_chains n_burn_plus_n_save k k'] 

1552): 

1553 """Latent error precision trace, burn-in + main concatenated.""" 

1554 burnin = burnin_trace.error_cov_inv 

1555 main = main_trace.error_cov_inv 

1556 sample_axis = trace_sample_axes(main_trace).error_cov_inv 

1557 prec = jnp.concatenate([burnin, main], axis=sample_axis) 

1558 prec = chain_to_axis(prec, chain_vmap_axes(main_trace).error_cov_inv) 

1559 if only_continuous and binary_indices is not None: 

1560 *_, k, _ = prec.shape 

1561 kc = k - binary_indices.size 

1562 mask = jnp.ones(k, dtype=bool).at[binary_indices].set(False) 

1563 (cont_indices,) = jnp.nonzero(mask, size=kc) 

1564 prec = prec[..., cont_indices[:, None], cont_indices[None, :]] 

1565 return prec 

1566 

1567 

1568@jit 

1569def varprob( 

1570 max_split: UInt[Array, ' p'], trace: MainTrace 

1571) -> Float32[Array, 'ndpost p']: 

1572 """Posterior samples of predictor selection probability, chains concatenated.""" 

1573 p = max_split.size 

1574 varprob = trace.varprob 

1575 if varprob is None: 

1576 ndpost = trace.grow_prop_count.size 

1577 peff = jnp.count_nonzero(max_split) 

1578 out = jnp.where(max_split, 1 / peff, 0) 

1579 return jnp.broadcast_to(out, (ndpost, p)) 

1580 varprob = chain_to_axis(varprob, chain_vmap_axes(trace).varprob) 

1581 return varprob.reshape(-1, p) 

1582 

1583 

1584def _trees_chain_first(obj: TreeHeaps) -> TreesTrace: 

1585 """Extract `obj`'s heap arrays, moving any chain axis to the front. 

1586 

1587 Returns a `TreesTrace` whose leading axis is the chain axis when `obj` 

1588 carries one, and the bare per-object heap arrays otherwise. 

1589 """ 

1590 trees = TreesTrace.from_dataclass(obj) 

1591 if get_has_chains(obj): 

1592 axes = trees.axes_from_dataclass(chain_vmap_axes(obj)) 

1593 # WORKAROUND(python<3.14): use operator.is_none 

1594 trees = tree.map(chain_to_axis, trees, axes, is_leaf=lambda x: x is None) 

1595 return trees 

1596 

1597 

1598@jit 

1599def check_trees( 

1600 trace: MainTrace, max_split: UInt[Array, ' p'] 

1601) -> UInt[Array, 'num_chains n_save num_trees']: 

1602 """Apply `bartz.grove.check_trace` to all the tree draws.""" 

1603 trees = _trees_chain_first(trace) 

1604 out: UInt[Array, '*chains samples num_trees'] 

1605 out = check_trace(trees, max_split) 

1606 if out.ndim < 3: 

1607 out = out[None, :, :] 

1608 return out 

1609 

1610 

1611@jit 

1612def tree_goes_bad( 

1613 trace: MainTrace, max_split: UInt[Array, ' p'] 

1614) -> Bool[Array, 'num_chains n_save num_trees']: 

1615 """Find iterations where a tree becomes invalid.""" 

1616 bad = check_trees(trace, max_split).astype(bool) 

1617 bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)]) 

1618 return bad & ~bad_before 

1619 

1620 

1621@jit 

1622def compare_resid( 

1623 state: State, y: Float32[Array, ' n'] | Float32[Array, 'k n'] | None 

1624) -> tuple[ 

1625 Float32[Array, '*num_chains n'] | Float32[Array, '*num_chains k n'], 

1626 Float32[Array, '*num_chains n'] | Float32[Array, '*num_chains k n'], 

1627]: 

1628 """Re-compute residuals to compare them with the updated ones.""" 

1629 chain_axes = chain_vmap_axes(state) 

1630 resid1 = chain_to_axis(state.resid, chain_axes.resid) 

1631 z = chain_to_axis(state.z, chain_axes.z) if state.z is not None else None 

1632 

1633 forests = _trees_chain_first(state.forest) 

1634 trees = evaluate_forest(state.X, forests, sum_batch_axis=-1) 

1635 

1636 if state.binary_indices is not None: 

1637 # mixed binary-continuous: z has only binary rows, y has all rows 

1638 assert y is not None 

1639 ref = jnp.broadcast_to(y, resid1.shape) 

1640 ref = ref.at[..., state.binary_indices, :].set(z) 

1641 elif z is not None: 

1642 ref = z 

1643 else: 

1644 assert y is not None 

1645 ref = y 

1646 resid2 = ref - (trees + state.offset[..., None]) 

1647 

1648 return resid1, resid2 

1649 

1650 

1651@jit 

1652def depth_distr(trace: MainTrace) -> Int32[Array, '*num_chains n_save d']: 

1653 """Histogram of tree depths for each state of the trees.""" 

1654 split_tree = chain_to_axis(trace.split_tree, chain_vmap_axes(trace).split_tree) 

1655 out: Int32[Array, '*chains samples d'] 

1656 out = forest_depth_distr(split_tree) 

1657 if out.ndim < 3: 1657 ↛ 1659line 1657 didn't jump to line 1659 because the condition on line 1657 was always true

1658 out = out[None, :, :] 

1659 return out 

1660 

1661 

1662@jit(static_argnames='node_type') 

1663def points_per_node_distr_trace( 

1664 X: UInt[Array, 'p n'], trace: MainTrace, node_type: Literal['leaf', 'leaf-parent'] 

1665) -> Int32[Array, '*num_chains n_save n+1']: 

1666 """Histogram of number of points per node, for every tree draw in the trace.""" 

1667 chain_axes = chain_vmap_axes(trace) 

1668 var_tree = chain_to_axis(trace.var_tree, chain_axes.var_tree) 

1669 split_tree = chain_to_axis(trace.split_tree, chain_axes.split_tree) 

1670 out: Int32[Array, '*chains samples n+1'] 

1671 out = points_per_node_distr(X, var_tree, split_tree, node_type, sum_batch_axis=-1) 

1672 if out.ndim < 3: 

1673 out = out[None, :, :] 

1674 return out 

1675 

1676 

1677class DeviceKwArgs(TypedDict): 

1678 num_chains: int | None 

1679 mesh: Mesh | None 

1680 

1681 

1682def process_device_settings( 

1683 y_train: Shaped[Array, '...'], 

1684 num_chains: int | None, 

1685 num_chain_devices: int | None | Literal['auto'], 

1686 num_data_devices: int | None, 

1687 devices: Literal['cpu', 'gpu'] | Device | Sequence[Device] | None, 

1688) -> tuple[DeviceKwArgs, Device | None]: 

1689 """Return the arguments for `mcmcstep.init` related to devices, and an optional device where to put the state.""" 

1690 # whether the user pinned a concrete pool of devices (vs. inheriting all of 

1691 # the platform's devices); the auto chain sharding may not exceed that pool 

1692 explicit_devices = devices is not None and not isinstance(devices, str) 

1693 platform, device, devices = _determine_devices(y_train, devices) 

1694 num_chain_devices = _determine_num_chain_devices( 

1695 platform, 

1696 num_chains, 

1697 num_chain_devices, 

1698 num_data_devices, 

1699 len(devices), 

1700 explicit_devices, 

1701 ) 

1702 mesh, device = _determine_mesh(num_chain_devices, num_data_devices, device, devices) 

1703 

1704 # prepare arguments to `init` 

1705 settings = DeviceKwArgs(num_chains=num_chains, mesh=mesh) 

1706 

1707 return settings, device 

1708 

1709 

1710def _determine_devices( 

1711 y_train: Shaped[Array, '...'], 

1712 devices: Literal['cpu', 'gpu'] | Device | Sequence[Device] | None, 

1713) -> tuple[str, Device | None, Sequence[Device]]: 

1714 """Determine the target platform and set of devices for the MCMC, and possibly a single target device.""" 

1715 if isinstance(devices, str): 

1716 platform = devices 

1717 devices = jax.devices(platform) 

1718 return platform, devices[0], devices 

1719 elif devices is not None: 

1720 if not hasattr(devices, '__len__'): 

1721 devices = (devices,) 

1722 device = devices[0] 

1723 return device.platform, device, devices 

1724 elif hasattr(y_train, 'platform'): 1724 ↛ 1731line 1724 didn't jump to line 1731 because the condition on line 1724 was always true

1725 # set device=None because if the devices were not specified explicitly 

1726 # we may be in the case where computation will follow data placement, 

1727 # do not disturb jax as the user may be playing with vmap, jit, reshard... 

1728 platform = y_train.platform() # ty: ignore[call-non-callable] 

1729 return platform, None, jax.devices(platform) 

1730 else: 

1731 msg = 'not possible to infer device from `y_train`, please set `devices`' 

1732 raise ValueError(msg) 

1733 

1734 

1735def _largest_divisor_at_most(n: int, cap: int) -> int: 

1736 """Return the largest divisor of `n` in [1, cap].""" 

1737 for d in range(cap, 0, -1): 1737 ↛ 1740line 1737 didn't jump to line 1740 because the loop on line 1737 didn't complete

1738 if n % d == 0: 

1739 return d 

1740 return 1 # unreachable: 1 always divides n 

1741 

1742 

1743def _determine_num_chain_devices( 

1744 platform: str, 

1745 num_chains: int | None, 

1746 num_chain_devices: int | None | Literal['auto'], 

1747 num_data_devices: int | None, 

1748 num_devices: int, 

1749 explicit_devices: bool, 

1750) -> int | None: 

1751 """Resolve and validate `num_chain_devices`, returning the chain mesh axis size or `None`.""" 

1752 if num_chain_devices == 'auto': 

1753 num_chain_devices = _auto_num_chain_devices( 

1754 platform, num_chains, num_data_devices, num_devices, explicit_devices 

1755 ) 

1756 

1757 # an explicit value must be a positive divisor of the number of chains 

1758 if num_chain_devices is not None: 

1759 effective_chains = 1 if num_chains is None else num_chains 

1760 if num_chain_devices < 1 or effective_chains % num_chain_devices: 

1761 chains_desc = ( 

1762 'a single chain (num_chains=None)' 

1763 if num_chains is None 

1764 else f'num_chains={num_chains}' 

1765 ) 

1766 msg = ( 

1767 f'num_chain_devices={num_chain_devices} must be a positive ' 

1768 f'divisor of the number of chains ({chains_desc})' 

1769 ) 

1770 raise ValueError(msg) 

1771 

1772 # there is no chain axis to shard when the chains are scalar 

1773 if num_chains is None: 

1774 return None 

1775 return num_chain_devices 

1776 

1777 

1778def _auto_num_chain_devices( 

1779 platform: str, 

1780 num_chains: int | None, 

1781 num_data_devices: int | None, 

1782 num_devices: int, 

1783 explicit_devices: bool, 

1784) -> int | None: 

1785 """Pick `num_chain_devices` automatically for multi-chain cpu runs. 

1786 

1787 `num_data_devices` reserves devices for the data axis, so the chain axis can 

1788 only use a fraction of them; this keeps the ``chains x data`` mesh within the 

1789 `num_devices` available devices. 

1790 """ 

1791 if num_chains is None or num_chains == 1 or platform != 'cpu': 

1792 return None 

1793 data_devices = num_data_devices or 1 

1794 num_cores = cpu_count() 

1795 assert num_cores is not None, 'could not determine number of cpu cores' 

1796 

1797 # devices available for the chain axis after reserving for the data axis 

1798 core_budget = max(1, num_cores // data_devices) 

1799 num_shards = _largest_divisor_at_most(num_chains, core_budget) 

1800 

1801 if num_shards > 1: 

1802 # the mesh draws from `num_devices` devices, whether those are all the 

1803 # platform's devices or an explicit subset passed by the user 

1804 device_budget = max(1, num_devices // data_devices) 

1805 if device_budget < num_shards: 

1806 new_num_shards = _largest_divisor_at_most(num_chains, device_budget) 

1807 warn( 

1808 _auto_chain_devices_warning( 

1809 num_chains, 

1810 num_shards, 

1811 new_num_shards, 

1812 device_budget, 

1813 num_devices, 

1814 num_data_devices, 

1815 explicit_devices, 

1816 ) 

1817 ) 

1818 num_shards = new_num_shards 

1819 

1820 return num_shards if num_shards > 1 else None 

1821 

1822 

1823def _auto_chain_devices_warning( 

1824 num_chains: int, 

1825 desired: int, 

1826 actual: int, 

1827 device_budget: int, 

1828 num_devices: int, 

1829 num_data_devices: int | None, 

1830 explicit_devices: bool, 

1831) -> str: 

1832 """Compose the warning shown when auto chain sharding is capped by the device count.""" 

1833 if explicit_devices: 

1834 pool = f'the {num_devices} devices passed in `devices`' 

1835 few = f'only {num_devices} devices were passed in `devices`' 

1836 advice = '' 

1837 else: 

1838 pool = f'the {num_devices} jax cpu devices' 

1839 few = f'jax is set up with only {num_devices} cpu devices' 

1840 advice = ( 

1841 ' To enable more parallelization, increase the limit with ' 

1842 '`jax.config.update("jax_num_cpu_devices", <num_devices>)`.' 

1843 ) 

1844 if num_data_devices: 

1845 limit = ( 

1846 f'only {device_budget} of {pool} are free for chains ' 

1847 f'(num_data_devices={num_data_devices} reserves the rest)' 

1848 ) 

1849 else: 

1850 limit = few 

1851 return ( 

1852 f'`Bart` would like to shard {num_chains} chains across {desired} ' 

1853 f'devices, but {limit}, so it will use {actual} devices for chains ' 

1854 f'instead.{advice}' 

1855 ) 

1856 

1857 

1858def _determine_mesh( 

1859 num_chain_devices: int | None, 

1860 num_data_devices: int | None, 

1861 device: Device | None, 

1862 devices: Sequence[Device], 

1863) -> tuple[Mesh | None, Device | None]: 

1864 """Create a jax device mesh for `mcmcstep.init()`.""" 

1865 if num_chain_devices is None and num_data_devices is None: 

1866 return None, device 

1867 else: 

1868 mesh = dict() 

1869 if num_chain_devices is not None: 

1870 mesh.update(chains=num_chain_devices) 

1871 if num_data_devices is not None: 

1872 mesh.update(data=num_data_devices) 

1873 mesh = make_mesh( 

1874 axis_shapes=tuple(mesh.values()), 

1875 axis_names=tuple(mesh), 

1876 axis_types=(AxisType.Auto,) * len(mesh), 

1877 devices=devices, 

1878 ) 

1879 return mesh, None 

1880 # set device=None because `mcmcstep.init` will `device_put` with the 

1881 # mesh already, we don't want to undo its work 

1882 

1883 

1884def process_varprob( 

1885 varprob: Float[ArrayLike, ' p'] | None, max_split: UInt[Array, ' p'] 

1886) -> Float32[Array, ' p'] | None: 

1887 """Convert varprob to log_s.""" 

1888 if varprob is None: 

1889 return None 

1890 varprob = jnp.asarray(varprob) 

1891 assert varprob.shape == max_split.shape, 'varprob must have shape (p,)' 

1892 varprob = error_if(varprob, varprob <= 0, 'varprob must be > 0') 

1893 return jnp.log(varprob) 

1894 

1895 

1896def predict_latent( 

1897 x: UInt[Array, 'p m'], 

1898 trace: MainTrace, 

1899 test_points: Literal['none', 'autobatch', 'shard_and_autobatch'] = 'none', 

1900) -> Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m']: 

1901 """Evaluate trees on already quantized `x`, and squash chains.""" 

1902 return evaluate_trace(x, trace, flatten_chains=True, test_points=test_points) 

1903 

1904 

1905@jit(static_argnums=(5, 6, 7)) 

1906def predict( 

1907 key: Key[Array, ''] | None, 

1908 trace: MainTrace, 

1909 x_test: UInt[Array, 'p m'], 

1910 error_scale: Float[Array, ' m'] | Float[Array, 'k m'] | None, 

1911 binary_indices: Int32[Array, ' kb'] | None, 

1912 has_binary: bool, 

1913 kind: PredictKind | str, 

1914 test_points: Literal['none', 'autobatch', 'shard_and_autobatch'], 

1915 /, 

1916) -> ( 

1917 Float32[Array, ' m'] 

1918 | Float32[Array, 'k m'] 

1919 | Float32[Array, 'ndpost m'] 

1920 | Float32[Array, 'ndpost k m'] 

1921): 

1922 """Implement `Bart.predict`.""" 

1923 # get latent i.e. bare sum-of-trees predictions 

1924 latent = predict_latent(x_test, trace, test_points) 

1925 if kind is PredictKind.latent_samples: 

1926 return latent 

1927 

1928 # sample posterior (uses latent directly, no probit squash needed) 

1929 if kind is PredictKind.outcome_samples: 

1930 assert key is not None 

1931 return sample_outcome( 

1932 key, trace, latent, error_scale, binary_indices, has_binary 

1933 ) 

1934 

1935 # squash predictions to (0, 1) if probit 

1936 if binary_indices is not None: 

1937 indexing = jnp.s_[..., binary_indices, :] 

1938 mean_samples = latent.at[indexing].set(ndtr(latent[indexing])) 

1939 elif has_binary: # self._mcmc_state.binary_y is not None: 

1940 mean_samples = ndtr(latent) 

1941 else: 

1942 mean_samples = latent 

1943 

1944 # take mean or return samples 

1945 if kind is PredictKind.mean: 

1946 return mean_samples.mean(axis=0) 

1947 return mean_samples 

1948 

1949 

1950@jit(static_argnums=(5,)) 

1951def sample_outcome( 

1952 key: Key[Array, ''], 

1953 trace: MainTrace, 

1954 latent: Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'], 

1955 error_scale: Float32[Array, ' m'] | Float32[Array, 'k m'] | None, 

1956 binary_indices: Int32[Array, ' kb'] | None, 

1957 has_binary: bool, 

1958 /, 

1959) -> Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m']: 

1960 """Sample from the posterior predictive distribution.""" 

1961 # move error_cov_inv chain axis to 0 

1962 prec = chain_to_axis(trace.error_cov_inv, chain_vmap_axes(trace).error_cov_inv) 

1963 

1964 if latent.ndim > 2: # multivariate case 

1965 error_cov_inv = lax.collapse(prec, 0, -2) 

1966 

1967 # Cholesky of precision: error_cov_inv = L @ L^T 

1968 L = chol_with_gersh(error_cov_inv) # (ndpost, k, k) 

1969 

1970 # Sample z ~ N(0, I) and solve L^T @ error = z 

1971 # so error = L^{-T} z ~ N(0, L^{-T} L^{-1}) = N(0, Sigma) 

1972 z = random.normal(key, latent.shape) # (ndpost, k, m) 

1973 error = solve_triangular(L, z, trans='T', lower=True) # (ndpost, k, m) 

1974 if error_scale is not None: 

1975 # error_scale is (m,) or (k, m) so it always broadcasts right 

1976 error *= error_scale 

1977 elif has_binary: 1977 ↛ 1979line 1977 didn't jump to line 1979 because the condition on line 1977 was never true

1978 # pure binary UV: probit has sigma = 1 

1979 error = random.normal(key, latent.shape) 

1980 else: # univariate continuous 

1981 sigma = jnp.sqrt(jnp.reciprocal(prec)).reshape(-1) 

1982 error = sigma[..., None] * random.normal(key, latent.shape) 

1983 if error_scale is not None: 1983 ↛ 1986line 1983 didn't jump to line 1986 because the condition on line 1983 was always true

1984 error *= error_scale[None, :] 

1985 

1986 outcome = latent + error 

1987 

1988 # convert binary outcomes via latent probit thresholding 

1989 if binary_indices is not None: 

1990 idx = jnp.s_[..., binary_indices, :] 

1991 outcome = outcome.at[idx].set(jnp.where(outcome[idx] > 0, 1.0, 0.0)) 

1992 elif has_binary: 

1993 outcome = jnp.where(outcome > 0, 1.0, 0.0) 

1994 

1995 return outcome