Coverage for src/bartz/mcmcstep/_state.py: 99%

462 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/mcmcstep/_state.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Module defining the BART MCMC state and initialization.""" 

26 

27from collections.abc import Callable, Sequence 

28from dataclasses import replace 

29from enum import Enum 

30from functools import partial, wraps 

31from typing import Literal, TypedDict, TypeVar, cast 

32 

33import jax 

34import numpy 

35from equinox import error_if, filter_jit 

36from jax import NamedSharding, device_put, lax, make_mesh, random, shard_map, tree, vmap 

37from jax import numpy as jnp 

38from jax.scipy.linalg import solve_triangular 

39from jax.sharding import AxisType, Mesh, PartitionSpec 

40from jaxtyping import ( 

41 Array, 

42 Bool, 

43 Float, 

44 Float32, 

45 Int32, 

46 Integer, 

47 Key, 

48 PyTree, 

49 Shaped, 

50 UInt, 

51 UInt32, 

52) 

53from numpy import ndarray 

54 

55from bartz._jaxext import Module, field, jaxtyping_disabled, jit, minimal_unsigned_dtype 

56from bartz.grove import tree_depths 

57from bartz.mcmcstep._axes import CHAIN_AXIS, chain_vmap_axes, data_vmap_axes 

58from bartz.mcmcstep._lazy import ( 

59 _is_lazy_or_none, 

60 _lazy, 

61 _lazy_from_array, 

62 _LazyArray, 

63 _wrap_chain, 

64 add_dummy_axis, 

65) 

66from bartz.mcmcstep._reduction import ( 

67 AutoBatchedReduction, 

68 AutoOneHotReduction, 

69 ReductionConfig, 

70) 

71 

72ArrayLike = Array | ndarray 

73 

74FloatLike = float | Float[ArrayLike, ''] 

75 

76CHAIN_AXIS_AFTER_TREES = {0: 1, -1: -1}[CHAIN_AXIS] 

77 

78 

79class OutcomeType(Enum): 

80 """Likelihood types for each outcome component in the regression.""" 

81 

82 continuous = 'continuous' 

83 """Continuous outcome with Normal error.""" 

84 

85 binary = 'binary' 

86 """Binary outcome in {0, 1} with probit link.""" 

87 

88 

89T = TypeVar('T') 

90 

91 

92class Wishart(Module): 

93 """A precision matrix with a Wishart prior, bundled with its current value. 

94 

95 Represents a random precision (inverse covariance) ``value`` drawn from a 

96 Wishart prior with degrees of freedom `nu` and rate matrix `rate`. The 

97 univariate case (``k = 1``) is the Gamma special case; the relationship to 

98 the inverse-gamma prior on the variance is ``alpha = nu / 2``, 

99 ``beta = rate / 2``. The prior mean of the precision is ``nu * rate^-1``. 

100 

101 Set `nu` and `rate` to `None` to represent a precision held fixed at `value` 

102 with no prior (e.g. the identity in binary regression). 

103 """ 

104 

105 nu: Float32[Array, ''] | None 

106 """Degrees of freedom of the Wishart prior, or `None` if there is no prior.""" 

107 

108 rate: Float32[Array, ''] | Float32[Array, 'k k'] | None 

109 """The rate matrix of the Wishart prior (scalar for univariate), or `None` 

110 if there is no prior. Equal to the inverse-gamma ``scale`` in the 

111 univariate case.""" 

112 

113 value: Float32[Array, '*chains k k'] | Float32[Array, '*chains'] = field( 

114 chains=CHAIN_AXIS 

115 ) 

116 """The precision matrix (scalar for univariate).""" 

117 

118 def __init__( 

119 self, 

120 nu: FloatLike | None, 

121 rate: FloatLike | Float[ArrayLike, 'k k'] | None, 

122 value: FloatLike 

123 | Float[ArrayLike, '*chains k k'] 

124 | Float[ArrayLike, '*chains'], 

125 ) -> None: 

126 # `init` passes a deferred `_LazyArray` (cast to `Array`) for `value` to 

127 # route it through sharding. 

128 assert (nu is None) == (rate is None), 'set both or neither of nu and rate' 

129 self.nu = None if nu is None else jnp.asarray(nu, jnp.float32) 

130 self.rate = None if rate is None else jnp.asarray(rate, jnp.float32) 

131 if isinstance(value, _LazyArray): 

132 self.value = cast(Array, value) 

133 else: 

134 self.value = jnp.asarray(value, jnp.float32) 

135 

136 

137class DiagWishart(Wishart): 

138 """A diagonal precision matrix with independent chi-square diagonal entries. 

139 

140 Despite the name this is not a Wishart restricted to diagonal matrices, but 

141 a convenience type: a diagonal precision whose entries are mutually 

142 independent, each with its own Gamma (scaled chi-square) prior. Only the 

143 multivariate (matrix) case is supported. 

144 

145 A component with `rate` 0 has no prior; its precision is held fixed at its 

146 `value` (1 for the binary components of a mixed regression). 

147 

148 Used for mixed binary-continuous regression and for continuous multivariate 

149 regression with per-datapoint missingness. 

150 """ 

151 

152 def __init__( 

153 self, 

154 nu: FloatLike | None, 

155 rate: FloatLike | Float[ArrayLike, 'k k'] | None, 

156 value: FloatLike 

157 | Float[ArrayLike, '*chains k k'] 

158 | Float[ArrayLike, '*chains'], 

159 ) -> None: 

160 # explicit (delegating) init so the static checker uses this signature 

161 # instead of synthesizing a stricter one from the inherited fields 

162 assert rate is None or jnp.ndim(rate) == 2, ( 

163 'DiagWishart supports only the multivariate (matrix) case' 

164 ) 

165 super().__init__(nu, rate, value) 

166 

167 

168class Forest(Module): 

169 """Represents the MCMC state of a sum of trees.""" 

170 

171 # Heap-array fields follow the `bartz.grove.TreesTrace` convention: the 

172 # union-free integer trees are declared before `leaf_tree` and carry the 

173 # bindable `half_tree_size` axis, while `leaf_tree` (and `p_nonterminal`) are 

174 # checked against `2*half_tree_size`. Declaring a union-free `*chains` field 

175 # first binds the variadic chain axis (plus `num_trees` and `half_tree_size`) 

176 # before `leaf_tree`'s `... | ... k ...` union is evaluated, so the runtime 

177 # typechecker can't mis-bind `*chains` against the `k` axis of a multivariate 

178 # forest (the layouts are otherwise rank-ambiguous). No anchor field needed. 

179 var_tree: UInt[Array, '*chains num_trees half_tree_size'] = field(chains=CHAIN_AXIS) 

180 """The decision axes.""" 

181 

182 split_tree: UInt[Array, '*chains num_trees half_tree_size'] = field( 

183 chains=CHAIN_AXIS 

184 ) 

185 """The decision boundaries.""" 

186 

187 affluence_tree: Bool[Array, '*chains num_trees half_tree_size'] = field( 

188 chains=CHAIN_AXIS 

189 ) 

190 """Marks leaves that can be grown.""" 

191 

192 leaf_tree: ( 

193 Float32[Array, '*chains num_trees 2*half_tree_size'] 

194 | Float32[Array, '*chains num_trees k 2*half_tree_size'] 

195 ) = field(chains=CHAIN_AXIS) 

196 """The leaf values.""" 

197 

198 grow_prop_count: Int32[Array, '*chains'] = field(chains=CHAIN_AXIS) 

199 """The number of grow proposals made during one full MCMC cycle.""" 

200 

201 prune_prop_count: Int32[Array, '*chains'] = field(chains=CHAIN_AXIS) 

202 """The number of prune proposals made during one full MCMC cycle.""" 

203 

204 grow_acc_count: Int32[Array, '*chains'] = field(chains=CHAIN_AXIS) 

205 """The number of grow moves accepted during one full MCMC cycle.""" 

206 

207 prune_acc_count: Int32[Array, '*chains'] = field(chains=CHAIN_AXIS) 

208 """The number of prune moves accepted during one full MCMC cycle.""" 

209 

210 max_split: UInt[Array, ' p'] 

211 """The maximum split index for each predictor.""" 

212 

213 blocked_vars: UInt[Array, ' q'] | None 

214 """Indices of variables that are not used. This shall include at least 

215 the `i` such that ``max_split[i] == 0``, otherwise behavior is 

216 undefined.""" 

217 

218 p_nonterminal: Float32[Array, ' 2*half_tree_size'] 

219 """The prior probability of each node being nonterminal, conditional on 

220 its ancestors. Includes the nodes at maximum depth which should be set 

