Coverage for src / bartz / mcmcstep / _step.py: 99%

441 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 18:11 +0000

1# bartz/src/bartz/mcmcstep/_step.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement `step`, `step_trees`, and the accept-reject logic.""" 

26 

27from dataclasses import replace 

28from functools import partial 

29 

30# WORKAROUND(jax<0.6.1): shard_map was promoted from jax.experimental to top-level in 0.6.1 

31try: 

32 from jax import shard_map 

33except ImportError: 

34 from jax.experimental.shard_map import shard_map 

35 

36import jax 

37from equinox import Module, tree_at 

38from jax import jit, lax, named_call, random, vmap 

39from jax import numpy as jnp 

40from jax.scipy.linalg import solve_triangular 

41from jax.scipy.special import gammaln, logsumexp 

42from jax.sharding import Mesh, PartitionSpec 

43from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt, UInt32 

44 

45from bartz.grove import var_histogram 

46from bartz.jaxext import split, truncated_normal_onesided, vmap_nodoc 

47from bartz.mcmcstep._moves import Moves, propose_moves 

48from bartz.mcmcstep._state import State, StepConfig, chol_with_gersh, field, vmap_chains 

49 

50 

51@partial(jit, donate_argnums=(1,)) 

52@vmap_chains 

53def step(key: Key[Array, ''], bart: State) -> State: 

54 """ 

55 Do one MCMC step. 

56 

57 Parameters 

58 ---------- 

59 key 

60 A jax random key. 

61 bart 

62 A BART mcmc state, as created by `init`. 

63 

64 Returns 

65 ------- 

66 The new BART mcmc state. 

67 

68 Notes 

69 ----- 

70 The memory of the input state is re-used for the output state, so the input 

71 state can not be used any more after calling `step`. All this applies 

72 outside of `jax.jit`. 

73 """ 

74 keys = split(key, 4) 1ab

75 

76 bart = step_trees(keys.pop(), bart) 1ab

77 

78 if bart.z is not None: 1ahbj

79 bart = step_z(keys.pop(), bart) 1hj

80 

81 if bart.error_cov_df is not None: 1ahbj

82 bart = step_error_cov_inv(keys.pop(), bart) 1ab

83 

84 bart = step_sparse(keys.pop(), bart) 1ahbj

85 return step_config(bart) 1ab

86 

87 

88@named_call 

89def step_trees(key: Key[Array, ''], bart: State) -> State: 

90 """ 

91 Forest sampling step of BART MCMC. 

92 

93 Parameters 

94 ---------- 

95 key 

96 A jax random key. 

97 bart 

98 A BART mcmc state, as created by `init`. 

99 

100 Returns 

101 ------- 

102 The new BART mcmc state. 

103 

104 Notes 

105 ----- 

106 This function zeroes the proposal counters. 

107 """ 

108 keys = split(key) 1ab

109 moves = propose_moves(keys.pop(), bart.forest) 1ab

110 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1ab

111 

112 

113@named_call 

114def accept_moves_and_sample_leaves( 

115 key: Key[Array, ''], bart: State, moves: Moves 

116) -> State: 

117 """ 

118 Accept or reject the proposed moves and sample the new leaf values. 

119 

120 Parameters 

121 ---------- 

122 key 

123 A jax random key. 

124 bart 

125 A valid BART mcmc state. 

126 moves 

127 The proposed moves, see `propose_moves`. 

128 

129 Returns 

130 ------- 

131 A new (valid) BART mcmc state. 

132 """ 

133 pso = accept_moves_parallel_stage(key, bart, moves) 1ab

134 bart, moves = accept_moves_sequential_stage(pso) 1ab

135 return accept_moves_final_stage(bart, moves) 1ab

136 

137 

138class Counts(Module): 

139 """Number of datapoints in the nodes involved in proposed moves for each tree.""" 

140 

141 left: UInt[Array, '*chains num_trees'] = field(chains=True) 

142 """Number of datapoints in the left child.""" 

143 

144 right: UInt[Array, '*chains num_trees'] = field(chains=True) 

145 """Number of datapoints in the right child.""" 

146 

147 total: UInt[Array, '*chains num_trees'] = field(chains=True) 

148 """Number of datapoints in the parent (``= left + right``).""" 

149 

150 

151class Precs(Module): 

152 """Likelihood precision scale in the nodes involved in proposed moves for each tree. 

153 

154 The "likelihood precision scale" of a tree node is the sum of the inverse 

155 squared error scales of the datapoints selected by the node. 

156 """ 

157 

158 left: Float32[Array, '*chains num_trees'] = field(chains=True) 

159 """Likelihood precision scale in the left child.""" 

160 

161 right: Float32[Array, '*chains num_trees'] = field(chains=True) 

162 """Likelihood precision scale in the right child.""" 

163 

164 total: Float32[Array, '*chains num_trees'] = field(chains=True) 

165 """Likelihood precision scale in the parent (``= left + right``).""" 

166 

167 

168class PreLkV(Module): 

169 """Non-sequential terms of the likelihood ratio for each tree. 

170 

171 These terms can be computed in parallel across trees. 

172 """ 

173 

174 left: ( 

175 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k'] 

176 ) = field(chains=True) 

177 """In the univariate case, this is the scalar term 

178 

179 ``1 / error_cov_inv + n_left / leaf_prior_cov_inv``. 

180 

181 In the multivariate case, this is the matrix term 

182 

183 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_left * error_cov_inv) @ error_cov_inv``. 

184 

185 ``n_left`` is the number of datapoints in the left child, or the 

186 likelihood precision scale in the heteroskedastic case.""" 

187 

188 right: ( 

189 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k'] 

190 ) = field(chains=True) 

191 """In the univariate case, this is the scalar term 

192 

193 ``1 / error_cov_inv + n_right / leaf_prior_cov_inv``. 

194 

195 In the multivariate case, this is the matrix term 

196 

197 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_right * error_cov_inv) @ error_cov_inv``. 

198 

199 ``n_right`` is the number of datapoints in the right child, or the 

200 likelihood precision scale in the heteroskedastic case.""" 

201 

202 total: ( 

203 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k'] 

204 ) = field(chains=True) 

205 """In the univariate case, this is the scalar term 

206 

207 ``1 / error_cov_inv + n_total / leaf_prior_cov_inv``. 

208 

209 In the multivariate case, this is the matrix term 

210 

211 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_total * error_cov_inv) @ error_cov_inv``. 

212 

213 ``n_total`` is the number of datapoints in the parent node, or the 

214 likelihood precision scale in the heteroskedastic case.""" 

215 

216 log_sqrt_term: Float32[Array, '*chains num_trees'] = field(chains=True) 

217 """The logarithm of the square root term of the likelihood ratio.""" 

218 

219 

220class PreLk(Module): 

221 """Non-sequential terms of the likelihood ratio shared by all trees.""" 

222 

223 exp_factor: Float32[Array, '*chains'] = field(chains=True) 

224 """The factor to multiply the likelihood ratio by, shared by all trees.""" 

225 

226 

227class PreLf(Module): 

228 """Pre-computed terms used to sample leaves from their posterior. 

229 

230 These terms can be computed in parallel across trees. 

231 

232 For each tree and leaf, the terms are scalars in the univariate case, and 

233 matrices/vectors in the multivariate case. 

234 """ 

235 

236 mean_factor: ( 

237 Float32[Array, '*chains num_trees 2**d'] 

238 | Float32[Array, '*chains num_trees k k 2**d'] 

239 ) = field(chains=True) 

240 """The factor to be right-multiplied by the sum of the scaled residuals to 

241 obtain the posterior mean.""" 

242 

243 centered_leaves: ( 

244 Float32[Array, '*chains num_trees 2**d'] 

245 | Float32[Array, '*chains num_trees k 2**d'] 

246 ) = field(chains=True) 

