Coverage for src / bartz / mcmcstep / _state.py: 95%

454 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 18:11 +0000

1# bartz/src/bartz/mcmcstep/_state.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Module defining the BART MCMC state and initialization.""" 

26 

27from collections.abc import Callable, Hashable, Sequence 

28from dataclasses import fields, replace 

29from enum import Enum 

30from functools import partial, wraps 

31from math import log2 

32from typing import Any, Literal, TypedDict, TypeVar 

33 

34import numpy 

35from equinox import Module, error_if, filter_jit 

36from equinox import field as eqx_field 

37from jax import ( 

38 NamedSharding, 

39 device_put, 

40 eval_shape, 

41 jit, 

42 lax, 

43 make_mesh, 

44 random, 

45 tree, 

46 vmap, 

47) 

48from jax import numpy as jnp 

49from jax.scipy.linalg import solve_triangular 

50from jax.sharding import AxisType, Mesh, PartitionSpec 

51from jaxtyping import Array, Bool, Float32, Int32, Integer, PyTree, Shaped, UInt 

52 

53from bartz.grove import tree_depths 

54from bartz.jaxext import get_default_device, is_key, minimal_unsigned_dtype 

55 

56 

57class OutcomeType(Enum): 

58 """Whether the regression outcome is continuous or binary (probit).""" 

59 

60 continuous = 'continuous' 

61 binary = 'binary' 

62 

63 

64def field(*, chains: bool = False, data: bool = False, **kwargs: Any): # noqa: ANN202 

65 """Extend `equinox.field` with two new parameters. 

66 

67 Parameters 

68 ---------- 

69 chains 

70 Whether the arrays in the field have an optional first axis that 

71 represents independent Markov chains. 

72 data 

73 Whether the last axis of the arrays in the field represent units of 

74 the data. 

75 **kwargs 

76 Other parameters passed to `equinox.field`. 

77 

78 Returns 

79 ------- 

80 A dataclass field descriptor with the special attributes in the metadata, unset if False. 

81 """ 

82 metadata = dict(kwargs.pop('metadata', {})) 

83 assert 'chains' not in metadata 

84 assert 'data' not in metadata 

85 if chains: 

86 metadata['chains'] = True 

87 if data: 

88 metadata['data'] = True 

89 return eqx_field(metadata=metadata, **kwargs) 

90 

91 

92def chain_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']: 

93 """Determine vmapping axes for chains. 

94 

95 This function determines the argument to the `in_axes` or `out_axes` 

96 parameter of `jax.vmap` to vmap over all and only the chain axes found in the 

97 pytree `x`. 

98 

99 Parameters 

100 ---------- 

101 x 

102 A pytree. Subpytrees that are Module attributes marked with 

103 ``field(..., chains=True)`` are considered to have a leading chain axis. 

104 

105 Returns 

106 ------- 

107 A pytree with the same structure as `x` with 0 or None in the leaves. 

108 """ 

109 return _find_metadata(x, 'chains', 0, None) 1ab

110 

111 

112def data_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']: 

113 """Determine vmapping axes for data. 

114 

115 This is analogous to `chain_vmap_axes` but returns -1 for all fields 

116 marked with ``field(..., data=True)``. 

117 """ 

118 return _find_metadata(x, 'data', -1, None) 1ab

119 

120 

121T = TypeVar('T') 

122 

123 

124def _find_metadata( 

125 x: PyTree[Any, ' S'], key: Hashable, if_true: T, if_false: T 

126) -> PyTree[T, ' S']: 

127 """Replace all subtrees of x marked with a metadata key.""" 

128 

129 def is_lazy_array(x: object) -> bool: 1ab

130 return isinstance(x, _LazyArray) 1ab

131 

132 def is_module(x: object) -> bool: 1ab

133 return isinstance(x, Module) and not is_lazy_array(x) 1ab

134 

135 if is_module(x): 1ab

136 args = [] 1ab

137 for f in fields(x): 1ab

138 v = getattr(x, f.name) 1ab

139 if f.metadata.get('static', False): 1ab

140 args.append(v) 1ab

141 elif f.metadata.get(key, False): 1ab

142 subtree = tree.map(lambda _: if_true, v, is_leaf=is_lazy_array) 1ab

143 args.append(subtree) 1ab

144 else: 

145 args.append(_find_metadata(v, key, if_true, if_false)) 1ab

146 return x.__class__(*args) 1ab

147 

148 def get_axes(x: object) -> PyTree[T]: 1ab

149 if is_module(x): 1aeb

150 return _find_metadata(x, key, if_true, if_false) 1ae

151 else: 

152 return tree.map(lambda _: if_false, x, is_leaf=is_lazy_array) 1ab

153 

154 def is_leaf(x: object) -> bool: 1ab

155 return isinstance(x, Module) # this catches _LazyArray as well 1ab

156 

157 return tree.map(get_axes, x, is_leaf=is_leaf) 1ab

158 

159 

160class Forest(Module): 

161 """Represents the MCMC state of a sum of trees.""" 

162 

163 leaf_tree: ( 

164 Float32[Array, '*chains num_trees 2**d'] 

165 | Float32[Array, '*chains num_trees k 2**d'] 

166 ) = field(chains=True) 

167 """The leaf values.""" 

168 

169 var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

170 """The decision axes.""" 

171 

172 split_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

173 """The decision boundaries.""" 

174 

175 affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

176 """Marks leaves that can be grown.""" 

177 

178 max_split: UInt[Array, ' p'] 

179 """The maximum split index for each predictor.""" 

180 

181 blocked_vars: UInt[Array, ' q'] | None 

182 """Indices of variables that are not used. This shall include at least 

183 the `i` such that ``max_split[i] == 0``, otherwise behavior is 

184 undefined.""" 

185 

186 p_nonterminal: Float32[Array, ' 2**d'] 

187 """The prior probability of each node being nonterminal, conditional on 

188 its ancestors. Includes the nodes at maximum depth which should be set 

189 to 0.""" 

190 

191 p_propose_grow: Float32[Array, ' 2**(d-1)'] 

192 """The unnormalized probability of picking a leaf for a grow proposal.""" 

193 

194 leaf_indices: UInt[Array, '*chains num_trees n'] = field(chains=True, data=True) 

195 """The index of the leaf each datapoints falls into, for each tree.""" 

196 

197 min_points_per_decision_node: Int32[Array, ''] | None 

198 """The minimum number of data points in a decision node.""" 

199 

200 min_points_per_leaf: Int32[Array, ''] | None 

201 """The minimum number of data points in a leaf node.""" 

202 

203 log_trans_prior: Float32[Array, '*chains num_trees'] | None = field(chains=True) 

204 """The log transition and prior Metropolis-Hastings ratio for the 

205 proposed move on each tree.""" 

206 

207 log_likelihood: Float32[Array, '*chains num_trees'] | None = field(chains=True) 

208 """The log likelihood ratio.""" 

209 

210 grow_prop_count: Int32[Array, '*chains'] = field(chains=True) 

211 """The number of grow proposals made during one full MCMC cycle.""" 

212 

213 prune_prop_count: Int32[Array, '*chains'] = field(chains=True) 

214 """The number of prune proposals made during one full MCMC cycle.""" 

215 

216 grow_acc_count: Int32[Array, '*chains'] = field(chains=True) 

217 """The number of grow moves accepted during one full MCMC cycle.""" 

218 