221 to 0.""" 

222 

223 p_propose_grow: Float32[Array, ' half_tree_size'] 

224 """The unnormalized probability of picking a leaf for a grow proposal.""" 

225 

226 leaf_indices: UInt[Array, 'num_trees *chains n'] = field( 

227 chains=CHAIN_AXIS_AFTER_TREES, data=-1 

228 ) 

229 """The index of the leaf each datapoints falls into, for each tree. 

230 

231 The chain axis sits after `num_trees` (not leading, unlike sibling fields) 

232 so the per-tree `lax.scan` in `step`, under the chain `vmap`, avoids a 

233 transpose of this large array that otherwise inflates gpu peak memory.""" 

234 

235 count_tree: UInt32[Array, '*chains num_trees 2*half_tree_size'] | None = field( 

236 chains=CHAIN_AXIS 

237 ) 

238 """The number of datapoints per leaf. Valid at the leaves and at the nodes 

239 involved in the latest moves, dirty elsewhere. `None` if the error 

240 precision is weighted and there are no minimum-points-per-node 

241 constraints, which makes the counts unused.""" 

242 

243 prec_tree: ( 

244 Float32[Array, '*chains num_trees 2*half_tree_size'] 

245 | Float32[Array, '*chains num_trees k k 2*half_tree_size'] 

246 | None 

247 ) = field(chains=CHAIN_AXIS) 

248 """The likelihood precision scale summed over the datapoints in each leaf. 

249 Valid at the leaves and at the nodes involved in the latest moves, dirty 

250 elsewhere. `None` if the error precision is not weighted, in which case 

251 `count_tree` takes its place.""" 

252 

253 min_points_per_decision_node: Int32[Array, ''] | None 

254 """The minimum number of data points in a decision node.""" 

255 

256 min_points_per_leaf: Int32[Array, ''] | None 

257 """The minimum number of data points in a leaf node.""" 

258 

259 log_trans_prior: Float32[Array, '*chains num_trees'] | None = field( 

260 chains=CHAIN_AXIS 

261 ) 

262 """The log transition and prior Metropolis-Hastings ratio for the 

263 proposed move on each tree.""" 

264 

265 log_likelihood: Float32[Array, '*chains num_trees'] | None = field( 

266 chains=CHAIN_AXIS 

267 ) 

268 """The log likelihood ratio.""" 

269 

270 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None 

271 """The prior precision matrix of a leaf, conditional on the tree structure. 

272 For the univariate case (k=1), this is a scalar (the inverse variance). 

273 The prior covariance of the sum of trees is 

274 ``num_trees * leaf_prior_cov_inv^-1``.""" 

275 

276 log_s: Float32[Array, '*chains p'] | None = field(chains=CHAIN_AXIS) 

277 """The logarithm of the prior probability for choosing a variable to split 

278 along in a decision rule, conditional on the ancestors. Not normalized. 

279 If `None`, use a uniform distribution.""" 

280 

281 theta: Float32[Array, '*chains'] | None = field(chains=CHAIN_AXIS) 

282 """The concentration parameter for the Dirichlet prior on the variable 

283 distribution `s`. Required only to update `log_s`.""" 

284 

285 a: Float32[Array, ''] | None 

286 """Parameter of the prior on `theta`. Required only to sample `theta`.""" 

287 

288 b: Float32[Array, ''] | None 

289 """Parameter of the prior on `theta`. Required only to sample `theta`.""" 

290 

291 rho: Float32[Array, ''] | None 

292 """Parameter of the prior on `theta`. Required only to sample `theta`.""" 

293 

294 @property 

295 def has_chains(self) -> bool: 

296 """Whether this forest carries an explicit chain axis.""" 

297 return self.var_tree.ndim > 2 

298 

299 

300class StepConfig(Module): 

301 """Options for the MCMC step.""" 

302 

303 steps_done: Int32[Array, ''] 

304 """The number of MCMC steps completed so far.""" 

305 

306 sparse_on_at: Int32[Array, ''] | None 

307 """After how many steps to turn on variable selection.""" 

308 

309 resid_reduction_config: ReductionConfig 

310 """How to sum the residuals in each leaf.""" 

311 

312 count_reduction_config: ReductionConfig 

313 """How to count the datapoints in each leaf.""" 

314 

315 prec_reduction_config: ReductionConfig 

316 """How to sum the likelihood precisions in each leaf.""" 

317 

318 prec_count_num_trees: int | None = field(static=True) 

319 """Batch size for processing trees to compute count and prec trees.""" 

320 

321 sequential_unroll: int | bool = field(static=True) 

322 """How much to unroll the sequential accept/reject loop over trees in 

323 `step`. See the ``unroll`` argument of `jax.lax.scan`.""" 

324 

325 augment: bool = field(static=True) 

326 """Whether to account exactly, via data augmentation, for the decision rules 

327 forbidden by the ancestors of each node when updating `log_s`.""" 

328 

329 mesh: Mesh | None = field(static=True) 

330 """The mesh used to shard data and computation across multiple devices.""" 

331 

332 @property 

333 def data_sharded(self) -> bool: 

334 """Whether the data axis is sharded across devices.""" 

335 return self.mesh is not None and 'data' in self.mesh.axis_names 

336 

337 

338class State(Module): 

339 """Represents the MCMC state of BART.""" 

340 

341 _chain_anchor: Float32[Array, '*chains'] = field(chains=CHAIN_AXIS) 

342 """Unused per-chain scalar, declared first as a runtime-typechecker anchor. 

343 Its single (union-free) ``*chains`` annotation binds the variadic chain axis 

344 before the ``... | ... k ...`` unions of `z`/`resid` (z over the 

345 binary-outcome ``kb`` axis) are checked; otherwise those can mis-bind 

346 ``*chains`` against the outcome axis for a multivariate-without-chains state 

347 (the layouts are rank-ambiguous). Unlike `Forest`, `State` has no genuine 

348 union-free chain field to reorder into this slot, so a dummy one is 

349 carried.""" 

350 

351 X: UInt[Array, 'p n'] = field(data=-1) 

352 """The predictors.""" 

353 

354 binary_y: None | Bool[Array, ' n'] | Bool[Array, 'kb n'] = field(data=-1) 

355 """The response as booleans for binary regression, `None` for continuous. 

356 In the mixed binary-continuous case, only the binary outcome components 

357 are stored, with shape ``(kb, n)``.""" 

358 

359 z: None | Float32[Array, '*chains n'] | Float32[Array, '*chains kb n'] = field( 

360 chains=CHAIN_AXIS, data=-1 

361 ) 

362 """The latent variable for binary regression. `None` in continuous 

363 regression. In the mixed binary-continuous case, only the binary outcome 

364 components are stored, with shape ``(*chains, kb, n)``.""" 

365 

366 binary_indices: None | Int32[Array, ' kb'] 

367 """The indices of binary outcome components in the full list of outcome 

368 components. `None` when there are no binary components.""" 

369 

370 offset: Float32[Array, ''] | Float32[Array, ' k'] 

371 """Constant shift added to the sum of trees.""" 

372 

373 resid: Float32[Array, '*chains n'] | Float32[Array, '*chains k n'] = field( 

374 chains=CHAIN_AXIS, data=-1 

375 ) 

376 """The residuals (`y` or `z` minus sum of trees).""" 

377 

378 error_cov_inv: Wishart 

379 """The inverse error covariance with its Wishart prior. The current value is 

380 ``error_cov_inv.value`` (scalar for univariate, matrix for multivariate); 

381 identity with no prior in binary regression.""" 

382 

383 prec_scale: Float32[Array, ' n'] | Float32[Array, 'k k n'] | None = field(data=-1) 

384 """The scale on the error precision. `None` in binary regression. With 

385 scalar per-datapoint weights, shape ``(n,)`` and value 

386 ``1 / error_scale ** 2``. With vector per-datapoint weights, shape ``(k, k, n)`` 

387 and value ``1/outer(error_scale, error_scale)`` repeated over datapoints.""" 

388 

389 inv_sdev_scale: Float32[Array, ' n'] | Float32[Array, 'k n'] | None = field(data=-1) 

390 """The reciprocal of the per-observation error standard-deviation scale. 

391 `None` in binary regression. Shape ``(n,)`` for scalar weights, or 