247 """The mean-zero normal values to be added to the posterior mean to 

248 obtain the posterior leaf samples.""" 

249 

250 

251class ParallelStageOut(Module): 

252 """The output of `accept_moves_parallel_stage`.""" 

253 

254 bart: State 

255 """A partially updated BART mcmc state.""" 

256 

257 moves: Moves 

258 """The proposed moves, with `partial_ratio` set to `None` and 

259 `log_trans_prior_ratio` set to its final value.""" 

260 

261 prec_trees: ( 

262 Float32[Array, '*chains num_trees 2**d'] 

263 | Int32[Array, '*chains num_trees 2**d'] 

264 ) = field(chains=True) 

265 """The likelihood precision scale in each potential or actual leaf node. If 

266 there is no precision scale, this is the number of points in each leaf.""" 

267 

268 move_precs: Precs | Counts 

269 """The likelihood precision scale in each node modified by the moves. If 

270 `bart.prec_scale` is not set, this is set to `move_counts`.""" 

271 

272 prelkv: PreLkV 

273 """Object with pre-computed terms of the likelihood ratios.""" 

274 

275 prelk: PreLk | None 

276 """Object with pre-computed terms of the likelihood ratios.""" 

277 

278 prelf: PreLf 

279 """Object with pre-computed terms of the leaf samples.""" 

280 

281 

282@named_call 

283def accept_moves_parallel_stage( 

284 key: Key[Array, ''], bart: State, moves: Moves 

285) -> ParallelStageOut: 

286 """ 

287 Pre-compute quantities used to accept moves, in parallel across trees. 

288 

289 Parameters 

290 ---------- 

291 key 

292 A jax random key. 

293 bart 

294 A BART mcmc state. 

295 moves 

296 The proposed moves, see `propose_moves`. 

297 

298 Returns 

299 ------- 

300 An object with all that could be done in parallel. 

301 """ 

302 # where the move is grow, modify the state like the move was accepted 

303 bart = replace( 1ab

304 bart, 

305 forest=replace( 

306 bart.forest, 

307 var_tree=moves.var_tree, 

308 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X), 

309 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves), 

310 ), 

311 ) 

312 

313 # count number of datapoints per leaf 

314 if ( 1aqorstbj

315 bart.forest.min_points_per_decision_node is not None 

316 or bart.forest.min_points_per_leaf is not None 

317 or bart.prec_scale is None 

318 ): 

319 count_trees, move_counts = compute_count_trees( 1aortbj

320 bart.forest.leaf_indices, moves, bart.config 

321 ) 

322 

323 # mark which leaves & potential leaves have enough points to be grown 

324 if bart.forest.min_points_per_decision_node is not None: 1aqosbj

325 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1ab

326 moves = replace( 1ab

327 moves, 

328 affluence_tree=moves.affluence_tree 

329 & (count_half_trees >= bart.forest.min_points_per_decision_node), 

330 ) 

331 

332 # copy updated affluence_tree to state 

333 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1aobj

334 

335 # veto grove move if new leaves don't have enough datapoints 

336 if bart.forest.min_points_per_leaf is not None: 1aufgb

337 moves = replace( 1afg

338 moves, 

339 allowed=moves.allowed 

340 & (move_counts.left >= bart.forest.min_points_per_leaf) 

341 & (move_counts.right >= bart.forest.min_points_per_leaf), 

342 ) 

343 

344 # count number of datapoints per leaf, weighted by error precision scale 

345 if bart.prec_scale is None: 1akub

346 prec_trees = count_trees 1ab

347 move_precs = move_counts 1ab

348 else: 

349 prec_trees, move_precs = compute_prec_trees( 1k

350 bart.prec_scale, bart.forest.leaf_indices, moves, bart.config 

351 ) 

352 assert move_precs is not None 1ab

353 

354 # compute some missing information about moves 

355 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1ab

356 save_ratios = bart.forest.log_likelihood is not None 1ab

357 bart = replace( 1ahfgb

358 bart, 

359 forest=replace( 

360 bart.forest, 

361 grow_prop_count=jnp.sum(moves.grow), 

362 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow), 

363 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None, 

364 ), 

365 ) 

366 

367 prelkv, prelk = precompute_likelihood_terms( 1ahfgb

368 bart.error_cov_inv, bart.forest.leaf_prior_cov_inv, move_precs 

369 ) 

370 prelf = precompute_leaf_terms( 1ab

371 key, prec_trees, bart.error_cov_inv, bart.forest.leaf_prior_cov_inv 

372 ) 

373 

374 return ParallelStageOut( 1ab

375 bart=bart, 

376 moves=moves, 

377 prec_trees=prec_trees, 

378 move_precs=move_precs, 

379 prelkv=prelkv, 

380 prelk=prelk, 

381 prelf=prelf, 

382 ) 

383 

384 

385@named_call 

386@partial(vmap_nodoc, in_axes=(0, 0, None)) 

387def apply_grow_to_indices( 

388 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n'] 

389) -> UInt[Array, 'num_trees n']: 

390 """ 

391 Update the leaf indices to apply a grow move. 

392 

393 Parameters 

394 ---------- 

395 moves 

396 The proposed moves, see `propose_moves`. 

397 leaf_indices 

398 The index of the leaf each datapoint falls into. 

399 X 

400 The predictors matrix. 

401 

402 Returns 

403 ------- 

404 The updated leaf indices. 

405 """ 

406 left_child = moves.node.astype(leaf_indices.dtype) << 1 1ab

407 x: UInt[Array, ' n'] = X[moves.grow_var, :] 1ab

408 go_right = x >= moves.grow_split 1ab

409 tree_size = jnp.array(2 * moves.var_tree.size) 1ab

410 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1ab

411 return jnp.where( 1ab

412 leaf_indices == node_to_update, left_child + go_right, leaf_indices 

413 ) 

414 

415 

416def _compute_count_or_prec_trees( 

417 prec_scale: Float32[Array, ' n'] | None, 

418 leaf_indices: UInt[Array, 'num_trees n'], 

419 moves: Moves, 

420 config: StepConfig, 

421) -> ( 

422 tuple[UInt32[Array, 'num_trees 2**d'], Counts] 

423 | tuple[Float32[Array, 'num_trees 2**d'], Precs] 

424): 

425 """Implement `compute_count_trees` and `compute_prec_trees`.""" 

426 if config.prec_count_num_trees is None: 1ahfgb

427 compute = vmap(_compute_count_or_prec_tree, in_axes=(None, 0, 0, None)) 1hb

428 return compute(prec_scale, leaf_indices, moves, config) 1hb

429 

430 def compute( 1afg

431 args: tuple[UInt[Array, ' n'], Moves], 

432 ) -> tuple[UInt32[Array, ' 2**d'], Counts] | tuple[Float32[Array, ' 2**d'], Precs]: 

433 leaf_indices, moves = args 1afg

434 return _compute_count_or_prec_tree(prec_scale, leaf_indices, moves, config) 1afg

435 

436 return lax.map( 1afg

437 compute, (leaf_indices, moves), batch_size=config.prec_count_num_trees 

438 ) 

439 

440 

441def _compute_count_or_prec_tree( 

442 prec_scale: Float32[Array, ' n'] | None, 

443 leaf_indices: UInt[Array, ' n'], 

444 moves: Moves, 

445 config: StepConfig, 

446) -> tuple[UInt32[Array, ' 2**d'], Counts] | tuple[Float32[Array, ' 2**d'], Precs]: 

447 """Compute count or precision tree for a single tree.""" 

448 (tree_size,) = moves.var_tree.shape 1ab

449 tree_size *= 2 1ab

450 