219 prune_acc_count: Int32[Array, '*chains'] = field(chains=True) 

220 """The number of prune moves accepted during one full MCMC cycle.""" 

221 

222 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None 

223 """The prior precision matrix of a leaf, conditional on the tree structure. 

224 For the univariate case (k=1), this is a scalar (the inverse variance). 

225 The prior covariance of the sum of trees is 

226 ``num_trees * leaf_prior_cov_inv^-1``.""" 

227 

228 log_s: Float32[Array, '*chains p'] | None = field(chains=True) 

229 """The logarithm of the prior probability for choosing a variable to split 

230 along in a decision rule, conditional on the ancestors. Not normalized. 

231 If `None`, use a uniform distribution.""" 

232 

233 theta: Float32[Array, '*chains'] | None = field(chains=True) 

234 """The concentration parameter for the Dirichlet prior on the variable 

235 distribution `s`. Required only to update `log_s`.""" 

236 

237 a: Float32[Array, ''] | None 

238 """Parameter of the prior on `theta`. Required only to sample `theta`. 

239 See `step_theta`.""" 

240 

241 b: Float32[Array, ''] | None 

242 """Parameter of the prior on `theta`. Required only to sample `theta`. 

243 See `step_theta`.""" 

244 

245 rho: Float32[Array, ''] | None 

246 """Parameter of the prior on `theta`. Required only to sample `theta`. 

247 See `step_theta`.""" 

248 

249 def num_chains(self) -> int | None: 

250 """Return the number of chains, or `None` if not multichain.""" 

251 # maybe this should be replaced by chain_shape() -> () | (int,) 

252 if self.var_tree.ndim == 2: 1areo

253 return None 1ro

254 else: 

255 return self.var_tree.shape[0] 1ae

256 

257 

258class StepConfig(Module): 

259 """Options for the MCMC step.""" 

260 

261 steps_done: Int32[Array, ''] 

262 """The number of MCMC steps completed so far.""" 

263 

264 sparse_on_at: Int32[Array, ''] | None 

265 """After how many steps to turn on variable selection.""" 

266 

267 resid_num_batches: int | None = field(static=True) 

268 """The number of batches for computing the sum of residuals. If 

269 `None`, they are computed with no batching.""" 

270 

271 count_num_batches: int | None = field(static=True) 

272 """The number of batches for computing counts. If 

273 `None`, they are computed with no batching.""" 

274 

275 prec_num_batches: int | None = field(static=True) 

276 """The number of batches for computing precision scales. If 

277 `None`, they are computed with no batching.""" 

278 

279 prec_count_num_trees: int | None = field(static=True) 

280 """Batch size for processing trees to compute count and prec trees.""" 

281 

282 mesh: Mesh | None = field(static=True) 

283 """The mesh used to shard data and computation across multiple devices.""" 

284 

285 

286class State(Module): 

287 """Represents the MCMC state of BART.""" 

288 

289 X: UInt[Array, 'p n'] = field(data=True) 

290 """The predictors.""" 

291 

292 binary_y: None | Bool[Array, ' n'] | Bool[Array, 'k n'] = field(data=True) 

293 """The response as booleans for binary regression, `None` for continuous. 

294 In the mixed binary-continuous case, only the binary outcome components 

295 are stored, with shape ``(kb, n)``.""" 

296 

297 z: None | Float32[Array, '*chains n'] | Float32[Array, '*chains k n'] = field( 

298 chains=True, data=True 

299 ) 

300 """The latent variable for binary regression. `None` in continuous 

301 regression. In the mixed binary-continuous case, only the binary outcome 

302 components are stored, with shape ``(*chains, kb, n)``.""" 

303 

304 binary_indices: None | Int32[Array, ' kb'] 

305 """The indices of binary outcome components in the full list of outcome 

306 components. `None` when there are no binary components. Filled in by 

307 `init` and used by `step_z` to update only the binary rows of `resid`.""" 

308 

309 offset: Float32[Array, ''] | Float32[Array, ' k'] 

310 """Constant shift added to the sum of trees.""" 

311 

312 resid: Float32[Array, '*chains n'] | Float32[Array, '*chains k n'] = field( 

313 chains=True, data=True 

314 ) 

315 """The residuals (`y` or `z` minus sum of trees).""" 

316 

317 error_cov_inv: Float32[Array, '*chains'] | Float32[Array, '*chains k k'] = field( 

318 chains=True 

319 ) 

320 """The inverse error covariance (scalar for univariate, matrix for multivariate). 

321 Identity in binary regression.""" 

322 

323 prec_scale: Float32[Array, ' n'] | None = field(data=True) 

324 """The scale on the error precision, i.e., ``1 / error_scale ** 2``. 

325 `None` in binary regression.""" 

326 

327 error_cov_df: Float32[Array, ''] | None 

328 """The df parameter of the inverse Wishart prior on the noise 

329 covariance. For the univariate case, the relationship to the inverse 

330 gamma prior parameters is ``alpha = df / 2``. 

331 `None` in binary regression.""" 

332 

333 error_cov_scale: Float32[Array, ''] | Float32[Array, 'k k'] | None 

334 """The scale parameter of the inverse Wishart prior on the noise 

335 covariance. For the univariate case, the relationship to the inverse 

336 gamma prior parameters is ``beta = scale / 2``. 

337 `None` in binary regression.""" 

338 

339 forest: Forest 

340 """The sum of trees model.""" 

341 

342 config: StepConfig 

343 """Metadata and configurations for the MCMC step.""" 

344 

345 

346def _init_shape_shifting_parameters( 

347 y: Float32[Array, ' n'] | Float32[Array, 'k n'], 

348 outcome_type: OutcomeType | list[OutcomeType], 

349 offset: Float32[Array, ''] | Float32[Array, ' k'], 

350 error_scale: Float32[Any, ' n'] | None, 

351 error_cov_df: float | Float32[Any, ''] | None, 

352 error_cov_scale: float | Float32[Any, ''] | Float32[Any, 'k k'] | None, 

353 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

354) -> tuple[ 

355 bool, 

356 tuple[()] | tuple[int], 

357 None | Float32[Array, ''], 

358 None | Float32[Array, ''], 

359 None | Float32[Array, ''], 

360 None | Int32[Array, ' kb'], 

361]: 

362 """ 

363 Check and initialize parameters that change array type/shape based on outcome kind. 

364 

365 Parameters 

366 ---------- 

367 y 

368 The response variable (used only for shape checks). 

369 outcome_type 

370 Whether the regression is continuous or binary. Can be a list of 

371 `OutcomeType` for per-component specification in the multivariate case. 

372 offset 

373 The offset to add to the predictions. 

374 error_scale 

375 Per-observation error scale (univariate only). 

376 error_cov_df 

377 The error covariance degrees of freedom. 

378 error_cov_scale 

379 The error covariance scale. 

380 leaf_prior_cov_inv 

381 The inverse of the leaf prior covariance. 

382 

383 Returns 

384 ------- 

385 is_binary 

386 Whether all outcomes are binary. 

387 kshape 

388 The outcome shape, empty for univariate, (k,) for multivariate. 

389 error_cov_inv 

390 The initialized error covariance inverse. 

391 error_cov_df 

392 The error covariance degrees of freedom (as array). 

393 error_cov_scale 

394 The error covariance scale (as array). 

395 binary_indices 

396 The indices of binary outcome components, or `None` if there are none. 

397 """ 