392 ``(k, n)`` for per-component vector weights.""" 

393 

394 forest: Forest 

395 """The sum of trees model.""" 

396 

397 config: StepConfig 

398 """Metadata and configurations for the MCMC step.""" 

399 

400 @property 

401 def has_chains(self) -> bool: 

402 """Whether this state carries an explicit chain axis.""" 

403 return self.forest.has_chains 

404 

405 def num_chains(self) -> int | None: 

406 """Return the number of chains, or `None` if not multichain.""" 

407 if not self.has_chains: 

408 return None 

409 c = chain_vmap_axes(self.forest).var_tree 

410 return self.forest.var_tree.shape[c] 

411 

412 

413def _check_diagonal(rate: Float32[Array, 'k k']) -> Float32[Array, 'k k']: 

414 """Raise if the Wishart `rate` is not diagonal.""" 

415 diag = jnp.diag(jnp.diag(rate)) 

416 return error_if(rate, jnp.any(rate != diag), 'error_cov_inv.rate must be diagonal') 

417 

418 

419def _check_binary_unit_precision( 

420 value: Float32[Array, 'k k'], binary_mask: Sequence[bool] 

421) -> Float32[Array, 'k k']: 

422 """Raise if the binary diagonal entries of `value` are not fixed at 1.""" 

423 binary = jnp.array(binary_mask) 

424 off_unit = jnp.any(binary & (jnp.diag(value) != 1.0)) 

425 return error_if( 

426 value, 

427 off_unit, 

428 'binary error precision must be 1 (the default for a zero rate)', 

429 ) 

430 

431 

432def _init_shape_shifting_parameters( 

433 y: Float32[Array, ' n'] | Float32[Array, 'k n'], 

434 outcome_type: OutcomeType | list[OutcomeType], 

435 offset: Float32[Array, ''] | Float32[Array, ' k'], 

436 error_scale: Float32[ArrayLike, ' n'] | Float32[ArrayLike, 'k n'] | None, 

437 error_cov_inv: Wishart | None, 

438 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

439 missing: Bool[ArrayLike, ' n'] | Bool[ArrayLike, 'k n'] | None, 

440) -> tuple[bool, tuple[int, ...], Wishart, None | Int32[Array, ' kb']]: 

441 """ 

442 Check and initialize parameters that change array type/shape based on outcome kind. 

443 

444 Parameters 

445 ---------- 

446 y 

447 The response variable (used only for shape checks). 

448 outcome_type 

449 Whether the regression is continuous or binary. Can be a list of 

450 `OutcomeType` for per-component specification in the multivariate case. 

451 offset 

452 The offset to add to the predictions. 

453 error_scale 

454 Per-observation error scale (univariate only). 

455 error_cov_inv 

456 The Wishart prior on the error precision and its initial value, or 

457 `None` for binary regression. The mixed and partial-missing diagonal 

458 modes require a `DiagWishart`; in the mixed case the binary components 

459 must have an initial precision of 1. 

460 leaf_prior_cov_inv 

461 The inverse of the leaf prior covariance. 

462 missing 

463 The per-datapoint missingness mask, used to detect partial missingness 

464 (2-D mask) so that diagonal-mode initialization is selected. 

465 

466 Returns 

467 ------- 

468 is_binary 

469 Whether all outcomes are binary. 

470 kshape 

471 The outcome shape, empty for univariate, (k,) for multivariate. 

472 error_cov_inv 

473 The Wishart prior with its initial value resolved for the outcome kind. 

474 binary_indices 

475 The indices of binary outcome components, or `None` if there are none. 

476 """ 

477 kshape = offset.shape 

478 

479 # determine per-component outcome kinds 

480 if isinstance(outcome_type, list): 

481 assert kshape, 'per-component outcome_type requires multivariate y' 

482 (k,) = kshape 

483 assert len(outcome_type) == k 

484 binary_mask = [t is OutcomeType.binary for t in outcome_type] 

485 is_binary = all(binary_mask) 

486 is_mixed = any(binary_mask) and not is_binary 

487 else: 

488 is_binary = outcome_type is OutcomeType.binary 

489 is_mixed = False 

490 

491 if is_mixed: 

492 binary_indices = jnp.array([i for i, b in enumerate(binary_mask) if b]) 

493 else: 

494 binary_indices = None 

495 

496 partial_missing = missing is not None and missing.ndim == 2 and kshape 

497 

498 # All-binary: no prior, the precision is fixed at the identity. 

499 if is_binary: 

500 assert error_scale is None 

501 assert error_cov_inv is None, 'no error covariance prior in binary regression' 

502 value = jnp.eye(kshape[0]) if kshape else jnp.array(1.0) 

503 error_cov_inv = Wishart(nu=None, rate=None, value=value) 

504 

505 # Mixed binary-continuous, or continuous-mv with 2-D missingness: diagonal 

506 # error covariance, updated component-wise. The caller must supply a 

507 # `DiagWishart`; in the mixed case the binary components must have unit 

508 # initial precision (see `DiagWishart`). 

509 elif is_mixed or partial_missing: 

510 if is_mixed: 

511 assert error_scale is None 

512 assert isinstance(error_cov_inv, DiagWishart), ( 

513 'mixed binary-continuous or partial-missing regression requires a ' 

514 'DiagWishart error_cov_inv prior' 

515 ) 

516 assert error_cov_inv.rate is not None 

517 assert error_cov_inv.rate.shape == 2 * kshape 

518 assert error_cov_inv.value.shape == 2 * kshape 

519 rate = _check_diagonal(error_cov_inv.rate) 

520 value = _check_diagonal(error_cov_inv.value) 

521 if is_mixed: 

522 value = _check_binary_unit_precision(value, binary_mask) 

523 error_cov_inv = replace(error_cov_inv, rate=rate, value=value) 

524 

525 # All-continuous: a dense `Wishart`. 

526 else: 

527 assert ( 

528 error_scale is None 

529 or error_scale.shape == y.shape # (k, n) 

530 or error_scale.shape == y.shape[-1:] # (n,) 

531 ) 

532 assert error_cov_inv is not None 

533 assert type(error_cov_inv) is Wishart, ( 

534 'continuous regression requires a dense Wishart error_cov_inv prior' 

535 ) 

536 rate = error_cov_inv.rate 

537 assert rate is not None 

538 assert rate.shape == 2 * kshape 

539 assert error_cov_inv.value.shape == 2 * kshape 

540 

541 assert y.shape[:-1] == kshape 

542 assert leaf_prior_cov_inv.shape == 2 * kshape 

543 

544 return is_binary, kshape, error_cov_inv, binary_indices 

545 

546 

547def _check_splitless_vars( 

548 filter_splitless_vars: int, 

549 max_split: UInt[Array, ' p'], 

550 offset: Float32[Array, ''] | Float32[Array, ' k'], 

551) -> Float32[Array, ''] | Float32[Array, ' k']: 

552 """Check there aren't too many deactivated predictors.""" 

553 msg = ( 

554 f'there are more than {filter_splitless_vars=} predictors with no splits, ' 

555 'please increase `filter_splitless_vars` or investigate the missing splits' 

556 ) 

557 return error_if(offset, jnp.sum(max_split == 0) > filter_splitless_vars, msg) 

558 

559 

560def _parse_outcome_type( 

561 outcome_type: 'OutcomeType | str | Sequence[OutcomeType | str]', 

562) -> 'OutcomeType | list[OutcomeType]': 

563 """Normalize outcome_type to enum (or list of enums).""" 

564 if isinstance(outcome_type, Sequence) and not isinstance(outcome_type, str): 

565 return [OutcomeType(t) for t in outcome_type] 

566 else: 

567 return OutcomeType(outcome_type) 

568 

569 

570def _parse_p_nonterminal( 

571 p_nonterminal: Float32[ArrayLike, ' d_minus_1'], 

572) -> Float32[Array, ' d_minus_1+1']: 

573 """Check it's in (0, 1) and pad with a 0 at the end.""" 

574 p_nonterminal = jnp.asarray(p_nonterminal) 

575 ok = (p_nonterminal > 0) & (p_nonterminal < 1) 

576 p_nonterminal = error_if(p_nonterminal, ~ok, 'p_nonterminal must be in (0, 1)') 

577 return jnp.pad(p_nonterminal, (0, 1)) 

578 

579 

580def make_p_nonterminal( 

581 d: int, alpha: FloatLike = 0.95, beta: FloatLike = 2.0 

582) -> Float32[Array, ' {d}-1']: 

583 """Prepare the `p_nonterminal` argument to `init`. 

584 

585 It is calculated according to the formula: 

586 

587 P_nt(depth) = alpha / (1 + depth)^beta, with depth 0-based 

588 

589 Parameters 

590 ---------- 

591 d 

592 The maximum depth of the trees (d=1 means tree with only root node) 

593 alpha 

594 The a priori probability of the root node having children, conditional 

595 on it being possible 

596 beta 

597 The exponent of the power decay of the probability of having children 

598 with depth. 

599 

600 Returns 

601 ------- 

602 An array of probabilities, one per tree level but the last. 

603 """ 

604 assert d >= 1 