451 if prec_scale is None: 1akb

452 value = 1 1ab

453 cls = Counts 1ab

454 dtype = jnp.uint32 1ab

455 num_batches = config.count_num_batches 1ab

456 else: 

457 value = prec_scale 1k

458 cls = Precs 1k

459 dtype = jnp.float32 1k

460 num_batches = config.prec_num_batches 1k

461 

462 trees = _scatter_add( 1ab

463 value, leaf_indices, tree_size, dtype, num_batches, config.mesh 

464 ) 

465 

466 # count datapoints in nodes modified by move 

467 left = trees[moves.left] 1ab

468 right = trees[moves.right] 1ab

469 counts = cls(left=left, right=right, total=left + right) 1ab

470 

471 # write count into non-leaf node 

472 trees = trees.at[moves.node].set(counts.total) 1ab

473 

474 return trees, counts 1ab

475 

476 

477@named_call 

478def compute_count_trees( 

479 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, config: StepConfig 

480) -> tuple[UInt32[Array, 'num_trees 2**d'], Counts]: 

481 """ 

482 Count the number of datapoints in each leaf. 

483 

484 Parameters 

485 ---------- 

486 leaf_indices 

487 The index of the leaf each datapoint falls into, with the deeper version 

488 of the tree (post-GROW, pre-PRUNE). 

489 moves 

490 The proposed moves, see `propose_moves`. 

491 config 

492 The MCMC configuration. 

493 

494 Returns 

495 ------- 

496 count_trees : Int32[Array, 'num_trees 2**d'] 

497 The number of points in each potential or actual leaf node. 

498 counts : Counts 

499 The counts of the number of points in the leaves grown or pruned by the 

500 moves. 

501 """ 

502 return _compute_count_or_prec_trees(None, leaf_indices, moves, config) 1ab

503 

504 

505@named_call 

506def compute_prec_trees( 

507 prec_scale: Float32[Array, ' n'], 

508 leaf_indices: UInt[Array, 'num_trees n'], 

509 moves: Moves, 

510 config: StepConfig, 

511) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]: 

512 """ 

513 Compute the likelihood precision scale in each leaf. 

514 

515 Parameters 

516 ---------- 

517 prec_scale 

518 The scale of the precision of the error on each datapoint. 

519 leaf_indices 

520 The index of the leaf each datapoint falls into, with the deeper version 

521 of the tree (post-GROW, pre-PRUNE). 

522 moves 

523 The proposed moves, see `propose_moves`. 

524 config 

525 The MCMC configuration. 

526 

527 Returns 

528 ------- 

529 prec_trees : Float32[Array, 'num_trees 2**d'] 

530 The likelihood precision scale in each potential or actual leaf node. 

531 precs : Precs 

532 The likelihood precision scale in the nodes involved in the moves. 

533 """ 

534 return _compute_count_or_prec_trees(prec_scale, leaf_indices, moves, config) 1k

535 

536 

537@partial(vmap_nodoc, in_axes=(0, None)) 

538def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves: 

539 """ 

540 Complete non-likelihood MH ratio calculation. 

541 

542 This function adds the probability of choosing a prune move over the grow 

543 move in the inverse transition, and the a priori probability that the 

544 children nodes are leaves. 

545 

546 Parameters 

547 ---------- 

548 moves 

549 The proposed moves. Must have already been updated to keep into account 

550 the thresholds on the number of datapoints per node, this happens in 

551 `accept_moves_parallel_stage`. 

552 p_nonterminal 

553 The a priori probability of each node being nonterminal conditional on 

554 its ancestors, including at the maximum depth where it should be zero. 

555 

556 Returns 

557 ------- 

558 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set. 

559 """ 

560 # can the leaves be grown? 

561 left_growable = moves.affluence_tree.at[moves.left].get( 1ab

562 mode='fill', fill_value=False 

563 ) 

564 right_growable = moves.affluence_tree.at[moves.right].get( 1ab

565 mode='fill', fill_value=False 

566 ) 

567 

568 # p_prune if grow 

569 other_growable_leaves = moves.num_growable >= 2 1ab

570 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1ab

571 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1.0) 1ab

572 

573 # p_prune if prune 

574 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1ab

575 

576 # select p_prune 

577 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1ab

578 

579 # prior probability of both children being terminal 

580 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1ab

581 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1ab

582 pt_children = pt_left * pt_right 1ab

583 

584 assert moves.partial_ratio is not None 1ab

585 return replace( 1ab

586 moves, 

587 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune), 

588 partial_ratio=None, 

589 ) 

590 

591 

592@named_call 

593@vmap_nodoc 

594def adapt_leaf_trees_to_grow_indices( 

595 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves 

596) -> Float32[Array, 'num_trees 2**d']: 

597 """ 

598 Modify leaves such that post-grow indices work on the original tree. 

599 

600 The value of the leaf to grow is copied to what would be its children if the 

601 grow move was accepted. 

602 

603 Parameters 

604 ---------- 

605 leaf_trees 

606 The leaf values. 

607 moves 

608 The proposed moves, see `propose_moves`. 

609 

610 Returns 

611 ------- 

612 The modified leaf values. 

613 """ 

614 values_at_node = leaf_trees[..., moves.node] 1ab

615 return ( 1ab

616 leaf_trees.at[..., jnp.where(moves.grow, moves.left, leaf_trees.size)] 

617 .set(values_at_node) 

618 .at[..., jnp.where(moves.grow, moves.right, leaf_trees.size)] 

619 .set(values_at_node) 

620 ) 

621 

622 

623def _logdet_from_chol(L: Float32[Array, '... k k']) -> Float32[Array, '...']: 

624 """Compute logdet of A = LL' via Cholesky (sum of log of diag^2).""" 

625 diags: Float32[Array, '... k'] = jnp.diagonal(L, axis1=-2, axis2=-1) 1de

626 return 2.0 * jnp.sum(jnp.log(diags), axis=-1) 1de

627 

628 

629def _precompute_likelihood_terms_uv( 

630 error_cov_inv: Float32[Array, ''], 

631 leaf_prior_cov_inv: Float32[Array, ''], 

632 move_precs: Precs | Counts, 

633) -> tuple[PreLkV, PreLk]: 

634 sigma2 = jnp.reciprocal(error_cov_inv) 1ab

635 sigma_mu2 = jnp.reciprocal(leaf_prior_cov_inv) 1ab

636 left = sigma2 + move_precs.left * sigma_mu2 1ab

637 right = sigma2 + move_precs.right * sigma_mu2 1ab

638 total = sigma2 + move_precs.total * sigma_mu2 1ab

639 prelkv = PreLkV( 1ab

640 left=left, 

641 right=right, 

642 total=total, 

643 log_sqrt_term=jnp.log(sigma2 * total / (left * right)) / 2, 

644 ) 

645 return prelkv, PreLk(exp_factor=error_cov_inv / leaf_prior_cov_inv / 2) 1ab

646 

647 

648def _precompute_likelihood_terms_mv( 

649 error_cov_inv: Float32[Array, 'k k'], 

650 leaf_prior_cov_inv: Float32[Array, 'k k'], 

651 move_precs: Counts, 

652) -> tuple[PreLkV, None]: 

653 nL: UInt[Array, 'num_trees 1 1'] = move_precs.left[..., None, None] 1de

654 nR: UInt[Array, 'num_trees 1 1'] = move_precs.right[..., None, None] 1de

655 nT: UInt[Array, 'num_trees 1 1'] = move_precs.total[..., None, None] 1de

656 

657 L_left: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1de

658 error_cov_inv * nL + leaf_prior_cov_inv 

659 ) 

660 L_right: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1de

661 error_cov_inv * nR + leaf_prior_cov_inv 

662 ) 

663 L_total: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1de