398 kshape = offset.shape 1ab

399 

400 # determine per-component outcome kinds 

401 if isinstance(outcome_type, list): 1afgbh

402 assert kshape, 'per-component outcome_type requires multivariate y' 1fgh

403 (k,) = kshape 1fgh

404 assert len(outcome_type) == k 1fguh

405 binary_mask = [t is OutcomeType.binary for t in outcome_type] 1fguh

406 is_binary = all(binary_mask) 1fgh

407 is_mixed = any(binary_mask) and not is_binary 1fguh

408 else: 

409 is_binary = outcome_type is OutcomeType.binary 1ab

410 is_mixed = False 1ab

411 

412 if is_mixed: 1afgbuh

413 binary_indices = jnp.array([i for i, b in enumerate(binary_mask) if b]) 1fgh

414 else: 

415 binary_indices = None 1ab

416 

417 # All-binary 

418 if is_binary: 1adfgbhk

419 assert error_scale is None 1dk

420 assert error_cov_df is None 1dk

421 assert error_cov_scale is None 1dk

422 if kshape: 1dAk

423 error_cov_inv = jnp.eye(kshape[0]) 1A

424 else: 

425 error_cov_inv = jnp.array(1.0) 1dk

426 

427 # Mixed binary-continuous (multivariate, diagonal error covariance) 

428 elif is_mixed: 1afgbh

429 assert error_scale is None, ( 1fgBh

430 'error_scale is not supported for mixed binary-continuous' 

431 ) 

432 error_cov_df = jnp.asarray(error_cov_df) 1fgBh

433 error_cov_scale = jnp.asarray(error_cov_scale) 1fgh

434 assert error_cov_scale.shape == 2 * kshape 1fgh

435 

436 # enforce diagonal error_cov_scale 

437 diag = jnp.diag(jnp.diag(error_cov_scale)) 1fgh

438 error_cov_scale = error_if( 1fgh

439 error_cov_scale, 

440 jnp.any(error_cov_scale != diag), 

441 'error_cov_scale must be diagonal for mixed binary-continuous', 

442 ) 

443 

444 # initialize diagonal error_cov_inv: use inv-gamma mode for continuous 

445 # components, 1.0 for binary components 

446 scale_diag = jnp.diag(error_cov_scale) 1fgh

447 inv_diag = jnp.where( 1fgh

448 jnp.array(binary_mask), 

449 1.0, 

450 error_cov_df / jnp.where(scale_diag, scale_diag, 1.0), 

451 ) 

452 error_cov_inv = jnp.diag(inv_diag) 1fgh

453 

454 # All-continuous 

455 else: 

456 error_cov_df = jnp.asarray(error_cov_df) 1ab

457 error_cov_scale = jnp.asarray(error_cov_scale) 1ab

458 assert error_cov_scale.shape == 2 * kshape 1ab

459 

460 # Multivariate vs univariate 

461 if kshape: 1almb

462 error_cov_inv = error_cov_df * _inv_via_chol_with_gersh(error_cov_scale) 1lm

463 else: 

464 # inverse gamma prior: alpha = df / 2, beta = scale / 2 

465 error_cov_inv = error_cov_df / error_cov_scale 1ab

466 

467 assert y.shape[:-1] == kshape 1ab

468 assert leaf_prior_cov_inv.shape == 2 * kshape 1ab

469 

470 return ( 1ab

471 is_binary, 

472 kshape, 

473 error_cov_inv, 

474 error_cov_df, 

475 error_cov_scale, 

476 binary_indices, 

477 ) 

478 

479 

480def _check_splitless_vars( 

481 filter_splitless_vars: int, 

482 max_split: UInt[Array, ' p'], 

483 offset: Float32[Array, ''] | Float32[Array, ' k'], 

484) -> Float32[Array, ''] | Float32[Array, ' k']: 

485 """Check there aren't too many deactivated predictors.""" 

486 msg = ( 1ab

487 f'there are more than {filter_splitless_vars=} predictors with no splits, ' 

488 'please increase `filter_splitless_vars` or investigate the missing splits' 

489 ) 

490 return error_if(offset, jnp.sum(max_split == 0) > filter_splitless_vars, msg) 1ab

491 

492 

493def _parse_outcome_type( 

494 outcome_type: 'OutcomeType | str | Sequence[OutcomeType | str]', 

495) -> 'OutcomeType | list[OutcomeType]': 

496 """Normalize outcome_type to enum (or list of enums).""" 

497 if isinstance(outcome_type, Sequence) and not isinstance(outcome_type, str): 1afpgqbh

498 return [OutcomeType(t) for t in outcome_type] 1fgh

499 else: 

500 return OutcomeType(outcome_type) 1apqb

501 

502 

503def _parse_p_nonterminal( 

504 p_nonterminal: Float32[Any, ' d_minus_1'], 

505) -> Float32[Array, ' d_minus_1+1']: 

506 """Check it's in (0, 1) and pad with a 0 at the end.""" 

507 p_nonterminal = jnp.asarray(p_nonterminal) 1ab

508 ok = (p_nonterminal > 0) & (p_nonterminal < 1) 1ab

509 p_nonterminal = error_if(p_nonterminal, ~ok, 'p_nonterminal must be in (0, 1)') 1ab

510 return jnp.pad(p_nonterminal, (0, 1)) 1ab

511 

512 

513def make_p_nonterminal( 

514 d: int, 

515 alpha: float | Float32[Array, ''] = 0.95, 

516 beta: float | Float32[Array, ''] = 2.0, 

517) -> Float32[Array, ' {d}-1']: 

518 """Prepare the `p_nonterminal` argument to `init`. 

519 

520 It is calculated according to the formula: 

521 

522 P_nt(depth) = alpha / (1 + depth)^beta, with depth 0-based 

523 

524 Parameters 

525 ---------- 

526 d 

527 The maximum depth of the trees (d=1 means tree with only root node) 

528 alpha 

529 The a priori probability of the root node having children, conditional 

530 on it being possible 

531 beta 

532 The exponent of the power decay of the probability of having children 

533 with depth. 

534 

535 Returns 

536 ------- 

537 An array of probabilities, one per tree level but the last. 

538 """ 

539 assert d >= 1 1ab

540 depth = jnp.arange(d - 1) 1ab

541 return alpha / (1 + depth).astype(float) ** beta 1ab

542 

543 

544class _LazyArray(Module): 

545 """Like `functools.partial` but specialized to array-creating functions like `jax.numpy.zeros`.""" 

546 

547 array_creator: Callable 

548 shape: tuple[int, ...] 

549 args: tuple 

550 

551 def __init__( 

552 self, array_creator: Callable, shape: tuple[int, ...], *args: Any 

553 ) -> None: 

554 self.array_creator = array_creator 1ab

555 self.shape = shape 1ab

556 self.args = args 1ab

557 

558 def __call__(self, **kwargs: Any) -> T: 

559 return self.array_creator(self.shape, *self.args, **kwargs) 1ab

560 

561 @property 

562 def ndim(self) -> int: 

563 return len(self.shape) 1jdi

564 

565 