605 depth = jnp.arange(d - 1) 

606 return alpha / (1 + depth).astype(float) ** beta 

607 

608 

609def init( 

610 *, 

611 X: UInt[ArrayLike, 'p n'], 

612 y: Float32[ArrayLike, ' n'] | Float32[ArrayLike, ' k n'], 

613 outcome_type: OutcomeType | str | Sequence[OutcomeType | str] = 'continuous', 

614 offset: FloatLike | Float[ArrayLike, ' k'], 

615 max_split: UInt[ArrayLike, ' p'], 

616 num_trees: int, 

617 p_nonterminal: Float32[ArrayLike, ' d_minus_1'], 

618 leaf_prior_cov_inv: FloatLike | Float[ArrayLike, 'k k'], 

619 error_cov_inv: Wishart | None = None, 

620 error_scale: Float32[ArrayLike, ' n'] | Float32[ArrayLike, 'k n'] | None = None, 

621 missing: Bool[ArrayLike, ' n'] | Bool[ArrayLike, 'k n'] | None = None, 

622 min_points_per_decision_node: int | Integer[ArrayLike, ''] | None = None, 

623 resid_reduction_config: ReductionConfig = AutoBatchedReduction(), 

624 count_reduction_config: ReductionConfig = AutoOneHotReduction(), 

625 prec_reduction_config: ReductionConfig = AutoOneHotReduction(), 

626 prec_count_num_trees: int | None | Literal['auto'] = 'auto', 

627 sequential_unroll: int | bool = 2, 

628 save_ratios: bool = False, 

629 filter_splitless_vars: int = 0, 

630 min_points_per_leaf: int | Integer[ArrayLike, ''] | None = None, 

631 log_s: Float32[ArrayLike, ' p'] | None = None, 

632 theta: FloatLike | None = None, 

633 a: FloatLike | None = None, 

634 b: FloatLike | None = None, 

635 rho: FloatLike | None = None, 

636 sparse_on_at: int | Integer[ArrayLike, ''] | None = None, 

637 augment: bool = True, 

638 num_chains: int | None = None, 

639 mesh: Mesh | dict[str, int] | None = None, 

640) -> State: 

641 """ 

642 Make a BART posterior sampling MCMC initial state. 

643 

644 Parameters 

645 ---------- 

646 X 

647 The predictors. Note this is trasposed compared to the usual convention. 

648 y 

649 The response. If two-dimensional, the outcome is multivariate with the 

650 first axis indicating the component. For binary data, non-zero means 1, 

651 zero means 0. 

652 outcome_type 

653 Whether the regression is continuous or binary (probit). Can also be a 

654 sequence of `OutcomeType` values, one per outcome component, for mixed 

655 binary-continuous multivariate regression. 

656 offset 

657 Constant shift added to the sum of trees. 0 if not specified. 

658 max_split 

659 The maximum split index for each variable. All split ranges start at 1. 

660 num_trees 

661 The number of trees in the forest. 

662 p_nonterminal 

663 The probability of a nonterminal node at each depth. The maximum depth 

664 of trees is fixed by the length of this array. Use `make_p_nonterminal` 

665 to set it with the conventional formula. 

666 leaf_prior_cov_inv 

667 The prior precision matrix of a leaf, conditional on the tree structure. 

668 For the univariate case (k=1), this is a scalar (the inverse variance). 

669 The prior covariance of the sum of trees is 

670 ``num_trees * leaf_prior_cov_inv^-1``. The prior mean of leaves is 

671 always zero. 

672 error_cov_inv 

673 The Wishart prior on the inverse error covariance, together with its 

674 initial value (see `Wishart`). Leave it unspecified for binary 

675 regression. The mixed binary-continuous and partial-missing diagonal 

676 modes require a `DiagWishart`; in the mixed case the binary components 

677 must have an initial precision of 1 (see `DiagWishart`). 

678 error_scale 

679 Each error is scaled by the corresponding factor in `error_scale`. If 

680 ``error_scale[..., i]`` is a scalar, each error variance or covariance 

681 matrix is multiplied by ``error_scale[..., i] ** 2``. If 

682 ``error_scale[:, i]`` is a vector, then the covariance matrix is 

683 rescaled by its outer product. Not supported for binary or mixed 

684 binary-continuous regression. If not specified, defaults to 1 for all 

685 points, but potentially skipping calculations. 

686 missing 

687 Boolean mask, same shape as `y`; `True` marks entries to be ignored 

688 by the MCMC. Values of `y` must be finite everywhere, including at 

689 masked positions. If 2-D, `error_cov_inv.rate` must be diagonal. 

690 min_points_per_decision_node 

691 The minimum number of data points in a decision node. 0 if not 

692 specified. 

693 resid_reduction_config 

694 count_reduction_config 

695 prec_reduction_config 

696 How to sum the residuals, count the datapoints, and sum the likelihood 

697 precisions in each leaf, respectively. See `ReductionConfig` and its 

698 subclasses. 

699 prec_count_num_trees 

700 The number of trees to process at a time when counting datapoints or 

701 computing the likelihood precision. If `None`, do all trees at once, 

702 which may use too much memory. If 'auto' (default), it's chosen 

703 automatically. 

704 sequential_unroll 

705 How much to unroll the sequential accept/reject loop over trees in 

706 `step`. See the ``unroll`` argument of `jax.lax.scan`. Unrolling may 

707 speed up the MCMC at the cost of longer compilation. 1 means no 

708 unrolling; the default is 2. 

709 save_ratios 

710 Whether to save the Metropolis-Hastings ratios. 

711 filter_splitless_vars 

712 The maximum number of variables without splits that can be ignored. If 

713 there are more, `init` raises an exception. 

714 min_points_per_leaf 

715 The minimum number of datapoints in a leaf node. 0 if not specified. 

716 Unlike `min_points_per_decision_node`, this constraint is not taken into 

717 account in the Metropolis-Hastings ratio because it would be expensive 

718 to compute. Grow moves that would violate this constraint are vetoed. 

719 This parameter is independent of `min_points_per_decision_node` and 

720 there is no check that they are coherent. It makes sense to set 

721 ``min_points_per_decision_node >= 2 * min_points_per_leaf``. 

722 log_s 

723 The logarithm of the prior probability for choosing a variable to split 

724 along in a decision rule, conditional on the ancestors. Not normalized. 

725 If not specified, use a uniform distribution. If not specified and 

726 `theta` or `rho`, `a`, `b` are, it's initialized automatically. 

727 theta 

728 The concentration parameter for the Dirichlet prior on `s`. Required 

729 only to update `log_s`. If not specified, and `rho`, `a`, `b` are 

730 specified, it's initialized automatically. 

731 a 

732 b 

733 rho 

734 Parameters of the prior on `theta`. Required only to sample `theta`. 

735 sparse_on_at 

736 After how many MCMC steps to turn on variable selection. 

737 augment 

738 Whether to account exactly, via data augmentation, for the decision 

739 rules forbidden by the ancestors of each node when updating `log_s`. If 

740 not set, those rules are ignored, which is faster but only approximate. 

741 num_chains 

742 The number of independent MCMC chains to represent in the state. Single 

743 chain with scalar values if not specified. 

744 mesh 

745 A jax mesh used to shard data and computation across multiple devices. 

746 If it has a 'chains' axis, that axis is used to shard the chains. If it 

747 has a 'data' axis, that axis is used to shard the datapoints. 

748 

749 As a shorthand, if a dictionary mapping axis names to axis size is 

750 passed, the corresponding mesh is created, e.g., ``dict(chains=4, 

751 data=2)`` will let jax pick 8 devices to split chains (which must be a 

752 multiple of 4) across 4 pairs of devices, where in each pair the data is 

753 split in two. 

754 

755 Note: if a mesh is passed, the arrays are always sharded according to 

756 it. In particular even if the mesh has no 'chains' or 'data' axis, the 

757 arrays will be replicated on all devices in the mesh. 

758 

759 Returns 

760 ------- 

761 An initialized BART MCMC state. 

762 

763 Raises 

764 ------ 

765 ValueError 

766 If arguments unused in binary regression are set. 

767 

768 Notes 

769 ----- 

770 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out 

771 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left 

772 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be 

773 integers in the range ``[0, 1, ..., max_split[i]]``. 

774 

775 In general the arrays passed to this function as arguments may be donated, 

776 invalidating them. Create copies before passing them to `init` if this 

777 happens and you need them again. 

778 """ 

779 # convert to array all array-like arguments that are used in other 

780 # configurations but don't need further processing themselves 

781 X = jnp.asarray(X) 

782 y = jnp.asarray(y) 

783 assert y.dtype == jnp.float32 