664 error_cov_inv * nT + leaf_prior_cov_inv 

665 ) 

666 

667 log_sqrt_term: Float32[Array, ' num_trees'] = 0.5 * ( 1de

668 _logdet_from_chol(chol_with_gersh(leaf_prior_cov_inv)) 

669 + _logdet_from_chol(L_total) 

670 - _logdet_from_chol(L_left) 

671 - _logdet_from_chol(L_right) 

672 ) 

673 

674 def _term_from_chol( 1de

675 L: Float32[Array, 'num_trees k k'], 

676 ) -> Float32[Array, 'num_trees k k']: 

677 rhs: Float32[Array, 'num_trees k k'] = jnp.broadcast_to(error_cov_inv, L.shape) 1de

678 Y: Float32[Array, 'num_trees k k'] = solve_triangular(L, rhs, lower=True) 1de

679 return Y.mT @ Y 1de

680 

681 prelkv = PreLkV( 1de

682 left=_term_from_chol(L_left), 

683 right=_term_from_chol(L_right), 

684 total=_term_from_chol(L_total), 

685 log_sqrt_term=log_sqrt_term, 

686 ) 

687 

688 return prelkv, None 1de

689 

690 

691@named_call 

692def precompute_likelihood_terms( 

693 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

694 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

695 move_precs: Precs | Counts, 

696) -> tuple[PreLkV, PreLk | None]: 

697 """ 

698 Pre-compute terms used in the likelihood ratio of the acceptance step. 

699 

700 Handles both univariate and multivariate cases based on the shape of the 

701 input arrays. The multivariate implementation assumes a homoskedastic error 

702 model (i.e., the residual covariance is the same for all observations). 

703 

704 Parameters 

705 ---------- 

706 error_cov_inv 

707 The inverse error variance (univariate) or the inverse of the error 

708 covariance matrix (multivariate). For univariate case, this is the 

709 inverse global error variance factor if `prec_scale` is set. 

710 leaf_prior_cov_inv 

711 The inverse prior variance of each leaf (univariate) or the inverse of 

712 prior covariance matrix of each leaf (multivariate). 

713 move_precs 

714 The likelihood precision scale in the leaves grown or pruned by the 

715 moves, under keys 'left', 'right', and 'total' (left + right). 

716 

717 Returns 

718 ------- 

719 prelkv : PreLkV 

720 Pre-computed terms of the likelihood ratio, one per tree. 

721 prelk : PreLk | None 

722 Pre-computed terms of the likelihood ratio, shared by all trees. 

723 """ 

724 if error_cov_inv.ndim == 2: 1adeb

725 assert isinstance(move_precs, Counts) 1de

726 return _precompute_likelihood_terms_mv( 1de

727 error_cov_inv, leaf_prior_cov_inv, move_precs 

728 ) 

729 else: 

730 return _precompute_likelihood_terms_uv( 1ab

731 error_cov_inv, leaf_prior_cov_inv, move_precs 

732 ) 

733 

734 

735def _precompute_leaf_terms_uv( 

736 key: Key[Array, ''], 

737 prec_trees: Float32[Array, 'num_trees 2**d'], 

738 error_cov_inv: Float32[Array, ''], 

739 leaf_prior_cov_inv: Float32[Array, ''], 

740 z: Float32[Array, 'num_trees 2**d'] | None = None, 

741) -> PreLf: 

742 prec_lk = prec_trees * error_cov_inv 1ab

743 var_post = jnp.reciprocal(prec_lk + leaf_prior_cov_inv) 1ab

744 if z is None: 1abp

745 z = random.normal(key, prec_trees.shape, error_cov_inv.dtype) 1ab

746 return PreLf( 1abp

747 mean_factor=var_post * error_cov_inv, 

748 # | mean = mean_lk * prec_lk * var_post 

749 # | resid_tree = mean_lk * prec_tree --> 

750 # | --> mean_lk = resid_tree / prec_tree (kind of) 

751 # | mean_factor = 

752 # | = mean / resid_tree = 

753 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree = 

754 # | = 1 / prec_tree * prec_tree / sigma2 * var_post = 

755 # | = var_post / sigma2 

756 centered_leaves=z * jnp.sqrt(var_post), 

757 ) 

758 

759 

760def _precompute_leaf_terms_mv( 

761 key: Key[Array, ''], 

762 prec_trees: Float32[Array, 'num_trees 2**d'], 

763 error_cov_inv: Float32[Array, 'k k'], 

764 leaf_prior_cov_inv: Float32[Array, 'k k'], 

765 z: Float32[Array, 'num_trees 2**d k'] | None = None, 

766) -> PreLf: 

767 num_trees, tree_size = prec_trees.shape 1de

768 k = error_cov_inv.shape[0] 1de

769 n_k: Float32[Array, 'num_trees tree_size 1 1'] = prec_trees[..., None, None] 1de

770 

771 # Only broadcast the inverse of error covariance matrix to satisfy JAX's 

772 # batching rules for `lax.linalg.solve_triangular`, which does not support 

773 # implicit broadcasting. 

774 error_cov_inv_batched = jnp.broadcast_to( 1de

775 error_cov_inv, (num_trees, tree_size, k, k) 

776 ) 

777 

778 posterior_precision: Float32[Array, 'num_trees tree_size k k'] = ( 1de

779 leaf_prior_cov_inv + n_k * error_cov_inv_batched 

780 ) 

781 

782 L_prec: Float32[Array, 'num_trees tree_size k k'] = chol_with_gersh( 1de

783 posterior_precision 

784 ) 

785 Y: Float32[Array, 'num_trees tree_size k k'] = solve_triangular( 1de

786 L_prec, error_cov_inv_batched, lower=True 

787 ) 

788 mean_factor: Float32[Array, 'num_trees tree_size k k'] = solve_triangular( 1de

789 L_prec, Y, trans='T', lower=True 

790 ) 

791 mean_factor = mean_factor.mT 1de

792 mean_factor_out: Float32[Array, 'num_trees k k tree_size'] = jnp.moveaxis( 1de

793 mean_factor, 1, -1 

794 ) 

795 

796 if z is None: 1dep

797 z = random.normal(key, (num_trees, tree_size, k)) 1de

798 centered_leaves: Float32[Array, 'num_trees tree_size k'] = solve_triangular( 1dep

799 L_prec, z, trans='T' 

800 ) 

801 centered_leaves_out: Float32[Array, 'num_trees k tree_size'] = jnp.swapaxes( 1de

802 centered_leaves, -1, -2 

803 ) 

804 

805 return PreLf(mean_factor=mean_factor_out, centered_leaves=centered_leaves_out) 1de

806 

807 

808@named_call 

809def precompute_leaf_terms( 

810 key: Key[Array, ''], 

811 prec_trees: Float32[Array, 'num_trees 2**d'], 

812 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

813 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

814 z: Float32[Array, 'num_trees 2**d'] 

815 | Float32[Array, 'num_trees 2**d k'] 

816 | None = None, 

817) -> PreLf: 

818 """ 

819 Pre-compute terms used to sample leaves from their posterior. 

820 

821 Handles both univariate and multivariate cases based on the shape of the 

822 input arrays. 

823 

824 Parameters 

825 ---------- 

826 key 

827 A jax random key. 

828 prec_trees 

829 The likelihood precision scale in each potential or actual leaf node. 

830 error_cov_inv 

831 The inverse error variance (univariate) or the inverse of error 

832 covariance matrix (multivariate). For univariate case, this is the 

833 inverse global error variance factor if `prec_scale` is set. 

834 leaf_prior_cov_inv 

835 The inverse prior variance of each leaf (univariate) or the inverse of 

836 prior covariance matrix of each leaf (multivariate). 

837 z 

838 Optional standard normal noise to use for sampling the centered leaves. 

839 This is intended for testing purposes only. 

840 

841 Returns 

842 ------- 

843 Pre-computed terms for leaf sampling. 

844 """ 