566def init( 

567 *, 

568 X: UInt[Any, 'p n'], 

569 y: Float32[Any, ' n'] | Float32[Any, ' k n'], 

570 outcome_type: OutcomeType | str | Sequence[OutcomeType | str] = 'continuous', 

571 offset: float | Float32[Any, ''] | Float32[Any, ' k'], 

572 max_split: UInt[Any, ' p'], 

573 num_trees: int, 

574 p_nonterminal: Float32[Any, ' d_minus_1'], 

575 leaf_prior_cov_inv: float | Float32[Any, ''] | Float32[Array, 'k k'], 

576 error_cov_df: float | Float32[Any, ''] | None = None, 

577 error_cov_scale: float | Float32[Any, ''] | Float32[Array, 'k k'] | None = None, 

578 error_scale: Float32[Any, ' n'] | None = None, 

579 min_points_per_decision_node: int | Integer[Any, ''] | None = None, 

580 resid_num_batches: int | None | Literal['auto'] = 'auto', 

581 count_num_batches: int | None | Literal['auto'] = 'auto', 

582 prec_num_batches: int | None | Literal['auto'] = 'auto', 

583 prec_count_num_trees: int | None | Literal['auto'] = 'auto', 

584 save_ratios: bool = False, 

585 filter_splitless_vars: int = 0, 

586 min_points_per_leaf: int | Integer[Any, ''] | None = None, 

587 log_s: Float32[Any, ' p'] | None = None, 

588 theta: float | Float32[Any, ''] | None = None, 

589 a: float | Float32[Any, ''] | None = None, 

590 b: float | Float32[Any, ''] | None = None, 

591 rho: float | Float32[Any, ''] | None = None, 

592 sparse_on_at: int | Integer[Any, ''] | None = None, 

593 num_chains: int | None = None, 

594 mesh: Mesh | dict[str, int] | None = None, 

595 target_platform: Literal['cpu', 'gpu'] | None = None, 

596) -> State: 

597 """ 

598 Make a BART posterior sampling MCMC initial state. 

599 

600 Parameters 

601 ---------- 

602 X 

603 The predictors. Note this is trasposed compared to the usual convention. 

604 y 

605 The response. If two-dimensional, the outcome is multivariate with the 

606 first axis indicating the component. For binary data, non-zero means 1, 

607 zero means 0. 

608 outcome_type 

609 Whether the regression is continuous or binary (probit). Can also be a 

610 sequence of `OutcomeType` values, one per outcome component, for mixed 

611 binary-continuous multivariate regression. 

612 offset 

613 Constant shift added to the sum of trees. 0 if not specified. 

614 max_split 

615 The maximum split index for each variable. All split ranges start at 1. 

616 num_trees 

617 The number of trees in the forest. 

618 p_nonterminal 

619 The probability of a nonterminal node at each depth. The maximum depth 

620 of trees is fixed by the length of this array. Use `make_p_nonterminal` 

621 to set it with the conventional formula. 

622 leaf_prior_cov_inv 

623 The prior precision matrix of a leaf, conditional on the tree structure. 

624 For the univariate case (k=1), this is a scalar (the inverse variance). 

625 The prior covariance of the sum of trees is 

626 ``num_trees * leaf_prior_cov_inv^-1``. The prior mean of leaves is 

627 always zero. 

628 error_cov_df 

629 error_cov_scale 

630 The df and scale parameters of the inverse Wishart prior on the error 

631 covariance. For the univariate case, the relationship to the inverse 

632 gamma prior parameters is ``alpha = df / 2``, ``beta = scale / 2``. 

633 Leave unspecified for binary regression. 

634 error_scale 

635 Each error is scaled by the corresponding factor in `error_scale`, so 

636 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``. 

637 Not supported for binary regression. If not specified, defaults to 1 for 

638 all points, but potentially skipping calculations. 

639 min_points_per_decision_node 

640 The minimum number of data points in a decision node. 0 if not 

641 specified. 

642 resid_num_batches 

643 count_num_batches 

644 prec_num_batches 

645 The number of batches, along datapoints, for summing the residuals, 

646 counting the number of datapoints in each leaf, and computing the 

647 likelihood precision in each leaf, respectively. `None` for no batching. 

648 If 'auto', it's chosen automatically based on the target platform; see 

649 the description of `target_platform` below for how it is determined. 

650 prec_count_num_trees 

651 The number of trees to process at a time when counting datapoints or 

652 computing the likelihood precision. If `None`, do all trees at once, 

653 which may use too much memory. If 'auto' (default), it's chosen 

654 automatically. 

655 save_ratios 

656 Whether to save the Metropolis-Hastings ratios. 

657 filter_splitless_vars 

658 The maximum number of variables without splits that can be ignored. If 

659 there are more, `init` raises an exception. 

660 min_points_per_leaf 

661 The minimum number of datapoints in a leaf node. 0 if not specified. 

662 Unlike `min_points_per_decision_node`, this constraint is not taken into 

663 account in the Metropolis-Hastings ratio because it would be expensive 

664 to compute. Grow moves that would violate this constraint are vetoed. 

665 This parameter is independent of `min_points_per_decision_node` and 

666 there is no check that they are coherent. It makes sense to set 

667 ``min_points_per_decision_node >= 2 * min_points_per_leaf``. 

668 log_s 

669 The logarithm of the prior probability for choosing a variable to split 

670 along in a decision rule, conditional on the ancestors. Not normalized. 

671 If not specified, use a uniform distribution. If not specified and 

672 `theta` or `rho`, `a`, `b` are, it's initialized automatically. 

673 theta 

674 The concentration parameter for the Dirichlet prior on `s`. Required 

675 only to update `log_s`. If not specified, and `rho`, `a`, `b` are 

676 specified, it's initialized automatically. 

677 a 

678 b 

679 rho 

680 Parameters of the prior on `theta`. Required only to sample `theta`. 

681 sparse_on_at 

682 After how many MCMC steps to turn on variable selection. 

683 num_chains 

684 The number of independent MCMC chains to represent in the state. Single 

685 chain with scalar values if not specified. 

686 mesh 

687 A jax mesh used to shard data and computation across multiple devices. 

688 If it has a 'chains' axis, that axis is used to shard the chains. If it 

689 has a 'data' axis, that axis is used to shard the datapoints. 

690 

691 As a shorthand, if a dictionary mapping axis names to axis size is 

692 passed, the corresponding mesh is created, e.g., ``dict(chains=4, 

693 data=2)`` will let jax pick 8 devices to split chains (which must be a 

694 multiple of 4) across 4 pairs of devices, where in each pair the data is 

695 split in two. 

696 

697 Note: if a mesh is passed, the arrays are always sharded according to 

698 it. In particular even if the mesh has no 'chains' or 'data' axis, the 

699 arrays will be replicated on all devices in the mesh. 

700 target_platform 

701 Platform ('cpu' or 'gpu') used to determine the number of batches 

702 automatically. If `mesh` is specified, the platform is inferred from the 

703 devices in the mesh. Otherwise, if `y` is a concrete array (i.e., `init` 

704 is not invoked in a `jax.jit` context), the platform is set to the 

705 platform of `y`. Otherwise, use `target_platform`. 

706 

707 To avoid confusion, in all cases where the `target_platform` argument 

708 would be ignored, `init` raises an exception if `target_platform` is 

709 set. 

710 

711 Returns 

712 ------- 

713 An initialized BART MCMC state. 

714 

715 Raises 

716 ------ 

717 ValueError 

718 If arguments unused in binary regression are set. 

719 

720 Notes 

721 ----- 

722 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out 

723 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left 

724 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be 

725 integers in the range ``[0, 1, ..., max_split[i]]``. 

726 

727 In general the arrays passed to this function as arguments may be donated, 

728 invalidating them. Create copies before passing them to `init` if this 

729 happens and you need them again. 

730 """ 