784 offset = jnp.asarray(offset) 

785 leaf_prior_cov_inv = jnp.asarray(leaf_prior_cov_inv) 

786 max_split = jnp.asarray(max_split) 

787 if error_scale is not None: 

788 error_scale = jnp.asarray(error_scale) 

789 if missing is not None: 

790 missing = jnp.asarray(missing) 

791 assert missing is None or missing.ndim <= y.ndim 

792 

793 # normalize outcome_type to enum (or list of enums) 

794 outcome_type = _parse_outcome_type(outcome_type) 

795 

796 # check p_nonterminal and pad it with a 0 at the end (still not final shape) 

797 p_nonterminal = _parse_p_nonterminal(p_nonterminal) 

798 

799 # process arguments that change depending on outcome type 

800 is_binary, kshape, error_cov_inv, binary_indices = _init_shape_shifting_parameters( 

801 y, outcome_type, offset, error_scale, error_cov_inv, leaf_prior_cov_inv, missing 

802 ) 

803 

804 # extract array sizes from arguments 

805 (max_depth,) = p_nonterminal.shape 

806 p, n = X.shape 

807 

808 # check and initialize sparsity parameters 

809 if not _all_none_or_not_none(rho, a, b): 809 ↛ 810line 809 didn't jump to line 810 because the condition on line 809 was never true

810 msg = 'rho, a, b are not either all `None` or all set' 

811 raise ValueError(msg) 

812 if theta is None and rho is not None: 

813 theta = rho 

814 if log_s is None and theta is not None: 

815 log_s = jnp.zeros(max_split.size) 

816 if not _all_none_or_not_none(theta, sparse_on_at): 816 ↛ 817line 816 didn't jump to line 817 because the condition on line 816 was never true

817 msg = 'sparsity params (either theta or rho,a,b) and sparse_on_at must be either all None or all set' 

818 raise ValueError(msg) 

819 

820 # determine settings for reductions 

821 mesh = _parse_mesh(num_chains, mesh) 

822 red_cfg = _parse_reduction_configs( 

823 resid_reduction_config, 

824 count_reduction_config, 

825 prec_reduction_config, 

826 prec_count_num_trees, 

827 y, 

828 num_trees, 

829 num_chains, 

830 mesh, 

831 ) 

832 

833 # check there aren't too many deactivated predictors 

834 offset = _check_splitless_vars(filter_splitless_vars, max_split, offset) 

835 

836 tree_size = 2**max_depth 

837 

838 # Assemble the state, shard it, then fill in the post-shard fields. This 

839 # whole region runs with type-checking disabled because the state carries 

840 # deliberately wrong-typed intermediates parked in its fields for sharding: 

841 # `_LazyArray` leaves (each chain-bearing leaf is built at its core no-chain 

842 # shape, then `_add_chains` wraps it to broadcast in the chain axis), the raw 

843 # float `y` in the bool `binary_y` slot, and the user `missing` mask and 

844 # `error_scale` in the scale slots. The context ends once every field has 

845 # been replaced by its final, correctly-typed array. 

846 with jaxtyping_disabled(): 