845 if error_cov_inv.ndim == 2: 1adeb

846 return _precompute_leaf_terms_mv( 1de

847 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z 

848 ) 

849 else: 

850 return _precompute_leaf_terms_uv( 1ab

851 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z 

852 ) 

853 

854 

855@named_call 

856def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 

857 """ 

858 Accept/reject the moves one tree at a time. 

859 

860 This is the most performance-sensitive function because it contains all and 

861 only the parts of the algorithm that can not be parallelized across trees. 

862 

863 Parameters 

864 ---------- 

865 pso 

866 The output of `accept_moves_parallel_stage`. 

867 

868 Returns 

869 ------- 

870 bart : State 

871 A partially updated BART mcmc state. 

872 moves : Moves 

873 The accepted/rejected moves, with `acc` and `to_prune` set. 

874 """ 

875 

876 def loop( 1ab

877 resid: Float32[Array, ' n'] | Float32[Array, ' k n'], pt: SeqStageInPerTree 

878 ) -> tuple[ 

879 Float32[Array, ' n'] | Float32[Array, ' k n'], 

880 tuple[ 

881 Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'], 

882 Bool[Array, ''], 

883 Bool[Array, ''], 

884 Float32[Array, ''] | None, 

885 ], 

886 ]: 

887 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1ab

888 resid, 

889 SeqStageInAllTrees( 

890 pso.bart.X, 

891 pso.bart.config.resid_num_batches, 

892 pso.bart.config.mesh, 

893 pso.bart.prec_scale, 

894 pso.bart.forest.log_likelihood is not None, 

895 pso.prelk, 

896 ), 

897 pt, 

898 ) 

899 return resid, (leaf_tree, acc, to_prune, lkratio) 1ab

900 

901 pts = SeqStageInPerTree( 1ab

902 pso.bart.forest.leaf_tree, 

903 pso.prec_trees, 

904 pso.moves, 

905 pso.move_precs, 

906 pso.bart.forest.leaf_indices, 

907 pso.prelkv, 

908 pso.prelf, 

909 ) 

910 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1ab

911 

912 bart = replace( 1ab

913 pso.bart, 

914 resid=resid, 

915 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio), 

916 ) 

917 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1ab

918 

919 return bart, moves 1ab

920 

921 

922class SeqStageInAllTrees(Module): 

923 """The inputs to `accept_move_and_sample_leaves` that are shared by all trees.""" 

924 

925 X: UInt[Array, 'p n'] 

926 """The predictors.""" 

927 

928 resid_num_batches: int | None = field(static=True) 

929 """The number of batches for computing the sum of residuals in each leaf.""" 

930 

931 mesh: Mesh | None = field(static=True) 

932 """The mesh of devices to use.""" 

933 

934 prec_scale: Float32[Array, ' n'] | None 

935 """The scale of the precision of the error on each datapoint. If None, it 

936 is assumed to be 1.""" 

937 

938 save_ratios: bool = field(static=True) 

939 """Whether to save the acceptance ratios.""" 

940 

941 prelk: PreLk | None 

942 """The pre-computed terms of the likelihood ratio which are shared across 

943 trees.""" 

944 

945 

946class SeqStageInPerTree(Module): 

947 """The inputs to `accept_move_and_sample_leaves` that are separate for each tree.""" 

948 

949 leaf_tree: Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'] 

950 """The leaf values of the tree.""" 

951 

952 prec_tree: Float32[Array, ' 2**d'] 

953 """The likelihood precision scale in each potential or actual leaf node.""" 

954 

955 move: Moves 

956 """The proposed move, see `propose_moves`.""" 

957 

958 move_precs: Precs | Counts 

959 """The likelihood precision scale in each node modified by the moves.""" 

960 

961 leaf_indices: UInt[Array, ' n'] 

962 """The leaf indices for the largest version of the tree compatible with 

963 the move.""" 

964 

965 prelkv: PreLkV 

966 """The pre-computed terms of the likelihood ratio which are specific to the tree.""" 

967 

968 prelf: PreLf 

969 """The pre-computed terms of the leaf sampling which are specific to the tree.""" 

970 

971 

972@named_call 

973def accept_move_and_sample_leaves( 

974 resid: Float32[Array, ' n'] | Float32[Array, ' k n'], 

975 at: SeqStageInAllTrees, 

976 pt: SeqStageInPerTree, 

977) -> tuple[ 

978 Float32[Array, ' n'] | Float32[Array, ' k n'], 

979 Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'], 

980 Bool[Array, ''], 

981 Bool[Array, ''], 

982 Float32[Array, ''] | None, 

983]: 

984 """ 

985 Accept or reject a proposed move and sample the new leaf values. 

986 

987 Parameters 

988 ---------- 

989 resid 

990 The residuals (data minus forest value). 

991 at 

992 The inputs that are the same for all trees. 

993 pt 

994 The inputs that are separate for each tree. 

995 

996 Returns 

997 ------- 

998 resid : Float32[Array, 'n'] | Float32[Array, ' k n'] 

999 The updated residuals (data minus forest value). 

1000 leaf_tree : Float32[Array, '2**d'] | Float32[Array, ' k 2**d'] 

1001 The new leaf values of the tree. 

1002 acc : Bool[Array, ''] 

1003 Whether the move was accepted. 

1004 to_prune : Bool[Array, ''] 

1005 Whether, to reflect the acceptance status of the move, the state should 

1006 be updated by pruning the leaves involved in the move. 

1007 log_lk_ratio : Float32[Array, ''] | None 

1008 The logarithm of the likelihood ratio for the move. `None` if not to be 

1009 saved. 

1010 """ 

1011 # sum residuals in each leaf, in tree proposed by grow move 

1012 if at.prec_scale is None: 1akb

1013 scaled_resid = resid 1ab

1014 else: 

1015 scaled_resid = resid * at.prec_scale 1k

1016 

1017 tree_size = pt.leaf_tree.shape[-1] # 2**d 1ab

1018 

1019 resid_tree = sum_resid( 1ab

1020 scaled_resid, pt.leaf_indices, tree_size, at.resid_num_batches, at.mesh 

1021 ) 

1022 

1023 # subtract starting tree from function 

1024 resid_tree += pt.prec_tree * pt.leaf_tree 1ab

1025 

1026 # sum residuals in parent node modified by move and compute likelihood 

1027 resid_left = resid_tree[..., pt.move.left] 1ab

1028 resid_right = resid_tree[..., pt.move.right] 1ab

1029 resid_total = resid_left + resid_right 1ab

1030 assert pt.move.node.dtype == jnp.int32 1ab

1031 resid_tree = resid_tree.at[..., pt.move.node].set(resid_total) 1ab

1032 

1033 log_lk_ratio = compute_likelihood_ratio( 1ab

1034 resid_total, resid_left, resid_right, pt.prelkv, at.prelk 

1035 ) 

1036 

1037 # calculate accept/reject ratio 

1038 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1ab

1039 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1ab

1040 if not at.save_ratios: 1ahfgb

1041 log_lk_ratio = None 1hb

1042 

1043 # determine whether to accept the move 

1044 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1afgb

1045 

1046 # compute leaves posterior and sample leaves 

1047 if resid.ndim > 1: 1adeb

1048 mean_post = jnp.einsum('kil,kl->il', pt.prelf.mean_factor, resid_tree) 1de

1049 else: 

1050 mean_post = resid_tree * pt.prelf.mean_factor 1ab

1051 leaf_tree = mean_post + pt.prelf.centered_leaves 1ab

1052 