731 # convert to array all array-like arguments that are used in other 

732 # configurations but don't need further processing themselves 

733 X = jnp.asarray(X) 1ab

734 y = jnp.asarray(y) 1ab

735 assert y.dtype == jnp.float32 1ab

736 offset = jnp.asarray(offset) 1ab

737 leaf_prior_cov_inv = jnp.asarray(leaf_prior_cov_inv) 1ab

738 max_split = jnp.asarray(max_split) 1ab

739 

740 # normalize outcome_type to enum (or list of enums) 

741 outcome_type = _parse_outcome_type(outcome_type) 1ab

742 

743 # check p_nonterminal and pad it with a 0 at the end (still not final shape) 

744 p_nonterminal = _parse_p_nonterminal(p_nonterminal) 1ab

745 

746 # process arguments that change depending on outcome type 

747 is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale, binary_indices = ( 1ab

748 _init_shape_shifting_parameters( 

749 y, 

750 outcome_type, 

751 offset, 

752 error_scale, 

753 error_cov_df, 

754 error_cov_scale, 

755 leaf_prior_cov_inv, 

756 ) 

757 ) 

758 

759 # extract array sizes from arguments 

760 (max_depth,) = p_nonterminal.shape 1ab

761 p, n = X.shape 1ab

762 

763 # check and initialize sparsity parameters 

764 if not _all_none_or_not_none(rho, a, b): 764 ↛ 765line 764 didn't jump to line 765 because the condition on line 764 was never true1ab

765 msg = 'rho, a, b are not either all `None` or all set' 

766 raise ValueError(msg) 

767 if theta is None and rho is not None: 1satCpqb

768 theta = rho 1apq

769 if log_s is None and theta is not None: 1sadtDECpqb

770 log_s = jnp.zeros(max_split.size) 1apq

771 if not _all_none_or_not_none(theta, sparse_on_at): 771 ↛ 772line 771 didn't jump to line 772 because the condition on line 771 was never true1sadDEb

772 msg = 'sparsity params (either theta or rho,a,b) and sparse_on_at must be either all None or all set' 

773 raise ValueError(msg) 

774 

775 # process multichain settings 

776 chain_shape = () if num_chains is None else (num_chains,) 1areb

777 resid_shape = chain_shape + y.shape 1areb

778 add_chains = partial(_add_chains, chain_shape=chain_shape) 1ab

779 

780 # determine batch sizes for reductions 

781 mesh = _parse_mesh(num_chains, mesh) 1ab

782 target_platform = _parse_target_platform( 1ab

783 y, mesh, target_platform, resid_num_batches, count_num_batches, prec_num_batches 

784 ) 

785 red_cfg = _parse_reduction_configs( 1ab

786 resid_num_batches, 

787 count_num_batches, 

788 prec_num_batches, 

789 prec_count_num_trees, 

790 y, 

791 num_trees, 

792 mesh, 

793 target_platform, 

794 ) 

795 

796 # check there aren't too many deactivated predictors 

797 offset = _check_splitless_vars(filter_splitless_vars, max_split, offset) 1ab

798 

799 # determine shapes for trees 

800 tree_shape = (*chain_shape, num_trees) 1ab

801 tree_size = 2**max_depth 1ab

802 

803 # initialize all remaining stuff and put it in an unsharded state 