847 state = State( 

848 _chain_anchor=_lazy(jnp.zeros, ()), # typechecker chain anchor 

849 X=X, 

850 binary_y=y, # temporary to be sharded together with everything else 

851 z=( 

852 _lazy(jnp.full, y.shape, offset[..., None]) 

853 if is_binary 

854 else _lazy( 

855 jnp.full, (binary_indices.size, n), offset[binary_indices, None] 

856 ) 

857 if binary_indices is not None 

858 else None 

859 ), 

860 binary_indices=binary_indices, 

861 offset=offset, 

862 resid=( 

863 _lazy(jnp.zeros, y.shape) 

864 if is_binary 

865 # resid is created later after y and offset are sharded 

866 else cast(Array, None) 

867 ), 

868 # only `value` carries the chain axis, so it becomes the lazy leaf; 

869 # the prior params `nu`/`rate` are shared across chains 

870 error_cov_inv=replace( 

871 error_cov_inv, value=_lazy_from_array(error_cov_inv.value) 

872 ), 

873 # temporarily store user inputs in these slots so they get sharded 

874 # with everything else; `_compute_scales` replaces them post-shard. 

875 prec_scale=error_scale, 

876 inv_sdev_scale=missing, 

877 forest=Forest( 

878 leaf_tree=_lazy( 

879 jnp.zeros, (num_trees, *kshape, tree_size), jnp.float32 

880 ), 

881 var_tree=_lazy( 

882 jnp.zeros, 

883 (num_trees, tree_size // 2), 

884 minimal_unsigned_dtype(p - 1), 

885 ), 

886 split_tree=_lazy( 

887 jnp.zeros, (num_trees, tree_size // 2), max_split.dtype 

888 ), 

889 affluence_tree=_lazy( 

890 _initial_affluence_tree, 

891 (num_trees, tree_size // 2), 

892 n, 

893 min_points_per_decision_node, 

894 ), 

895 blocked_vars=_get_blocked_vars(filter_splitless_vars, max_split), 

896 max_split=max_split, 

897 grow_prop_count=_lazy(jnp.zeros, (), int), 

898 grow_acc_count=_lazy(jnp.zeros, (), int), 

899 prune_prop_count=_lazy(jnp.zeros, (), int), 

900 prune_acc_count=_lazy(jnp.zeros, (), int), 

901 p_nonterminal=p_nonterminal[tree_depths(tree_size)], 

902 p_propose_grow=p_nonterminal[tree_depths(tree_size // 2)], 

903 leaf_indices=_lazy( 

904 jnp.ones, (num_trees, n), minimal_unsigned_dtype(tree_size - 1) 

905 ), 

906 # the counts serve the minimum-points constraints and stand in 

907 # for the precisions when the error precision is unweighted 

908 # (`prec_scale` is set iff `error_scale` or `missing` is given) 

909 count_tree=( 

910 _lazy(_initial_count_tree, (num_trees, tree_size), n) 

911 if min_points_per_decision_node is not None 

912 or min_points_per_leaf is not None 

913 or (error_scale is None and missing is None) 

914 else None 

915 ), 

916 # prec_tree is created later, it needs the sharded prec_scale 

917 prec_tree=None, 

918 min_points_per_decision_node=_asarray_or_none( 

919 min_points_per_decision_node 

920 ), 

921 min_points_per_leaf=_asarray_or_none(min_points_per_leaf), 

922 log_trans_prior=_lazy(jnp.zeros, (num_trees,)) if save_ratios else None, 

923 log_likelihood=_lazy(jnp.zeros, (num_trees,)) if save_ratios else None, 

924 leaf_prior_cov_inv=leaf_prior_cov_inv, 

925 log_s=_lazy_from_array(_asarray_or_none(log_s)), 

926 theta=_lazy_from_array(_asarray_or_none(theta)), 

927 rho=_asarray_or_none(rho), 

928 a=_asarray_or_none(a), 

929 b=_asarray_or_none(b), 

930 ), 

931 config=StepConfig( 

932 steps_done=jnp.int32(0), 

933 sparse_on_at=_asarray_or_none(sparse_on_at), 

934 sequential_unroll=sequential_unroll, 

935 augment=augment, 

936 mesh=mesh, 

937 **red_cfg, 

938 ), 

939 ) 

940 

941 # add the chain axis to every chain-marked leaf at the position 

942 # declared by its field metadata 

943 state = _add_chains(state, num_chains) 

944 

945 # delete big input arrays such that they can be deleted as soon as they 

946 # are sharded, only those arrays that contain an (n,) sized axis 

947 del X, error_scale, missing, y 

948 

949 # move all arrays to the appropriate device 

950 state = _shard_state(state) 

951 

952 # calculate initial resid in the continuous outcome case, such that y 

953 # and offset are already sharded if needed 

954 if state.resid is None: 

955 state = _set_initial_resid(state, binary_indices, num_chains) 

956 

957 # calculate initial binary_y 

958 if is_binary or binary_indices is not None: 

959 assert state.binary_y is not None # holds y at this point 

960 binary_y = _LazyArray( 

961 _initial_binary_y, 

962 state.binary_y.shape 

963 if binary_indices is None 

964 else (binary_indices.size, n), 

965 state.binary_y, # this is actually y 

966 binary_indices, 

967 ) 

968 binary_y = _shard_leaf(binary_y, None, -1, state.config.mesh) 

969 else: 

970 binary_y = None 

971 state = replace(state, binary_y=binary_y) 

972 

973 # calculate prec_scale and inv_sdev_scale after sharding to do the 

974 # calculation on the right devices. Pre-shard, `state.prec_scale` holds 

975 # the user-supplied `error_scale` and `state.inv_sdev_scale` holds the 

976 # user-supplied `missing` mask. 

977 if state.prec_scale is not None or state.inv_sdev_scale is not None: 

978 inv_sdev_scale, prec_scale = _compute_scales( 

979 state.prec_scale, state.inv_sdev_scale 

980 ) 

981 state = replace(state, inv_sdev_scale=inv_sdev_scale, prec_scale=prec_scale) 

982 

983 # calculate the initial prec_tree from the sharded prec_scale 

984 state = _set_initial_prec_tree(state, num_chains, num_trees, tree_size) 

985 

986 # all the wrong-typed intermediates have now been replaced by their final 

987 # values, so type-checking can resume; make all types strong to avoid 

988 # unwanted recompilations 

989 return _remove_weak_types(state) 

990 

991 

992def _set_initial_resid( 

993 state: 'State', binary_indices: Int32[Array, ' kb'] | None, num_chains: int | None 

994) -> 'State': 

995 """Build the continuous-outcome `resid` and shard it. 

996 

997 Called post-shard so the captured ``state.binary_y`` and ``state.offset`` 

998 are already on the target devices. Sharding axes are read via 

999 `chain_vmap_axes` / `data_vmap_axes` on a shape preview where the new 

1000 `resid` leaf has the chain-extended ``ndim`` (inflated by a placeholder 

1001 when `num_chains` is not `None`). 

1002 """ 

1003 assert state.binary_y is not None # holds y at this point 

1004 inner = _LazyArray( 

1005 _initial_resid, 

1006 state.binary_y.shape, 

1007 state.binary_y, 

1008 state.offset, 

1009 binary_indices, 

1010 ) 

1011 preview_resid = add_dummy_axis(inner) if num_chains is not None else inner 

1012 preview = replace(state, resid=preview_resid) 

1013 chain_axis = chain_vmap_axes(preview).resid 

1014 data_axis = data_vmap_axes(preview).resid 

1015 resid = _wrap_chain(inner, chain_axis, num_chains) 

1016 resid = _shard_leaf(resid, chain_axis, data_axis, state.config.mesh) 

1017 return replace(state, resid=resid) 

1018 

1019 

1020def _initial_resid( 

1021 shape: tuple[int, ...], 

1022 y: Float32[Array, ' n'] | Float32[Array, 'k n'], 

1023 offset: Float32[Array, ''] | Float32[Array, ' k'], 

1024 binary_indices: Int32[Array, ' kb'] | None, 

1025) -> Float32[Array, ' n'] | Float32[Array, 'k n']: 

1026 """Calculate the initial value for `State.resid` in the continuous outcome case. 

1027 

1028 In the mixed binary-continuous case, binary rows are zeroed out (their 

1029 residual starts at ``z - trees - offset = 0``). 

1030 """ 

1031 resid = jnp.broadcast_to(y - offset[..., None], shape) 

1032 if binary_indices is not None: 

1033 resid = resid.at[..., binary_indices, :].set(0.0) 

1034 return resid 

1035 

1036 

1037def _initial_binary_y( 

1038 shape: tuple[int, ...], 

1039 y: Float32[Array, 'k n'] | Float32[Array, ' n'], 

1040 binary_indices: Int32[Array, ' kb'] | None, 

1041) -> Bool[Array, 'kb n'] | Bool[Array, ' n']: 

1042 """Extract and convert the binary outcome components from ``y``.""" 

1043 if binary_indices is None: 

1044 out = y.astype(bool) 

1045 else: 

1046 out = y[binary_indices, :].astype(bool) 

1047 assert out.shape == shape 

1048 return out 

1049 

1050 

1051def _initial_affluence_tree( 

1052 shape: tuple[int, ...], n: int, min_points_per_decision_node: int | None 

1053) -> Shaped[Array, '...']: 

1054 """Create the initial value of `Forest.affluence_tree`.""" 

1055 return ( 

1056 jnp.zeros(shape, bool) 

1057 .at[..., 1] 

1058 .set( 

1059 True 

1060 if min_points_per_decision_node is None 

1061 else n >= min_points_per_decision_node 

1062 ) 

1063 ) 

1064 

1065 

1066def _initial_count_tree(shape: tuple[int, ...], n: int) -> Shaped[Array, '...']: 

1067 """Create the initial value of `Forest.count_tree`: all datapoints in the root.""" 

1068 return jnp.zeros(shape, jnp.uint32).at[..., 1].set(n) 

1069 

1070 

1071def _set_initial_prec_tree( 

1072 state: State, num_chains: int | None, num_trees: int, tree_size: int 

1073) -> State: 

1074 """Build the cached per-leaf precision for root-only trees and shard it. 

1075 

1076 Called post-shard so the captured ``state.prec_scale`` is already on the 

1077 target devices; mirrors `_set_initial_resid`. 

1078 """ 

1079 assert state.prec_scale is not None 

1080 shape = (num_trees, *state.prec_scale.shape[:-1], tree_size) 

1081 inner = _LazyArray(_initial_prec_tree, shape, state.prec_scale) 

1082 preview_tree = add_dummy_axis(inner) if num_chains is not None else inner 

1083 preview = replace(state, forest=replace(state.forest, prec_tree=preview_tree)) 

1084 chain_axis = chain_vmap_axes(preview).forest.prec_tree 

1085 prec_tree = _wrap_chain(inner, chain_axis, num_chains) 

1086 prec_tree = _shard_leaf(prec_tree, chain_axis, None, state.config.mesh) 

1087 return replace(state, forest=replace(state.forest, prec_tree=prec_tree)) 

1088 

1089 

1090def _initial_prec_tree( 

1091 shape: tuple[int, ...], prec_scale: Float32[Array, ' n'] | Float32[Array, 'k k n'] 

1092) -> Float32[Array, 'num_trees tree_size'] | Float32[Array, 'num_trees k k tree_size']: 

1093 """Create the initial value of `Forest.prec_tree`: all datapoints in the root.""" 

1094 return jnp.zeros(shape, jnp.float32).at[..., 1].set(prec_scale.sum(axis=-1)) 

1095 

1096 

1097@jit(donate_argnums=(0, 1)) 

1098def _compute_scales( 

1099 error_scale: Float32[Array, ' n'] | Float32[Array, 'k n'] | None, 

1100 missing: Bool[Array, ' n'] | Bool[Array, 'k n'] | None, 

1101) -> tuple[ 

1102 Float32[Array, ' n'] | Float32[Array, 'k n'], 

1103 Float32[Array, ' n'] | Float32[Array, 'k k n'], 

1104]: 

1105 """Compute ``inv_sdev_scale`` and ``prec_scale``. 

1106 

1107 This is a separate function to use donate_argnums to avoid intermediate 

1108 copies. At least one of `error_scale` and `missing` must be non-None. 

1109 """ 

1110 if error_scale is None: 

1111 inv_sdev_scale = jnp.array(1.0) 

1112 else: 

1113 inv_sdev_scale = jnp.reciprocal(error_scale) 

1114 if missing is not None: 

1115 inv_sdev_scale = jnp.where(missing, 0.0, inv_sdev_scale) 

1116 if inv_sdev_scale.ndim == 1: 

1117 prec_scale = jnp.square(inv_sdev_scale) 

1118 else: 

1119 prec_scale = jnp.einsum('an,bn->abn', inv_sdev_scale, inv_sdev_scale) 

1120 return inv_sdev_scale, prec_scale 

1121 

1122 

1123def _get_blocked_vars( 

1124 filter_splitless_vars: int, max_split: UInt[Array, ' p'] 

1125) -> None | UInt[Array, ' q']: 

1126 """Initialize the `blocked_vars` field.""" 

1127 if filter_splitless_vars: 

1128 (p,) = max_split.shape 

1129 (blocked_vars,) = jnp.nonzero( 

1130 max_split == 0, size=filter_splitless_vars, fill_value=p 

1131 ) 

1132 return blocked_vars.astype(minimal_unsigned_dtype(p)) 

1133 # see `fully_used_variables` for the type cast 

1134 else: 

1135 return None 

1136 

1137 

1138def _add_chains(state: 'State', num_chains: int | None) -> 'State': 

1139 """Extend chain-marked `_LazyArray` leaves to include a chain axis of size `num_chains`. 

1140 

1141 Walks `state`, asks `chain_vmap_axes` where each leaf's chain axis lives, 

1142 and wraps the carried `_LazyArray` so its factory creates the core array 

1143 and then broadcasts a chain axis in at that position. To make 

1144 `chain_vmap_axes` normalize against the chain-extended ``ndim``, the 

1145 lookup is done on a shape preview built via `add_dummy_axis`. No-op when 

1146 `num_chains` is `None`. 

1147 

1148 Chain-marked leaves are required to be `_LazyArray` (or `None`); eager 

1149 arrays at chain-marked positions are rejected so that all chain insertion 

1150 happens at concretization time inside `_shard_state`. 

1151 """ 

1152 if num_chains is None: 

1153 return state 

1154 preview = add_dummy_axis(state) 

1155 chain_axes = chain_vmap_axes(preview) 

1156 

1157 def wrap(leaf: object, chain_axis: int | None) -> object: 

1158 if chain_axis is None or leaf is None: 

1159 return leaf 

1160 assert isinstance(leaf, _LazyArray), ( 

1161 f'expected _LazyArray for chain-marked leaf, got {type(leaf).__name__}' 

1162 ) 

1163 return _wrap_chain(leaf, chain_axis, num_chains) 

1164 

1165 return tree.map(wrap, state, chain_axes, is_leaf=_is_lazy_or_none) 

1166 

1167 

1168def _parse_mesh( 

1169 num_chains: int | None, mesh: Mesh | dict[str, int] | None 

1170) -> Mesh | None: 

1171 """Parse the `mesh` argument.""" 

1172 if mesh is None: 

1173 return None 

1174 

1175 # convert dict format to actual mesh 

1176 if not isinstance(mesh, Mesh): 

1177 assert set(mesh).issubset({'chains', 'data'}) 

1178 mesh = make_mesh( 

1179 tuple(mesh.values()), tuple(mesh), axis_types=(AxisType.Auto,) * len(mesh) 

1180 ) 

1181 

1182 # the chains mesh axis must be consistent with the number of chains 

1183 if 'chains' in mesh.axis_names: 

1184 if num_chains is None: 

1185 msg = "mesh has a 'chains' axis but num_chains is None (scalar, no chain axis)" 

1186 raise ValueError(msg) 

1187 chains_axis = get_axis_size(mesh, 'chains') 

1188 if num_chains % chains_axis: 

1189 msg = ( 

1190 f"mesh 'chains' axis of size {chains_axis} does not divide " 

1191 f'num_chains={num_chains}' 

1192 ) 

1193 raise ValueError(msg) 

1194 

1195 # check the axes we use are in auto mode 

1196 assert 'chains' not in mesh.axis_names or 'chains' in mesh.auto_axes 

1197 assert 'data' not in mesh.axis_names or 'data' in mesh.auto_axes 

1198 

1199 return mesh 

1200 

1201 

1202@partial(filter_jit, donate='all') 

1203# jit and donate because otherwise type conversion would create copies 

1204def _remove_weak_types(x: PyTree[Array, 'T']) -> PyTree[Array, 'T']: 

1205 """Make all types strong. 

1206 

1207 This is to avoid recompilation in `run_mcmc` or `step`. 

1208 """ 

1209 

1210 def remove_weak(x: T) -> T: 

1211 if isinstance(x, Array) and x.weak_type: 

1212 return cast(T, x.astype(x.dtype)) 

1213 else: 

1214 return x 

1215 

1216 return tree.map(remove_weak, x) 

1217 

1218 

1219def _shard_state(state: State) -> State: 

1220 """Place all arrays on the appropriate devices, and instantiate lazily defined arrays.""" 

1221 mesh = state.config.mesh 

1222 shard_leaf = partial(_shard_leaf, mesh=mesh) 

1223 return tree.map( 

1224 shard_leaf, 

1225 state, 

1226 chain_vmap_axes(state), 

1227 data_vmap_axes(state), 

1228 is_leaf=lambda x: x is None or isinstance(x, _LazyArray), 

1229 ) 

1230 

1231 

1232def _leaf_partition_spec( 

1233 ndim: int, chain_axis: int | None, data_axis: int | None, mesh: Mesh 

1234) -> PartitionSpec: 

1235 """Build a `PartitionSpec` for a leaf with the given chain/data axes.""" 

1236 spec = [None] * ndim 

1237 if chain_axis is not None and 'chains' in mesh.axis_names: 

1238 spec[chain_axis] = 'chains' 

1239 if data_axis is not None and 'data' in mesh.axis_names: 

1240 spec[data_axis] = 'data' 

1241 

1242 # remove trailing Nones to be consistent with jax's output, it's useful 

1243 # for comparing shardings during debugging 

1244 while spec and spec[-1] is None: 

1245 spec.pop() 

1246 

1247 return PartitionSpec(*spec) 

1248 

1249 

1250def _shard_leaf( 

1251 x: Shaped[Array, '*shape'] | None | Shaped[_LazyArray, '*shape'], 

1252 chain_axis: int | None, 

1253 data_axis: int | None, 

1254 mesh: Mesh | None, 

1255) -> Shaped[Array, '*shape'] | None: 

1256 """Create `x` if it's lazy and shard it.""" 

1257 if x is None: 

1258 return None 

1259 

1260 if mesh is None: 

1261 sharding = None 

1262 else: 

1263 spec = _leaf_partition_spec(x.ndim, chain_axis, data_axis, mesh) 

1264 sharding = NamedSharding(mesh, spec) 

1265 

1266 if isinstance(x, _LazyArray): 

1267 x = _concretize_lazy_array(x, sharding) 

1268 elif sharding is not None: 

1269 x = device_put(x, sharding, donate=True) 

1270 

1271 return x 

1272 

1273 

1274@filter_jit 

1275# jit such that in recent jax versions the shards are created on the right 

1276# devices immediately instead of being created on the wrong device and then 

1277# copied 

1278def _concretize_lazy_array( 

1279 x: Shaped[_LazyArray, '*shape'], sharding: NamedSharding | None 

1280) -> Shaped[Array, '*shape']: 

1281 """Create an array from an abstract spec on the appropriate devices.""" 

1282 x = x() 

1283 if sharding is not None: 

1284 x = lax.with_sharding_constraint(x, sharding) 

1285 return x 

1286 

1287 

1288def _all_none_or_not_none(*args: object) -> bool: 

1289 is_none = [x is None for x in args] 

1290 return all(is_none) or not any(is_none) 

1291 

1292 

1293def _asarray_or_none(x: object) -> Shaped[Array, '...'] | None: 

1294 if x is None: 

1295 return None 

1296 return jnp.asarray(x) 

1297 

1298 

1299class _ReductionConfig(TypedDict): 

1300 """Fields of `StepConfig` related to reductions.""" 

1301 

1302 resid_reduction_config: ReductionConfig 

1303 count_reduction_config: ReductionConfig 

1304 prec_reduction_config: ReductionConfig 

1305 prec_count_num_trees: int | None 

1306 

1307 

1308def _parse_reduction_configs( 

1309 resid_reduction_config: ReductionConfig, 

1310 count_reduction_config: ReductionConfig, 

1311 prec_reduction_config: ReductionConfig, 

1312 prec_count_num_trees: int | None | Literal['auto'], 

1313 y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'], 

1314 num_trees: int, 

1315 num_chains: int | None, 

1316 mesh: Mesh | None, 

1317) -> _ReductionConfig: 

1318 """Determine settings for indexed reduces.""" 

1319 n = y.shape[-1] 

1320 n //= get_axis_size(mesh, 'data') # per-device datapoints 

1321 # chains are vmapped together on each device, so they share the per-step 

1322 # memory of the per-tree reduction 

1323 chains_per_device = (num_chains or 1) // get_axis_size(mesh, 'chains') 

1324 # the reduction configs carry their own datapoint-batch settings (resolved 

1325 # per-platform at run time when 'auto', see `ReductionConfig`), so they are 

1326 # stored verbatim; only `prec_count_num_trees`, which does not depend on the 

1327 # platform, is resolved here 

1328 return dict( 

1329 resid_reduction_config=resid_reduction_config, 

1330 count_reduction_config=count_reduction_config, 

1331 prec_reduction_config=prec_reduction_config, 

1332 prec_count_num_trees=_parse_prec_count_num_trees( 

1333 prec_count_num_trees, num_trees, n * chains_per_device 

1334 ), 

1335 ) 

1336 

1337 

1338def _parse_prec_count_num_trees( 

1339 prec_count_num_trees: int | None | Literal['auto'], num_trees: int, n: int 

1340) -> int | None: 

1341 """Return the number of trees to process at a time or determine it automatically.""" 

1342 if prec_count_num_trees != 'auto': 

1343 return prec_count_num_trees 

1344 max_n_by_ntree = 2**27 # about 100M 

1345 pcnt = max_n_by_ntree // max(1, n) 

1346 pcnt = min(num_trees, pcnt) 

1347 pcnt = max(1, pcnt) 

1348 pcnt = _search_divisor( 

1349 pcnt, num_trees, max(1, pcnt // 2), max(1, min(num_trees, pcnt * 2)) 

1350 ) 

1351 if pcnt >= num_trees: 1351 ↛ 1353line 1351 didn't jump to line 1353 because the condition on line 1351 was always true

1352 pcnt = None 

1353 return pcnt 

1354 

1355 

1356def _search_divisor(target_divisor: int, dividend: int, low: int, up: int) -> int: 

1357 """Find the divisor closest to `target_divisor` in [low, up] if `target_divisor` is not already. 

1358 

1359 If there is none, give up and return `target_divisor`. 

1360 """ 

1361 assert target_divisor >= 1 

1362 assert 1 <= low <= up <= dividend 

1363 if dividend % target_divisor == 0: 

1364 return target_divisor 

1365 candidates = numpy.arange(low, up + 1) 

1366 divisors = candidates[dividend % candidates == 0] 

1367 if divisors.size == 0: 

1368 return target_divisor 

1369 penalty = numpy.abs(divisors - target_divisor) 

1370 closest = numpy.argmin(penalty) 

1371 return divisors[closest].item() 

1372 

1373 

1374def get_axis_size(mesh: Mesh | None, axis_name: str) -> int: 

1375 if mesh is None or axis_name not in mesh.shape: 

1376 return 1 

1377 else: 

1378 return mesh.shape[axis_name] 

1379 

1380 

1381def chol_with_gersh( 

1382 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool = False 

1383) -> Float32[Array, '*batch_shape k k']: 

1384 """Cholesky with Gershgorin stabilization, supports batching.""" 

1385 return _chol_with_gersh_impl(mat, absolute_eps) 

1386 

1387 

1388@partial(jnp.vectorize, signature='(k,k)->(k,k)', excluded=(1,)) 

1389def _chol_with_gersh_impl( 

1390 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool 

1391) -> Float32[Array, '*batch_shape k k']: 

1392 rho = jnp.max(jnp.sum(jnp.abs(mat), axis=1), initial=0.0) 

1393 eps = jnp.finfo(mat.dtype).eps 

1394 u = mat.shape[0] * rho * eps 

1395 if absolute_eps: 

1396 u += eps 

1397 mat = mat.at[jnp.diag_indices_from(mat)].add(u) 

1398 return jnp.linalg.cholesky(mat) 

1399 

1400 

1401def _inv_via_chol_with_gersh( 

1402 mat: Float32[Array, '*batch_shape k k'], 

1403) -> Float32[Array, '*batch_shape k k']: 

1404 """Compute matrix inverse via Cholesky with Gershgorin stabilization. 

1405 

1406 DO NOT USE THIS FUNCTION UNLESS YOU REALLY NEED TO. 

1407 """ 

1408 # mat = L L^T 

1409 # mat^-1 = L^-T L^-1 = L^-T I L^-1 = L^-T (L^-T I)^T 

1410 # I suspect this to be more accurate than (L^-1 I)^T (L^-1 I) 

1411 L = chol_with_gersh(mat) 

1412 eye = jnp.broadcast_to(jnp.eye(mat.shape[-1]), mat.shape) 

1413 Ltinv = solve_triangular(L, eye, trans='T', lower=True) 

1414 return solve_triangular(L, Ltinv.mT, trans='T', lower=True) 

1415 

1416 

1417def split_key_for_chains( 

1418 fun: Callable[[Key[Array, ''] | Key[Array, ' num_chains'], State], State], 

1419) -> Callable[[Key[Array, ''], State], State]: 

1420 """Split a single PRNG key into per-chain keys before calling `fun`. 

1421 

1422 When the state is multichain, the input key is split into 

1423 ``state.num_chains()`` keys. For single-chain states, the key is passed 

1424 through unchanged. 

1425 """ 

1426 

1427 @wraps(fun) 

1428 def wrapped(key: Key[Array, ''], state: State) -> State: 

1429 num_chains = state.num_chains() 

1430 if num_chains is None: 

1431 return fun(key, state) 

1432 keys = random.split(key, num_chains) 

1433 return fun(keys, state) 

1434 

1435 return wrapped 

1436 

1437 

1438def partition_specs(x: PyTree, mesh: Mesh) -> PyTree[PartitionSpec]: 

1439 """Per-leaf `PartitionSpec`s derived from chain/data `field` markers. 

1440 

1441 Each array leaf is sharded over ``'chains'`` along its chain axis and over 

1442 ``'data'`` along its data axis, when those axes are marked (see `field`) 

1443 and present in `mesh`; all other axes are replicated. 

1444 

1445 Parameters 

1446 ---------- 

1447 x 

1448 A pytree of arrays carrying chain/data `field` markers. 

1449 mesh 

1450 The device mesh to shard over. 

1451 

1452 Returns 

1453 ------- 

1454 A pytree matching `x` with a `PartitionSpec` in place of each array leaf. 

1455 """ 

1456 return tree.map( 

1457 lambda leaf, ca, da: _leaf_partition_spec(leaf.ndim, ca, da, mesh), 

1458 x, 

1459 chain_vmap_axes(x), 

1460 data_vmap_axes(x), 

1461 ) 

1462 

1463 

1464def shard_map_state( 

1465 fun: Callable[[Key[Array, ''] | Key[Array, ' num_chains'], State], State], 

1466) -> Callable[[Key[Array, ''] | Key[Array, ' num_chains'], State], State]: 

1467 """Wrap a ``(keys, state) -> state`` function in a manual `jax.shard_map`. 

1468 

1469 Uses `state.config.mesh` (static). No-op when the mesh is `None`. The keys 

1470 input is sharded across ``'chains'`` when the state is multichain and 

1471 ``'chains'`` is in the mesh; otherwise the keys are replicated. State 

1472 leaves are sharded according to their `chains`/`data` field metadata. The 

1473 output sharding matches the input sharding. 

1474 """ 

1475 

1476 @wraps(fun) 

1477 def wrapped(key: Key[Array, ''] | Key[Array, ' num_chains'], state: State) -> State: 

1478 mesh = state.config.mesh 

1479 if mesh is None: 

1480 return fun(key, state) 

1481 

1482 if state.has_chains and 'chains' in mesh.axis_names: 

1483 key_spec = PartitionSpec('chains') 

1484 else: 

1485 key_spec = PartitionSpec() 

1486 

1487 state_specs = partition_specs(state, mesh) 

1488 

1489 mapped = shard_map( 

1490 fun, 

1491 mesh=mesh, 

1492 in_specs=(key_spec, state_specs), 

1493 out_specs=state_specs, 

1494 **_get_shard_map_patch_kwargs(), 

1495 ) 

1496 return mapped(key, state) 

1497 

1498 return wrapped 

1499 

1500 

1501def vmap_chains( 

1502 fun: Callable[[Key[Array, ''], State], State], 

1503) -> Callable[[Key[Array, ' num_chains'] | Key[Array, ''], State], State]: 

1504 """Vmap a ``(key, state) -> state`` function over chain axes. 

1505 

1506 When the state is multichain, `keys` must have a leading chain axis and 

1507 `fun` is vmapped over it together with the chain axes of `state`. For 

1508 single-chain states, the function is called unchanged. 

1509 """ 

1510 

1511 @wraps(fun) 

1512 def wrapped( 

1513 keys: Key[Array, ' num_chains'] | Key[Array, ''], state: State 

1514 ) -> State: 

1515 if not state.has_chains: 

1516 return fun(keys, state) 

1517 state_axes = chain_vmap_axes(state) 

1518 vmapped_fun = vmap(fun, in_axes=(0, state_axes), out_axes=state_axes) 

1519 return vmapped_fun(keys, state) 

1520 

1521 return wrapped 

1522 

1523 

1524class _ShardMapPatchKwargs(TypedDict, total=False): 

1525 check_vma: bool 

1526 

1527 

1528def _get_shard_map_patch_kwargs() -> _ShardMapPatchKwargs: 

1529 # bug: jax 0.8.1-0.8.2: vmap(shard_map(psum)), jax#34249; the 

1530 # jax_disable_vmap_shmap_error config did not work. 

1531 

1532 # bug: jax 0.6.2: `random.poisson`'s internal `while_loop` (used by 

1533 # `sample_s_augmentation`) does not `pvary` its initial carry, so its 

1534 # output type varies over 'chains' while its input does not, whenever 

1535 # the rate argument does. 

1536 

1537 # WORKAROUND(jax<=0.8.2): remove this whole function when jax > 0.8.2 

1538 buggy = ('0.8.1', '0.8.2', '0.6.2') 

1539 if jax.__version__ in buggy: 

1540 return {'check_vma': False} 

1541 return {}