1053 # copy leaves around such that the leaf indices point to the correct leaf 

1054 to_prune = acc ^ pt.move.grow 1ab

1055 leaf_tree = ( 1ab

1056 leaf_tree.at[..., jnp.where(to_prune, pt.move.left, tree_size)] 

1057 .set(leaf_tree[..., pt.move.node]) 

1058 .at[..., jnp.where(to_prune, pt.move.right, tree_size)] 

1059 .set(leaf_tree[..., pt.move.node]) 

1060 ) 

1061 # replace old tree with new tree in function values 

1062 resid += (pt.leaf_tree - leaf_tree)[..., pt.leaf_indices] 1ab

1063 

1064 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1ab

1065 

1066 

1067@named_call 

1068@partial(jnp.vectorize, excluded=(1, 2, 3, 4), signature='(n)->(ts)') 

1069def sum_resid( 

1070 scaled_resid: Float32[Array, ' n'] | Float32[Array, 'k n'], 

1071 leaf_indices: UInt[Array, ' n'], 

1072 tree_size: int, 

1073 resid_num_batches: int | None, 

1074 mesh: Mesh | None, 

1075) -> Float32[Array, ' {tree_size}'] | Float32[Array, 'k {tree_size}']: 

1076 """ 

1077 Sum the residuals in each leaf. 

1078 

1079 Handles both univariate and multivariate cases based on the shape of the 

1080 input arrays. 

1081 

1082 Parameters 

1083 ---------- 

1084 scaled_resid 

1085 The residuals (data minus forest value) multiplied by the error 

1086 precision scale. For multivariate case, shape is ``(k, n)`` where ``k`` 

1087 is the number of outcome columns. 

1088 leaf_indices 

1089 The leaf indices of the tree (in which leaf each data point falls into). 

1090 tree_size 

1091 The size of the tree array (2 ** d). 

1092 resid_num_batches 

1093 The number of batches for computing the sum of residuals in each leaf. 

1094 mesh 

1095 The mesh of devices to use. 

1096 

1097 Returns 

1098 ------- 

1099 The sum of the residuals at data points in each leaf. For multivariate 

1100 case, returns per-leaf sums of residual vectors. 

1101 """ 

1102 return _scatter_add( 1ab

1103 scaled_resid, leaf_indices, tree_size, jnp.float32, resid_num_batches, mesh 

1104 ) 

1105 

1106 

1107def _scatter_add( 

1108 values: Float32[Array, ' n'] | int, 

1109 indices: Integer[Array, ' n'], 

1110 size: int, 

1111 dtype: jnp.dtype, 

1112 batch_size: int | None, 

1113 mesh: Mesh | None, 

1114) -> Shaped[Array, ' {size}']: 

1115 """Indexed reduce with optional batching.""" 

1116 # check `values` 

1117 values = jnp.asarray(values) 1ab

1118 assert values.ndim == 0 or values.shape == indices.shape 1ab

1119 

1120 # set configuration 

1121 _scatter_add = partial( 1ab

1122 _scatter_add_impl, size=size, dtype=dtype, num_batches=batch_size 

1123 ) 

1124 

1125 # single-device invocation 

1126 if mesh is None or 'data' not in mesh.axis_names: 1vahibwn

1127 return _scatter_add(values, indices) 1vaibw

1128 

1129 # multi-device invocation 

1130 if values.shape: 1hn

1131 in_specs = PartitionSpec('data'), PartitionSpec('data') 1hn

1132 else: 

1133 in_specs = PartitionSpec(), PartitionSpec('data') 1hn

1134 _scatter_add = partial(_scatter_add, final_psum=True) 1hn

1135 _scatter_add = shard_map( 1hn

1136 _scatter_add, 

1137 in_specs=in_specs, 

1138 out_specs=PartitionSpec(), 

1139 mesh=mesh, 

1140 **_get_shard_map_patch_kwargs(), 

1141 ) 

1142 return _scatter_add(values, indices) 1hn

1143 

1144 

1145def _get_shard_map_patch_kwargs() -> dict[str, bool]: 

1146 # WORKAROUND(jax<=0.8.2): vmap(shard_map(psum)), jax#34249; the 

1147 # jax_disable_vmap_shmap_error config did not work 

1148 if jax.__version__ in ('0.8.1', '0.8.2'): 1148 ↛ 1149line 1148 didn't jump to line 1149 because the condition on line 1148 was never true1hn

1149 return {'check_vma': False} 

1150 else: 

1151 return {} 1hn

1152 

1153 

1154def _scatter_add_impl( 

1155 values: Float32[Array, ' n'] | Int32[Array, ''], 

1156 indices: Integer[Array, ' n'], 

1157 /, 

1158 *, 

1159 size: int, 

1160 dtype: jnp.dtype, 

1161 num_batches: int | None, 

1162 final_psum: bool = False, 

1163) -> Shaped[Array, ' {size}']: 

1164 if num_batches is None: 1akbj

1165 out = jnp.zeros(size, dtype).at[indices].add(values) 1aj

1166 

1167 else: 

1168 # in the sharded case, n is the size of the local shard, not the full size 

1169 (n,) = indices.shape 1kb

1170 batch_indices = jnp.arange(n) % num_batches 1kb

1171 out = ( 1kb

1172 jnp.zeros((size, num_batches), dtype) 

1173 .at[indices, batch_indices] 

1174 .add(values) 

1175 .sum(axis=1) 

1176 ) 

1177 

1178 if final_psum: 1ahbn

1179 out = lax.psum(out, 'data') 1hn

1180 return out 1ab

1181 

1182 

1183def _compute_likelihood_ratio_uv( 

1184 total_resid: Float32[Array, ''], 

1185 left_resid: Float32[Array, ''], 

1186 right_resid: Float32[Array, ''], 

1187 prelkv: PreLkV, 

1188 prelk: PreLk, 

1189) -> Float32[Array, '']: 

1190 exp_term = prelk.exp_factor * ( 1ab

1191 left_resid * left_resid / prelkv.left 

1192 + right_resid * right_resid / prelkv.right 

1193 - total_resid * total_resid / prelkv.total 

1194 ) 

1195 return prelkv.log_sqrt_term + exp_term 1ab

1196 

1197 

1198def _compute_likelihood_ratio_mv( 

1199 total_resid: Float32[Array, ' k'], 

1200 left_resid: Float32[Array, ' k'], 

1201 right_resid: Float32[Array, ' k'], 

1202 prelkv: PreLkV, 

1203) -> Float32[Array, '']: 

1204 def _quadratic_form( 1de

1205 r: Float32[Array, ' k'], mat: Float32[Array, 'k k'] 

1206 ) -> Float32[Array, '']: 

1207 return r @ mat @ r 1de

1208 

1209 qf_left = _quadratic_form(left_resid, prelkv.left) 1de

1210 qf_right = _quadratic_form(right_resid, prelkv.right) 1de

1211 qf_total = _quadratic_form(total_resid, prelkv.total) 1de

1212 exp_term = 0.5 * (qf_left + qf_right - qf_total) 1de

1213 return prelkv.log_sqrt_term + exp_term 1de

1214 

1215 

1216@named_call 

1217def compute_likelihood_ratio( 

1218 total_resid: Float32[Array, ''] | Float32[Array, ' k'], 

1219 left_resid: Float32[Array, ''] | Float32[Array, ' k'], 

1220 right_resid: Float32[Array, ''] | Float32[Array, ' k'], 

1221 prelkv: PreLkV, 

1222 prelk: PreLk | None, 

1223) -> Float32[Array, '']: 