804 state = State( 1adpqb

805 X=X, 

806 binary_y=y, # temporary to be sharded together with everything else 

807 z=( 

808 _LazyArray(jnp.full, resid_shape, offset[..., None]) 

809 if is_binary 

810 else _LazyArray( 

811 jnp.full, 

812 (*chain_shape, binary_indices.size, n), 

813 offset[binary_indices, None], 

814 ) 

815 if binary_indices is not None 

816 else None 

817 ), 

818 binary_indices=binary_indices, 

819 offset=offset, 

820 resid=( 

821 _LazyArray(jnp.zeros, resid_shape) 

822 if is_binary 

823 else None # resid is created later after y and offset are sharded 

824 ), 

825 error_cov_inv=add_chains(error_cov_inv), 

826 prec_scale=error_scale, # temporarily set to error_scale, fix after sharding 

827 error_cov_df=error_cov_df, 

828 error_cov_scale=error_cov_scale, 

829 forest=Forest( 

830 leaf_tree=_LazyArray( 

831 jnp.zeros, (*tree_shape, *kshape, tree_size), jnp.float32 

832 ), 

833 var_tree=_LazyArray( 

834 jnp.zeros, (*tree_shape, tree_size // 2), minimal_unsigned_dtype(p - 1) 

835 ), 

836 split_tree=_LazyArray( 

837 jnp.zeros, (*tree_shape, tree_size // 2), max_split.dtype 

838 ), 

839 affluence_tree=_LazyArray( 

840 _initial_affluence_tree, 

841 (*tree_shape, tree_size // 2), 

842 n, 

843 min_points_per_decision_node, 

844 ), 

845 blocked_vars=_get_blocked_vars(filter_splitless_vars, max_split), 

846 max_split=max_split, 

847 grow_prop_count=_LazyArray(jnp.zeros, chain_shape, int), 

848 grow_acc_count=_LazyArray(jnp.zeros, chain_shape, int), 

849 prune_prop_count=_LazyArray(jnp.zeros, chain_shape, int), 

850 prune_acc_count=_LazyArray(jnp.zeros, chain_shape, int), 

851 p_nonterminal=p_nonterminal[tree_depths(tree_size)], 

852 p_propose_grow=p_nonterminal[tree_depths(tree_size // 2)], 

853 leaf_indices=_LazyArray( 

854 jnp.ones, (*tree_shape, n), minimal_unsigned_dtype(tree_size - 1) 

855 ), 

856 min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node), 

857 min_points_per_leaf=_asarray_or_none(min_points_per_leaf), 

858 log_trans_prior=_LazyArray(jnp.zeros, (*chain_shape, num_trees)) 

859 if save_ratios 

860 else None, 

861 log_likelihood=_LazyArray(jnp.zeros, (*chain_shape, num_trees)) 

862 if save_ratios 

863 else None, 

864 leaf_prior_cov_inv=leaf_prior_cov_inv, 

865 log_s=add_chains(_asarray_or_none(log_s)), 

866 theta=add_chains(_asarray_or_none(theta)), 

867 rho=_asarray_or_none(rho), 

868 a=_asarray_or_none(a), 

869 b=_asarray_or_none(b), 

870 ), 

871 config=StepConfig( 

872 steps_done=jnp.int32(0), 

873 sparse_on_at=_asarray_or_none(sparse_on_at), 

874 mesh=mesh, 

875 **red_cfg, 

876 ), 

877 ) 

878 

879 # delete big input arrays such that they can be deleted as soon as they 

880 # are sharded, only those arrays that contain an (n,) sized axis 

881 del X, error_scale, y 1adpqb

882 

883 # move all arrays to the appropriate device 

884 state = _shard_state(state) 1ab

885 

886 # calculate initial resid in the continuous outcome case, such that y and 

887 # offset are already sharded if needed 

888 if state.resid is None: 1adbk

889 resid = _LazyArray( 1ab

890 _initial_resid, 

891 resid_shape, 

892 state.binary_y, # this is actually y 

893 state.offset, 

894 binary_indices, 

895 ) 

896 resid = _shard_leaf(resid, 0, -1, state.config.mesh) 1ab

897 state = replace(state, resid=resid) 1ab

898 

899 # calculate initial binary_y 

900 if is_binary or binary_indices is not None: 1adfgbhk

901 binary_y = _LazyArray( 1dfghk

902 _initial_binary_y, 

903 state.binary_y.shape 

904 if binary_indices is None 

905 else (binary_indices.size, n), 

906 state.binary_y, # this is actually y 

907 binary_indices, 

908 ) 

909 binary_y = _shard_leaf(binary_y, None, -1, state.config.mesh) 1dfghk

910 else: 

911 binary_y = None 1ab

912 state = replace(state, binary_y=binary_y) 1ab

913 

914 # calculate prec_scale after sharding to do the calculation on the right 

915 # devices 

916 if state.prec_scale is not None: 1atb

917 prec_scale = _compute_prec_scale(state.prec_scale) 1t

918 state = replace(state, prec_scale=prec_scale) 1t

919 

920 # make all types strong to avoid unwanted recompilations 

921 return _remove_weak_types(state) 1ab

922 

923 

924def _initial_resid( 

925 shape: tuple[int, ...], 

926 y: Float32[Array, ' n'] | Float32[Array, 'k n'], 

927 offset: Float32[Array, ''] | Float32[Array, ' k'], 

928 binary_indices: Int32[Array, ' kb'] | None, 

929) -> Float32[Array, ' n'] | Float32[Array, 'k n']: 

930 """Calculate the initial value for `State.resid` in the continuous outcome case. 

931 

932 In the mixed binary-continuous case, binary rows are zeroed out (their 

933 residual starts at ``z - trees - offset = 0``). 

934 """ 

935 resid = jnp.broadcast_to(y - offset[..., None], shape) 1ab

936 if binary_indices is not None: 1afgbh

937 resid = resid.at[..., binary_indices, :].set(0.0) 1fgh

938 return resid 1ab

939 

940 

941def _initial_binary_y( 

942 shape: tuple[int, ...], 

943 y: Float32[Array, 'k n'] | Float32[Array, ' n'], 

944 binary_indices: Int32[Array, ' kb'] | None, 

945) -> Bool[Array, 'kb n'] | Bool[Array, ' n']: 

946 """Extract and convert the binary outcome components from ``y``.""" 

947 if binary_indices is None: 1dfghk

948 out = y.astype(bool) 1dk

949 else: 

950 out = y[binary_indices, :].astype(bool) 1fgh

951 assert out.shape == shape 1dk

952 return out 1dk

953 

954 

955def _initial_affluence_tree( 

956 shape: tuple[int, ...], n: int, min_points_per_decision_node: int | None 

957) -> Array: 

958 """Create the initial value of `Forest.affluence_tree`.""" 

959 return ( 1aLbk

960 jnp.zeros(shape, bool) 

961 .at[..., 1] 

962 .set( 

963 True 

964 if min_points_per_decision_node is None 

965 else n >= min_points_per_decision_node 

966 ) 

967 ) 

968 

969 

970@partial(jit, donate_argnums=(0,)) 

971def _compute_prec_scale(error_scale: Float32[Array, ' n']) -> Float32[Array, ' n']: 

972 """Compute 1 / error_scale**2. 

973 

974 This is a separate function to use donate_argnums to avoid intermediate 

975 copies. 

976 """ 

977 return jnp.reciprocal(jnp.square(error_scale)) 1t

978 

979 

980def _get_blocked_vars( 

981 filter_splitless_vars: int, max_split: UInt[Array, ' p'] 

982) -> None | UInt[Array, ' q']: 

983 """Initialize the `blocked_vars` field.""" 

984 if filter_splitless_vars: 1avwb

985 (p,) = max_split.shape 1vw

986 (blocked_vars,) = jnp.nonzero( 1vw

987 max_split == 0, size=filter_splitless_vars, fill_value=p 

988 ) 

989 return blocked_vars.astype(minimal_unsigned_dtype(p)) 1vw

990 # see `fully_used_variables` for the type cast 

991 else: 

992 return None 1ab

993 

994 

995def _add_chains( 

996 x: Shaped[Array, '*shape'] | None, chain_shape: tuple[int, ...] 

997) -> Shaped[Array, '*shape'] | Shaped[Array, ' num_chains *shape'] | None: 

998 """Broadcast `x` to all chains.""" 

999 if x is None: 1sab

1000 return None 1sb

1001 else: 

1002 return jnp.broadcast_to(x, chain_shape + x.shape) 1ab

1003 

1004 

1005def _parse_mesh( 

1006 num_chains: int | None, mesh: Mesh | dict[str, int] | None 

1007) -> Mesh | None: 

1008 """Parse the `mesh` argument.""" 

1009 if mesh is None: 1jadbi

1010 return None 1ab

1011 

1012 # convert dict format to actual mesh 

1013 if isinstance(mesh, dict): 1jdFi

1014 assert set(mesh).issubset({'chains', 'data'}) 1i

1015 mesh = make_mesh( 1i

1016 tuple(mesh.values()), tuple(mesh), axis_types=(AxisType.Auto,) * len(mesh) 

1017 ) 

1018 

1019 # check there's no chain mesh axis if there are no chains 

1020 if num_chains is None: 1jdFin

1021 assert 'chains' not in mesh.axis_names 1n

1022 

1023 # check the axes we use are in auto mode 

1024 assert 'chains' not in mesh.axis_names or 'chains' in mesh.auto_axes 1jdfin

1025 assert 'data' not in mesh.axis_names or 'data' in mesh.auto_axes 1jdfin

1026 

1027 return mesh 1jdfin

1028 

1029 

1030def _parse_target_platform( 

1031 y: Array, 

1032 mesh: Mesh | None, 

1033 target_platform: Literal['cpu', 'gpu'] | None, 

1034 resid_num_batches: int | None | Literal['auto'], 

1035 count_num_batches: int | None | Literal['auto'], 

1036 prec_num_batches: int | None | Literal['auto'], 

1037) -> Literal['cpu', 'gpu'] | None: 

1038 if mesh is not None: 1jadbi

1039 assert target_platform is None, 'mesh provided, do not set target_platform' 1jdi

1040 return mesh.devices.flat[0].platform 1jdi

1041 elif hasattr(y, 'platform'): 1axbk

1042 assert target_platform is None, 'device inferred from y, unset target_platform' 1ak

1043 return y.platform() 1ak

1044 elif ( 1xyb

1045 resid_num_batches == 'auto' 

1046 or count_num_batches == 'auto' 

1047 or prec_num_batches == 'auto' 

1048 ): 

1049 assert target_platform in ('cpu', 'gpu') 1yb

1050 return target_platform 1yb

1051 else: 

1052 assert target_platform is None, 'target_platform not used, unset it' 1x

1053 return target_platform 1x

1054 

1055 

1056@partial(filter_jit, donate='all') 

1057# jit and donate because otherwise type conversion would create copies 

1058def _remove_weak_types(x: PyTree[Array, 'T']) -> PyTree[Array, 'T']: 

1059 """Make all types strong. 

1060 

1061 This is to avoid recompilation in `run_mcmc` or `step`. 

1062 """ 

1063 

1064 def remove_weak(x: T) -> T: 1ab

1065 if isinstance(x, Array) and x.weak_type: 1ab

1066 return x.astype(x.dtype) 1ab

1067 else: 

1068 return x 1ab

1069 

1070 return tree.map(remove_weak, x) 1ab

1071 

1072 

1073def _shard_state(state: State) -> State: 

1074 """Place all arrays on the appropriate devices, and instantiate lazily defined arrays.""" 

1075 mesh = state.config.mesh 1ab

1076 shard_leaf = partial(_shard_leaf, mesh=mesh) 1ab

1077 return tree.map( 1ab

1078 shard_leaf, 

1079 state, 

1080 chain_vmap_axes(state), 

1081 data_vmap_axes(state), 

1082 is_leaf=lambda x: x is None or isinstance(x, _LazyArray), 

1083 ) 

1084 

1085 

1086def _shard_leaf( 

1087 x: Array | None | _LazyArray, 

1088 chain_axis: int | None, 

1089 data_axis: int | None, 

1090 mesh: Mesh | None, 

1091) -> Array | None: 

1092 """Create `x` if it's lazy and shard it.""" 

1093 if x is None: 1ab

1094 return None 1ab

1095 

1096 if mesh is None: 1jadbi

1097 sharding = None 1ab

1098 else: 

1099 spec = [None] * x.ndim 1jdi

1100 if chain_axis is not None and 'chains' in mesh.axis_names: 1jdfin

1101 spec[chain_axis] = 'chains' 1jfi

1102 if data_axis is not None and 'data' in mesh.axis_names: 1jdfin

1103 spec[data_axis] = 'data' 1dn

1104 

1105 # remove trailing Nones to be consistent with jax's output, it's useful 

1106 # for comparing shardings during debugging 

1107 while spec and spec[-1] is None: 1jdfi

1108 spec.pop() 1jdi

1109 

1110 spec = PartitionSpec(*spec) 1jdi

1111 sharding = NamedSharding(mesh, spec) 1jdi

1112 

1113 if isinstance(x, _LazyArray): 1ab

1114 x = _concretize_lazy_array(x, sharding) 1ab

1115 elif sharding is not None: 1jadbi

1116 x = device_put(x, sharding, donate=True) 1jdi

1117 

1118 return x 1ab

1119 

1120 

1121@filter_jit 

1122# jit such that in recent jax versions the shards are created on the right 

1123# devices immediately instead of being created on the wrong device and then 

1124# copied 

1125def _concretize_lazy_array(x: _LazyArray, sharding: NamedSharding | None) -> Array: 

1126 """Create an array from an abstract spec on the appropriate devices.""" 

1127 x = x() 1ab

1128 if sharding is not None: 1jadbi

1129 x = lax.with_sharding_constraint(x, sharding) 1jdi

1130 return x 1ab

1131 

1132 

1133def _all_none_or_not_none(*args: object) -> bool: 

1134 is_none = [x is None for x in args] 1ab

1135 return all(is_none) or not any(is_none) 1sapqb

1136 

1137 

1138def _asarray_or_none(x: object) -> Array | None: 

1139 if x is None: 1aGb

1140 return None 1Gb

1141 return jnp.asarray(x) 1ab

1142 

1143 

1144def _get_platform(mesh: Mesh | None) -> str: 

1145 if mesh is None: 

1146 return get_default_device().platform 

1147 else: 

1148 return mesh.devices.flat[0].platform 

1149 

1150 

1151class _ReductionConfig(TypedDict): 

1152 """Fields of `StepConfig` related to reductions.""" 

1153 

1154 resid_num_batches: int | None 

1155 count_num_batches: int | None 

1156 prec_num_batches: int | None 

1157 prec_count_num_trees: int | None 

1158 

1159 

1160def _parse_reduction_configs( 

1161 resid_num_batches: int | None | Literal['auto'], 

1162 count_num_batches: int | None | Literal['auto'], 

1163 prec_num_batches: int | None | Literal['auto'], 

1164 prec_count_num_trees: int | None | Literal['auto'], 

1165 y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'], 

1166 num_trees: int, 

1167 mesh: Mesh | None, 

1168 target_platform: Literal['cpu', 'gpu'] | None, 

1169) -> _ReductionConfig: 

1170 """Determine settings for indexed reduces.""" 

1171 n = y.shape[-1] 1ab

1172 n //= get_axis_size(mesh, 'data') # per-device datapoints 1ab

1173 parse_num_batches = partial(_parse_num_batches, target_platform, n) 1ab

1174 return dict( 1ab

1175 resid_num_batches=parse_num_batches(resid_num_batches, 'resid'), 

1176 count_num_batches=parse_num_batches(count_num_batches, 'count'), 

1177 prec_num_batches=parse_num_batches(prec_num_batches, 'prec'), 

1178 prec_count_num_trees=_parse_prec_count_num_trees( 

1179 prec_count_num_trees, num_trees, n 

1180 ), 

1181 ) 

1182 

1183 

1184def _parse_num_batches( 

1185 target_platform: Literal['cpu', 'gpu'] | None, 

1186 n: int, 

1187 num_batches: int | None | Literal['auto'], 

1188 which: Literal['resid', 'count', 'prec'], 

1189) -> int | None: 

1190 """Return the number of batches or determine it automatically.""" 

1191 final_round = partial(_final_round, n) 1ab

1192 if num_batches != 'auto': 1adbz

1193 nb = num_batches 1az

1194 elif target_platform == 'cpu': 1194 ↛ 1196line 1194 didn't jump to line 1196 because the condition on line 1194 was always true1db

1195 nb = final_round(16) 1db

1196 elif target_platform == 'gpu': 

1197 nb = dict(resid=1024, count=2048, prec=1024)[which] # on an A4000 

1198 nb = final_round(nb) 

1199 return nb 1adbz

1200 

1201 

1202def _final_round(n: int, num: float) -> int | None: 

1203 """Bound batch size, round number of batches to a power of 2, and disable batching if there's only 1 batch.""" 

1204 # at least some elements per batch 

1205 num = min(n // 32, num) 1db

1206 

1207 # round to the nearest power of 2 because I guess XLA and the hardware 

1208 # will like that (not sure about this, maybe just multiple of 32?) 

1209 num = 2 ** round(log2(num)) if num else 0 1dHbIn

1210 

1211 # disable batching if the batch is as large as the whole dataset 

1212 return num if num > 1 else None 1dHbIkn

1213 

1214 

1215def _parse_prec_count_num_trees( 

1216 prec_count_num_trees: int | None | Literal['auto'], num_trees: int, n: int 

1217) -> int | None: 

1218 """Return the number of trees to process at a time or determine it automatically.""" 

1219 if prec_count_num_trees != 'auto': 1adpqb

1220 return prec_count_num_trees 1apq

1221 max_n_by_ntree = 2**27 # about 100M 1db

1222 pcnt = max_n_by_ntree // max(1, n) 1db

1223 pcnt = min(num_trees, pcnt) 1db

1224 pcnt = max(1, pcnt) 1db

1225 pcnt = _search_divisor( 1db

1226 pcnt, num_trees, max(1, pcnt // 2), max(1, min(num_trees, pcnt * 2)) 

1227 ) 

1228 if pcnt >= num_trees: 1228 ↛ 1230line 1228 didn't jump to line 1230 because the condition on line 1228 was always true1db

1229 pcnt = None 1db

1230 return pcnt 1db

1231 

1232 

1233def _search_divisor(target_divisor: int, dividend: int, low: int, up: int) -> int: 

1234 """Find the divisor closest to `target_divisor` in [low, up] if `target_divisor` is not already. 

1235 

1236 If there is none, give up and return `target_divisor`. 

1237 """ 

1238 assert target_divisor >= 1 1db

1239 assert 1 <= low <= up <= dividend 1db

1240 if dividend % target_divisor == 0: 1240 ↛ 1242line 1240 didn't jump to line 1242 because the condition on line 1240 was always true1db

1241 return target_divisor 1db

1242 candidates = numpy.arange(low, up + 1) 

1243 divisors = candidates[dividend % candidates == 0] 

1244 if divisors.size == 0: 

1245 return target_divisor 

1246 penalty = numpy.abs(divisors - target_divisor) 

1247 closest = numpy.argmin(penalty) 

1248 return divisors[closest].item() 

1249 

1250 

1251def get_axis_size(mesh: Mesh | None, axis_name: str) -> int: 

1252 if mesh is None or axis_name not in mesh.axis_names: 1adbin

1253 return 1 1abi

1254 else: 

1255 i = mesh.axis_names.index(axis_name) 1dn

1256 return mesh.axis_sizes[i] 1dn

1257 

1258 

1259def chol_with_gersh( 

1260 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool = False 

1261) -> Float32[Array, '*batch_shape k k']: 

1262 """Cholesky with Gershgorin stabilization, supports batching.""" 

1263 return _chol_with_gersh_impl(mat, absolute_eps) 1lm

1264 

1265 

1266@partial(jnp.vectorize, signature='(k,k)->(k,k)', excluded=(1,)) 

1267def _chol_with_gersh_impl( 

1268 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool 

1269) -> Float32[Array, '*batch_shape k k']: 

1270 rho = jnp.max(jnp.sum(jnp.abs(mat), axis=1), initial=0.0) 1lm

1271 eps = jnp.finfo(mat.dtype).eps 1lm

1272 u = mat.shape[0] * rho * eps 1lm

1273 if absolute_eps: 1lJm

1274 u += eps 1lJ

1275 mat = mat.at[jnp.diag_indices_from(mat)].add(u) 1lm

1276 return jnp.linalg.cholesky(mat) 1lm

1277 

1278 

1279def _inv_via_chol_with_gersh( 

1280 mat: Float32[Array, '*batch_shape k k'], 

1281) -> Float32[Array, '*batch_shape k k']: 

1282 """Compute matrix inverse via Cholesky with Gershgorin stabilization. 

1283 

1284 DO NOT USE THIS FUNCTION UNLESS YOU REALLY NEED TO. 

1285 """ 

1286 # mat = L L^T 

1287 # mat^-1 = L^-T L^-1 = L^-T I L^-1 = L^-T (L^-T I)^T 

1288 # I suspect this to be more accurate than (L^-1 I)^T (L^-1 I) 

1289 L = chol_with_gersh(mat) 1lm

1290 eye = jnp.broadcast_to(jnp.eye(mat.shape[-1]), mat.shape) 1lm

1291 Ltinv = solve_triangular(L, eye, trans='T', lower=True) 1lm

1292 return solve_triangular(L, Ltinv.mT, trans='T', lower=True) 1lm

1293 

1294 

1295def get_num_chains(x: PyTree) -> int | None: 

1296 """Get the number of chains of a pytree. 

1297 

1298 Find all nodes in the structure that define 'num_chains()', stopping 

1299 traversal at nodes that define it. Check all values obtained invoking 

1300 `num_chains` are equal, then return it. 

1301 """ 

1302 leaves, _ = tree.flatten(x, is_leaf=lambda x: hasattr(x, 'num_chains')) 1ao

1303 num_chains = [x.num_chains() for x in leaves if hasattr(x, 'num_chains')] 1ao

1304 ref = num_chains[0] 1ao

1305 assert all(c == ref for c in num_chains) 1ao

1306 return ref 1ao

1307 

1308 

1309def _chain_axes_with_keys(x: PyTree) -> PyTree[int | None]: 

1310 """Return `chain_vmap_axes(x)` but also set to 0 for random keys.""" 

1311 axes = chain_vmap_axes(x) 1ae

1312 

1313 def axis_if_key(x: object, axis: int | None) -> int | None: 1ae

1314 if is_key(x): 1ae

1315 return 0 1ae

1316 else: 

1317 return axis 1ae

1318 

1319 return tree.map(axis_if_key, x, axes) 1ae

1320 

1321 

1322def _get_mc_out_axes( 

1323 fun: Callable[[tuple, dict], PyTree], args: PyTree, in_axes: PyTree[int | None] 

1324) -> PyTree[int | None]: 

1325 """Decide chain vmap axes for outputs.""" 

1326 vmapped_fun = vmap(fun, in_axes=in_axes) 1ae

1327 out = eval_shape(vmapped_fun, *args) 1ae

1328 return chain_vmap_axes(out) 1ae

1329 

1330 

1331def _find_mesh(x: PyTree) -> Mesh | None: 

1332 """Find the mesh used for chains.""" 

1333 

1334 class MeshFound(Exception): 1ae

1335 pass 1ae

1336 

1337 def find_mesh(x: object) -> None: 1ae

1338 if isinstance(x, State): 1ae

1339 raise MeshFound(x.config.mesh) 1ae

1340 

1341 try: 1ae

1342 tree.map(find_mesh, x, is_leaf=lambda x: isinstance(x, State)) 1ae

1343 except MeshFound as e: 1ae

1344 return e.args[0] 1ae

1345 else: 

1346 raise ValueError 

1347 

1348 

1349def _split_all_keys(x: PyTree, num_chains: int) -> PyTree: 

1350 """Split all random keys in `num_chains` keys.""" 

1351 mesh = _find_mesh(x) 1ae

1352 

1353 def split_key(x: object) -> object: 1ae

1354 if is_key(x): 1ae

1355 x = random.split(x, num_chains) 1ae

1356 if mesh is not None and 'chains' in mesh.axis_names: 1jadfeiK

1357 x = device_put(x, NamedSharding(mesh, PartitionSpec('chains'))) 1jfi

1358 return x 1adeK

1359 

1360 return tree.map(split_key, x) 1ae

1361 

1362 

1363def vmap_chains(fun: Callable[..., T]) -> Callable[..., T]: 

1364 """Apply vmap on chain axes automatically if the inputs are multichain.""" 

1365 

1366 @wraps(fun) 

1367 def auto_vmapped_fun(*args: Any, **kwargs: Any) -> T: 

1368 all_args = args, kwargs 1ao

1369 num_chains = get_num_chains(all_args) 1ao

1370 if num_chains is not None: 1areo

1371 all_args = _split_all_keys(all_args, num_chains) 1ae

1372 

1373 def wrapped_fun(args: tuple[Any, ...], kwargs: dict[str, Any]) -> T: 1ae

1374 return fun(*args, **kwargs) 1ae

1375 

1376 mc_in_axes = _chain_axes_with_keys(all_args) 1ae

1377 mc_out_axes = _get_mc_out_axes(wrapped_fun, all_args, mc_in_axes) 1ae

1378 vmapped_fun = vmap(wrapped_fun, in_axes=mc_in_axes, out_axes=mc_out_axes) 1ae

1379 return vmapped_fun(*all_args) 1ae

1380 

1381 else: 

1382 return fun(*args, **kwargs) 1ro

1383 

1384 return auto_vmapped_fun