1224 """ 

1225 Compute the likelihood ratio of a grow move. 

1226 

1227 Handles both univariate and multivariate cases based on the shape of the 

1228 residual arrays. 

1229 

1230 Parameters 

1231 ---------- 

1232 total_resid 

1233 left_resid 

1234 right_resid 

1235 The sum of the residuals (scaled by error precision scale) of the 

1236 datapoints falling in the nodes involved in the moves. 

1237 prelkv 

1238 prelk 

1239 The pre-computed terms of the likelihood ratio, see 

1240 `precompute_likelihood_terms`. 

1241 

1242 Returns 

1243 ------- 

1244 The log-likelihood ratio log P(data | new tree) - log P(data | old tree). 

1245 """ 

1246 if total_resid.ndim > 0: 1adeb

1247 return _compute_likelihood_ratio_mv( 1de

1248 total_resid, left_resid, right_resid, prelkv 

1249 ) 

1250 else: 

1251 assert prelk is not None 1ab

1252 return _compute_likelihood_ratio_uv( 1ab

1253 total_resid, left_resid, right_resid, prelkv, prelk 

1254 ) 

1255 

1256 

1257@named_call 

1258def accept_moves_final_stage(bart: State, moves: Moves) -> State: 

1259 """ 

1260 Post-process the mcmc state after accepting/rejecting the moves. 

1261 

1262 This function is separate from `accept_moves_sequential_stage` to signal it 

1263 can work in parallel across trees. 

1264 

1265 Parameters 

1266 ---------- 

1267 bart 

1268 A partially updated BART mcmc state. 

1269 moves 

1270 The proposed moves (see `propose_moves`) as updated by 

1271 `accept_moves_sequential_stage`. 

1272 

1273 Returns 

1274 ------- 

1275 The fully updated BART mcmc state. 

1276 """ 

1277 return replace( 1ab

1278 bart, 

1279 forest=replace( 

1280 bart.forest, 

1281 grow_acc_count=jnp.sum(moves.acc & moves.grow), 

1282 prune_acc_count=jnp.sum(moves.acc & ~moves.grow), 

1283 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves), 

1284 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves), 

1285 ), 

1286 ) 

1287 

1288 

1289@named_call 

1290@vmap_nodoc 

1291def apply_moves_to_leaf_indices( 

1292 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves 

1293) -> UInt[Array, 'num_trees n']: 

1294 """ 

1295 Update the leaf indices to match the accepted move. 

1296 

1297 Parameters 

1298 ---------- 

1299 leaf_indices 

1300 The index of the leaf each datapoint falls into, if the grow move was 

1301 accepted. 

1302 moves 

1303 The proposed moves (see `propose_moves`), as updated by 

1304 `accept_moves_sequential_stage`. 

1305 

1306 Returns 

1307 ------- 

1308 The updated leaf indices. 

1309 """ 

1310 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1ab

1311 is_child = (leaf_indices & mask) == moves.left 1ab

1312 assert moves.to_prune is not None 1ab

1313 return jnp.where( 1ab

1314 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices 

1315 ) 

1316 

1317 

1318@named_call 

1319@vmap_nodoc 

1320def apply_moves_to_split_trees( 

1321 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves 

1322) -> UInt[Array, 'num_trees 2**(d-1)']: 

1323 """ 

1324 Update the split trees to match the accepted move. 

1325 

1326 Parameters 

1327 ---------- 

1328 split_tree 

1329 The cutpoints of the decision nodes in the initial trees. 

1330 moves 

1331 The proposed moves (see `propose_moves`), as updated by 

1332 `accept_moves_sequential_stage`. 

1333 

1334 Returns 

1335 ------- 

1336 The updated split trees. 

1337 """ 

1338 assert moves.to_prune is not None 1ab

1339 return ( 1ab

1340 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)] 

1341 .set(moves.grow_split.astype(split_tree.dtype)) 

1342 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)] 

1343 .set(0) 

1344 ) 

1345 

1346 

1347@jax.jit 

1348def _sample_wishart_bartlett( 

1349 key: Key[Array, ''], df: Float32[Array, ''], scale_inv: Float32[Array, 'k k'] 

1350) -> Float32[Array, 'k k']: 

1351 """ 

1352 Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition. 

1353 

1354 Parameters 

1355 ---------- 

1356 key 

1357 A JAX random key 

1358 df 

1359 Degrees of freedom 

1360 scale_inv 

1361 Scale matrix of the corresponding Inverse Wishart distribution 

1362 

1363 Returns 

1364 ------- 

1365 A sample from Wishart(df, scale) 

1366 """ 

1367 keys = split(key) 1de

1368 

1369 # Diagonal elements: A_ii ~ sqrt(chi^2(df - i)) 

1370 # chi^2(k) = Gamma(k/2, scale=2) 

1371 k, _ = scale_inv.shape 1de

1372 df_vector = df - jnp.arange(k) 1de

1373 chi2_samples = random.gamma(keys.pop(), df_vector / 2.0) * 2.0 1de

1374 diag_A = jnp.sqrt(chi2_samples) 1de

1375 

1376 off_diag_A = random.normal(keys.pop(), (k, k)) 1de

1377 A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A) 1de

1378 L = chol_with_gersh(scale_inv, absolute_eps=True) 1de

1379 T = solve_triangular(L, A, lower=True, trans='T') 1de

1380 

1381 return T @ T.T 1de

1382 

1383 

1384def _step_error_cov_inv_uv(key: Key[Array, ''], bart: State) -> State: 

1385 assert bart.error_cov_df is not None 1ab

1386 assert bart.error_cov_scale is not None 1ab

1387 

1388 resid = bart.resid 1ab

1389 # inverse gamma prior: alpha = df / 2, beta = scale / 2 

1390 alpha = bart.error_cov_df / 2 + resid.size / 2 1ab

1391 if bart.prec_scale is None: 1akb

1392 scaled_resid = resid 1ab

1393 else: 

1394 scaled_resid = resid * bart.prec_scale 1k

1395 norm2 = resid @ scaled_resid 1ab

1396 beta = bart.error_cov_scale / 2 + norm2 / 2 1ab

1397 

1398 sample = random.gamma(key, alpha) 1ab

1399 # random.gamma seems to be slow at compiling, maybe cdf inversion would 

1400 # be better, but it's not implemented in jax 

1401 return replace(bart, error_cov_inv=sample / beta) 1ab

1402 

1403 

1404def _step_error_cov_inv_mv(key: Key[Array, ''], bart: State) -> State: 

1405 assert bart.error_cov_df is not None 1de

1406 assert bart.error_cov_scale is not None 1de

1407 

1408 n = bart.resid.shape[-1] 1de

1409 df_post = bart.error_cov_df + n 1de

1410 scale_post = bart.error_cov_scale + bart.resid @ bart.resid.T 1de

1411 

1412 prec = _sample_wishart_bartlett(key, df_post, scale_post) 1de

1413 return replace(bart, error_cov_inv=prec) 1de

1414 

1415 

1416def _step_error_cov_inv_diag(key: Key[Array, ''], bart: State) -> State: 

1417 """Update diagonal error_cov_inv for mixed binary-continuous. 

1418 

1419 Each continuous component gets an independent inverse-gamma update 

1420 (like `_step_error_cov_inv_uv` repeated per component). Binary 

1421 components stay fixed at 1. 

1422 """ 

1423 assert bart.binary_indices is not None 1ilm

1424 assert bart.error_cov_scale is not None 1ilm

1425 assert bart.error_cov_df is not None 1ilm

1426 

1427 # per-component sum of squared residuals, shape (k,) 

1428 norm2 = jnp.einsum('kn,kn->k', bart.resid, bart.resid) 1ilm

1429 

1430 # inverse-gamma posterior parameters 

1431 *_, k, n = bart.resid.shape 1ilm

1432 scale_diag = jnp.diag(bart.error_cov_scale) 1ilm

1433 alpha = bart.error_cov_df / 2 + n / 2 1ilm

1434 beta = scale_diag / 2 + norm2 / 2 1ilm

1435 

1436 # sample independent gamma variates for all k components 

1437 samples = random.gamma(key, alpha, (k,)) 1ilm

1438 new_diag = samples / beta 1ilm

1439 

1440 # keep binary components at 1.0 

1441 new_diag = new_diag.at[bart.binary_indices].set(1.0) 1ilm

1442 

1443 return replace(bart, error_cov_inv=jnp.diag(new_diag)) 1ilm

1444 

1445 

1446@named_call 

1447def step_error_cov_inv(key: Key[Array, ''], bart: State) -> State: 

1448 """ 

1449 MCMC-update the inverse error covariance. 

1450 

1451 Handles univariate, multivariate, and mixed binary-continuous cases. 

1452 

1453 Parameters 

1454 ---------- 

1455 key 

1456 A jax random key. 

1457 bart 

1458 A BART mcmc state. 

1459 

1460 Returns 

1461 ------- 

1462 The new BART mcmc state, with an updated `error_cov_inv`. 

1463 """ 

1464 if bart.binary_indices is not None: 1ailbm

1465 return _step_error_cov_inv_diag(key, bart) 1ilm

1466 elif bart.error_cov_inv.ndim == 2: 1adeb

1467 return _step_error_cov_inv_mv(key, bart) 1de

1468 else: 

1469 return _step_error_cov_inv_uv(key, bart) 1ab

1470 

1471 

1472@named_call 

1473def step_z(key: Key[Array, ''], bart: State) -> State: 

1474 """ 

1475 MCMC-update the latent variable for binary regression. 

1476 

1477 Parameters 

1478 ---------- 

1479 key 

1480 A jax random key. 

1481 bart 

1482 A BART MCMC state. 

1483 

1484 Returns 

1485 ------- 

1486 The updated BART MCMC state. 

1487 """ 

1488 assert bart.z is not None 1hj

1489 assert bart.binary_y is not None 1hj

1490 

1491 if bart.binary_indices is not None: 1hilmj

1492 resid = bart.resid[..., bart.binary_indices, :] 1ilm

1493 else: 

1494 resid = bart.resid 1hj

1495 

1496 trees_plus_offset = bart.z - resid 1hj

1497 resid = truncated_normal_onesided(key, (), ~bart.binary_y, -trees_plus_offset) 1hj

1498 z = trees_plus_offset + resid 1hj

1499 

1500 if bart.binary_indices is not None: 1hilmj

1501 resid = bart.resid.at[..., bart.binary_indices, :].set(resid) 1ilm

1502 

1503 return replace(bart, z=z, resid=resid) 1hj

1504 

1505 

1506@named_call 

1507def step_s(key: Key[Array, ''], bart: State) -> State: 

1508 """ 

1509 Update `log_s` using Dirichlet sampling. 

1510 

1511 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior 

1512 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where 

1513 varcount is the count of how many times each variable is used in the 

1514 current forest. 

1515 

1516 Parameters 

1517 ---------- 

1518 key 

1519 Random key for sampling. 

1520 bart 

1521 The current BART state. 

1522 

1523 Returns 

1524 ------- 

1525 Updated BART state with re-sampled `log_s`. 

1526 

1527 Notes 

1528 ----- 

1529 This full conditional is approximated, because it does not take into account 

1530 that there are forbidden decision rules. 

1531 """ 

1532 assert bart.forest.theta is not None 1afg

1533 

1534 # histogram current variable usage 

1535 p = bart.forest.max_split.size 1afg

1536 varcount = var_histogram( 1afg

1537 p, bart.forest.var_tree, bart.forest.split_tree, sum_batch_axis=-1 

1538 ) 

1539 

1540 # sample from Dirichlet posterior 

1541 alpha = bart.forest.theta / p + varcount 1afg

1542 log_s = random.loggamma(key, alpha) 1afg

1543 

1544 # update forest with new s 

1545 return replace(bart, forest=replace(bart.forest, log_s=log_s)) 1afg

1546 

1547 

1548@named_call 

1549def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State: 

1550 """ 

1551 Update `theta`. 

1552 

1553 The prior is theta / (theta + rho) ~ Beta(a, b). 

1554 

1555 Parameters 

1556 ---------- 

1557 key 

1558 Random key for sampling. 

1559 bart 

1560 The current BART state. 

1561 num_grid 

1562 The number of points in the evenly-spaced grid used to sample 

1563 theta / (theta + rho). 

1564 

1565 Returns 

1566 ------- 

1567 Updated BART state with re-sampled `theta`. 

1568 """ 

1569 assert bart.forest.log_s is not None 1afg

1570 assert bart.forest.rho is not None 1afg

1571 assert bart.forest.a is not None 1afg

1572 assert bart.forest.b is not None 1afg

1573 

1574 # the grid points are the midpoints of num_grid bins in (0, 1) 

1575 padding = 1 / (2 * num_grid) 1afg

1576 lamda_grid = jnp.linspace(padding, 1 - padding, num_grid) 1afg

1577 

1578 # normalize s 

1579 log_s = bart.forest.log_s - logsumexp(bart.forest.log_s) 1afg

1580 

1581 # sample lambda 

1582 logp, theta_grid = _log_p_lamda( 1afg

1583 lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b 

1584 ) 

1585 i = random.categorical(key, logp) 1afg

1586 theta = theta_grid[i] 1afg

1587 

1588 return replace(bart, forest=replace(bart.forest, theta=theta)) 1afg

1589 

1590 

1591def _log_p_lamda( 

1592 lamda: Float32[Array, ' num_grid'], 

1593 log_s: Float32[Array, ' p'], 

1594 rho: Float32[Array, ''], 

1595 a: Float32[Array, ''], 

1596 b: Float32[Array, ''], 

1597) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]: 

1598 # in the following I use lamda[::-1] == 1 - lamda 

1599 theta = rho * lamda / lamda[::-1] 1afg

1600 p = log_s.size 1afg

1601 return ( 1afg

1602 (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda) 

1603 + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda) 

1604 + gammaln(theta) 

1605 - p * gammaln(theta / p) 

1606 + theta / p * jnp.sum(log_s) 

1607 ), theta 

1608 

1609 

1610@named_call 

1611def step_sparse(key: Key[Array, ''], bart: State) -> State: 

1612 """ 

1613 Update the sparsity parameters. 

1614 

1615 This invokes `step_s`, and then `step_theta` only if the parameters of 

1616 the theta prior are defined. 

1617 

1618 Parameters 

1619 ---------- 

1620 key 

1621 Random key for sampling. 

1622 bart 

1623 The current BART state. 

1624 

1625 Returns 

1626 ------- 

1627 Updated BART state with re-sampled `log_s` and `theta`. 

1628 """ 

1629 if bart.config.sparse_on_at is not None: 1xafgb

1630 bart = lax.cond( 1afg

1631 bart.config.steps_done < bart.config.sparse_on_at, 

1632 lambda _key, bart: bart, 

1633 _step_sparse, 

1634 key, 

1635 bart, 

1636 ) 

1637 return bart 1xab

1638 

1639 

1640def _step_sparse(key: Key[Array, ''], bart: State) -> State: 

1641 keys = split(key) 1afg

1642 bart = step_s(keys.pop(), bart) 1afg

1643 if bart.forest.rho is not None: 1akyfg

1644 bart = step_theta(keys.pop(), bart) 1afg

1645 return bart 1akyfg

1646 

1647 

1648@named_call 

1649def step_config(bart: State) -> State: 

1650 config = bart.config 1ab

1651 config = replace(config, steps_done=config.steps_done + 1) 1ab

1652 return replace(bart, config=config) 1ab