Coverage for src/bartz/mcmcstep/_step.py: 99%

490 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/mcmcstep/_step.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement `step`, `step_trees`, and the accept-reject logic.""" 

26 

27from dataclasses import replace 

28from functools import partial 

29from typing import overload 

30 

31from equinox import AbstractVar 

32from jax import lax, named_call, random, vmap 

33from jax import numpy as jnp 

34from jax.nn import softmax 

35from jax.scipy.linalg import solve_triangular 

36from jax.scipy.special import gammaln, logsumexp 

37from jaxtyping import Array, Bool, Float32, Int32, Key, Shaped, UInt, UInt32 

38 

39from bartz._jaxext import ( 

40 Module, 

41 field, 

42 jit, 

43 split, 

44 truncated_normal_onesided, 

45 vmap_nodoc, 

46) 

47from bartz._jaxext.random import loggamma 

48from bartz.grove import var_histogram 

49from bartz.mcmcstep._moves import Moves, propose_moves, split_range 

50from bartz.mcmcstep._reduction import ReductionConfig 

51from bartz.mcmcstep._state import ( 

52 Forest, 

53 State, 

54 StepConfig, 

55 chol_with_gersh, 

56 get_axis_size, 

57 shard_map_state, 

58 split_key_for_chains, 

59 vmap_chains, 

60) 

61 

62 

63@jit(donate_argnums=(1,)) 

64@split_key_for_chains 

65@shard_map_state 

66@vmap_chains 

67def step(key: Key[Array, ''], state: State) -> State: 

68 """ 

69 Do one MCMC step. 

70 

71 Parameters 

72 ---------- 

73 key 

74 A jax random key. 

75 state 

76 A BART mcmc state, as created by `init`. 

77 

78 Returns 

79 ------- 

80 The new BART mcmc state. 

81 

82 Notes 

83 ----- 

84 The memory of the input state is re-used for the output state, so the input 

85 state can not be used any more after calling `step`. All this applies 

86 outside of `jax.jit`. 

87 """ 

88 keys = split(key, 4) 

89 

90 state = step_trees(keys.pop(), state) 

91 

92 if state.z is not None: 

93 state = step_z(keys.pop(), state) 

94 

95 if state.error_cov_inv.nu is not None: 

96 state = step_error_cov_inv(keys.pop(), state) 

97 

98 state = step_sparse(keys.pop(), state) 

99 return step_config(state) 

100 

101 

102@named_call 

103def step_trees(key: Key[Array, ''], state: State) -> State: 

104 """ 

105 Forest sampling step of BART MCMC. 

106 

107 Parameters 

108 ---------- 

109 key 

110 A jax random key. 

111 state 

112 A BART mcmc state, as created by `init`. 

113 

114 Returns 

115 ------- 

116 The new BART mcmc state. 

117 

118 Notes 

119 ----- 

120 This function zeroes the proposal counters. 

121 """ 

122 keys = split(key) 

123 moves = propose_moves(keys.pop(), state.forest) 

124 return accept_moves_and_sample_leaves(keys.pop(), state, moves) 

125 

126 

127@named_call 

128def accept_moves_and_sample_leaves( 

129 key: Key[Array, ''], state: State, moves: Moves 

130) -> State: 

131 """ 

132 Accept or reject the proposed moves and sample the new leaf values. 

133 

134 Parameters 

135 ---------- 

136 key 

137 A jax random key. 

138 state 

139 A valid BART mcmc state. 

140 moves 

141 The proposed moves, see `propose_moves`. 

142 

143 Returns 

144 ------- 

145 A new (valid) BART mcmc state. 

146 """ 

147 pso = accept_moves_parallel_stage(key, state, moves) 

148 state, moves = accept_moves_sequential_stage(pso) 

149 return accept_moves_final_stage(state, moves) 

150 

151 

152class Counts(Module): 

153 """Number of datapoints in the nodes involved in proposed moves for each tree.""" 

154 

155 lrt: UInt[Array, '*num_trees 3'] 

156 """Number of datapoints in the left child, right child, and parent 

157 (``= left + right``), stacked along the trailing axis.""" 

158 

159 

160class PreLkV(Module): 

161 """Non-sequential terms of the likelihood ratio for each tree. 

162 

163 These terms are derived from the leaf precompute terms (`PreLf`) gathered 

164 at the nodes involved in each move. The terms for the left child, right 

165 child, and their join (the parent node) are stacked along the axis right 

166 after the tree axis. Each term is, in the univariate case, the scalar 

167 

168 ``error_cov_inv^2 / (leaf_prior_cov_inv + n * error_cov_inv)``. 

169 

170 In the multivariate homoskedastic or scalar weight case, this is the matrix term 

171 

172 ``error_cov_inv @ inv(leaf_prior_cov_inv + n * error_cov_inv) @ error_cov_inv``. 

173 

174 In the multivariate vector-weight case, this is instead 

175 

176 ``chol(leaf_prior_cov_inv + n * error_cov_inv)`` 

177 

178 ``n`` is the number of datapoints in the node, or the likelihood precision 

179 scale in the heteroskedastic case. 

180 """ 

181 

182 # `log_sqrt_term` is declared before `lrt` so its single (union-free) 

183 # annotation binds the variadic `*num_trees` axis first; otherwise the 

184 # runtime typechecker can greedily mis-bind `*num_trees` against the `k` 

185 # axis of the `... | ... k k` union (the multivariate and univariate 

186 # layouts are rank-ambiguous). 

187 log_sqrt_term: Float32[Array, '*num_trees'] 

188 """The logarithm of the square root term of the likelihood ratio.""" 

189 

190 lrt: Float32[Array, '*num_trees 3'] | Float32[Array, '*num_trees 3 k k'] 

191 """Scaled full conditional variance, scaled covariance, or precision 

192 cholesky, for the left child, right child, and their join.""" 

193 

194 

195class PreLf(Module): 

196 """Pre-computed terms used to sample leaves from their posterior. 

197 

198 These terms can be computed in parallel across trees. 

199 

200 For each tree and leaf, the terms are scalars in the univariate case 

201 (`PreLfUV`), and matrices/vectors in the multivariate case (`PreLfMV`, 

202 `PreLfMVHet`). 

203 

204 Abstract base: the layouts differ in rank, so they live in concrete 

205 subclasses with union-free annotations; a single class carrying a shape 

206 union would make the greedy variadic mis-bind against the ``k`` axes under 

207 the runtime typechecker. The concrete class also tags the meaning of 

208 `mean_factor`, which drives the dispatch in `precompute_likelihood_terms` 

209 and in the sequential stage. The ``num_trees`` axis is variadic so the same 

210 annotations also match a per-element layout if vmapped over trees. 

211 """ 

212 

213 mean_factor: AbstractVar[ 

214 Float32[Array, '*num_trees tree_size'] 

215 | Float32[Array, '*num_trees k k tree_size'] 

216 ] 

217 """The factor to be right-multiplied by the sum of the scaled residuals to 

218 obtain the posterior mean.""" 

219 

220 centered_leaves: AbstractVar[ 

221 Float32[Array, '*num_trees tree_size'] 

222 | Float32[Array, '*num_trees k tree_size'] 

223 ] 

224 """The mean-zero normal values to be added to the posterior mean to 

225 obtain the posterior leaf samples.""" 

226 

227 

228class PreLfUV(PreLf): 

229 """`PreLf` for the univariate case.""" 

230 

231 mean_factor: Float32[Array, '*num_trees tree_size'] 

232 """``error_cov_inv / prec``, where ``prec`` is the posterior precision of 

233 the leaf.""" 

234 

235 centered_leaves: Float32[Array, '*num_trees tree_size'] 

236 """Zero-mean normal draws with the posterior variance of each leaf.""" 

237 

238 

239class PreLfMV(PreLf): 

240 """`PreLf` for the multivariate homoskedastic or scalar-weight case.""" 

241 

242 mean_factor: Float32[Array, '*num_trees k k tree_size'] 

243 """``error_cov_inv @ inv(prec)``, where ``prec`` is the posterior precision 

244 of the leaf.""" 

245 

246 centered_leaves: Float32[Array, '*num_trees k tree_size'] 

247 """Zero-mean normal draws with the posterior covariance of each leaf.""" 

248 

249 logdet_prec: Float32[Array, '*num_trees tree_size'] 

250 """The log-determinant of the posterior precision of each leaf.""" 

251 

252 

253class PreLfMVHet(PreLf): 

254 """`PreLf` for the multivariate vector-weight case.""" 

255 

256 mean_factor: Float32[Array, '*num_trees k k tree_size'] 

257 """The lower Cholesky factor of the posterior precision of each leaf; the 

258 mean solve happens downstream in the sequential stage.""" 

259 

260 centered_leaves: Float32[Array, '*num_trees k tree_size'] 

261 """Zero-mean normal draws with the posterior covariance of each leaf.""" 

262 

263 

264class ParallelStageOut(Module): 

265 """The output of `accept_moves_parallel_stage`.""" 

266 

267 state: State 

268 """A partially updated BART mcmc state.""" 

269 

270 moves: Moves 

271 """The proposed moves, with `partial_ratio` set to `None` and 

272 `log_trans_prior_ratio` set to its final value.""" 

273 

274 # `num_trees` stays a fixed (non-variadic) axis: `ParallelStageOut` is always 

275 # built with the tree axis present (never per tree under vmap), so the union 

276 # is disambiguated by rank/dtype and needs no anchor (cf. `PreLf`). 

277 prec_trees: ( 

278 Float32[Array, 'num_trees tree_size'] 

279 | UInt32[Array, 'num_trees tree_size'] 

280 | Float32[Array, 'num_trees k k tree_size'] 

281 ) 

282 """The likelihood precision scale in each potential or actual leaf node.""" 

283 

284 prelkv: PreLkV 

285 """Object with pre-computed terms of the likelihood ratios.""" 

286 

287 prelf: PreLf 

288 """Object with pre-computed terms of the leaf samples.""" 

289 

290 

291@named_call 

292def accept_moves_parallel_stage( 

293 key: Key[Array, ''], state: State, moves: Moves 

294) -> ParallelStageOut: 

295 """ 

296 Pre-compute quantities used to accept moves, in parallel across trees. 

297 

298 Parameters 

299 ---------- 

300 key 

301 A jax random key. 

302 state 

303 A BART mcmc state. 

304 moves 

305 The proposed moves, see `propose_moves`. 

306 

307 Returns 

308 ------- 

309 An object with all that could be done in parallel. 

310 """ 

311 # where the move is grow, modify the state like the move was accepted 

312 state = replace( 

313 state, 

314 forest=replace( 

315 state.forest, 

316 var_tree=moves.var_tree, 

317 leaf_indices=apply_grow_to_indices( 

318 moves, state.forest.leaf_indices, state.X 

319 ), 

320 leaf_tree=adapt_leaf_trees_to_grow_indices(state.forest.leaf_tree, moves), 

321 ), 

322 ) 

323 

324 # update the cached number of datapoints per leaf at the nodes involved 

325 # in the moves 

326 if ( 

327 state.forest.min_points_per_decision_node is not None 

328 or state.forest.min_points_per_leaf is not None 

329 or state.prec_scale is None 

330 ): 

331 assert state.forest.count_tree is not None 

332 count_trees, move_counts = compute_count_trees( 

333 state.forest.count_tree, state.forest.leaf_indices, moves, state.config 

334 ) 

335 state = replace(state, forest=replace(state.forest, count_tree=count_trees)) 

336 

337 # affluence of the nodes touched by each move: whether they would be 

338 # growable as leaves (admissible rule + enough datapoints). The children 

339 # must also lie within the heap, i.e. not be at the bottom level; the 

340 # parent always does. These feed the transition ratio and the final 

341 # `affluence_tree` update. 

342 _, half = state.forest.var_tree.shape 

343 lrt_affluent = (moves.lrt_nodes < half) & moves.lrt_growable 

344 if state.forest.min_points_per_decision_node is not None: 

345 lrt_affluent &= move_counts.lrt >= state.forest.min_points_per_decision_node 

346 moves = replace(moves, lrt_affluent=lrt_affluent) 

347 

348 # veto grove move if new leaves don't have enough datapoints 

349 if state.forest.min_points_per_leaf is not None: 

350 moves = replace( 

351 moves, 

352 allowed=moves.allowed 

353 & jnp.all( 

354 move_counts.lrt[..., :2] >= state.forest.min_points_per_leaf, axis=-1 

355 ), 

356 ) 

357 

358 # update the cached number of datapoints per leaf, weighted by error 

359 # precision scale, at the nodes involved in the moves 

360 if state.prec_scale is None: 

361 prec_trees = count_trees 

362 else: 

363 assert state.forest.prec_tree is not None 

364 prec_trees = compute_prec_trees( 

365 state.forest.prec_tree, 

366 state.prec_scale, 

367 state.forest.leaf_indices, 

368 moves, 

369 state.config, 

370 ) 

371 state = replace(state, forest=replace(state.forest, prec_tree=prec_trees)) 

372 

373 # compute some missing information about moves 

374 moves = complete_ratio(moves, state.forest.p_nonterminal) 

375 save_ratios = state.forest.log_likelihood is not None 

376 state = replace( 

377 state, 

378 forest=replace( 

379 state.forest, 

380 grow_prop_count=jnp.sum(moves.grow), 

381 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow), 

382 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None, 

383 ), 

384 ) 

385 

386 prelf = precompute_leaf_terms( 

387 key, prec_trees, state.error_cov_inv.value, state.forest.leaf_prior_cov_inv 

388 ) 

389 prelkv = precompute_likelihood_terms( 

390 state.error_cov_inv.value, state.forest.leaf_prior_cov_inv, prelf, moves 

391 ) 

392 

393 return ParallelStageOut( 

394 state=state, moves=moves, prec_trees=prec_trees, prelkv=prelkv, prelf=prelf 

395 ) 

396 

397 

398@named_call 

399def apply_grow_to_indices( 

400 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n'] 

401) -> UInt[Array, 'num_trees n']: 

402 """ 

403 Update the leaf indices to apply a grow move. 

404 

405 Parameters 

406 ---------- 

407 moves 

408 The proposed moves, see `propose_moves`. 

409 leaf_indices 

410 The index of the leaf each datapoint falls into. 

411 X 

412 The predictors matrix. 

413 

414 Returns 

415 ------- 

416 The updated leaf indices. 

417 """ 

418 return _apply_grow_to_indices(moves, leaf_indices, X) 

419 

420 

421@partial(vmap_nodoc, in_axes=(0, 0, None)) 

422def _apply_grow_to_indices( 

423 moves: Moves, leaf_indices: UInt[Array, ' n'], X: UInt[Array, 'p n'] 

424) -> UInt[Array, ' n']: 

425 """Implement `apply_grow_to_indices`.""" 

426 left_child = moves.lrt_nodes[0].astype(leaf_indices.dtype) 

427 x: UInt[Array, ' n'] = X[moves.grow_var, :] 

428 go_right = x >= moves.grow_split 

429 tree_size = jnp.array(2 * moves.var_tree.size) 

430 node_to_update = jnp.where(moves.grow, moves.lrt_nodes[2], tree_size) 

431 return jnp.where( 

432 leaf_indices == node_to_update, left_child + go_right, leaf_indices 

433 ) 

434 

435 

436def _fill_lrt_total(lrt: Shaped[Array, '*k_k 3']) -> Shaped[Array, '*k_k 3']: 

437 """Set the total slot of stacked (left, right, total) values to left + right. 

438 

439 The left and right slots pass through unchanged, the stale value in the 

440 total slot is ignored. Implemented with fusable elementwise operations. 

441 """ 

442 total = lrt[..., 0] + lrt[..., 1] 

443 return jnp.where(jnp.arange(3) == 2, total[..., None], lrt) 

444 

445 

446@overload 

447def _compute_count_or_prec_trees( 

448 prec_scale: None, 

449 trees: UInt32[Array, 'num_trees tree_size'], 

450 leaf_indices: UInt[Array, 'num_trees n'], 

451 moves: Moves, 

452 config: StepConfig, 

453) -> tuple[UInt32[Array, 'num_trees tree_size'], Counts]: ... 

454 

455 

456@overload 

457def _compute_count_or_prec_trees( 

458 prec_scale: Float32[Array, ' n'] | Float32[Array, 'k k n'], 

459 trees: Float32[Array, 'num_trees tree_size'] 

460 | Float32[Array, 'num_trees k k tree_size'], 

461 leaf_indices: UInt[Array, 'num_trees n'], 

462 moves: Moves, 

463 config: StepConfig, 

464) -> ( 

465 tuple[Float32[Array, 'num_trees tree_size'], None] 

466 | tuple[Float32[Array, 'num_trees k k tree_size'], None] 

467): ... 

468 

469 

470def _compute_count_or_prec_trees( 

471 prec_scale: Float32[Array, ' n'] | Float32[Array, 'k k n'] | None, 

472 trees: UInt32[Array, 'num_trees tree_size'] 

473 | Float32[Array, 'num_trees tree_size'] 

474 | Float32[Array, 'num_trees k k tree_size'], 

475 leaf_indices: UInt[Array, 'num_trees n'], 

476 moves: Moves, 

477 config: StepConfig, 

478) -> ( 

479 tuple[UInt32[Array, 'num_trees tree_size'], Counts] 

480 | tuple[Float32[Array, 'num_trees tree_size'], None] 

481 | tuple[Float32[Array, 'num_trees k k tree_size'], None] 

482): 

483 """Implement `compute_count_trees` and `compute_prec_trees`.""" 

484 if config.prec_count_num_trees is None: 484 ↛ 488line 484 didn't jump to line 488 because the condition on line 484 was always true

485 compute = vmap(_compute_count_or_prec_tree, in_axes=(None, 0, 0, 0, None)) 

486 return compute(prec_scale, trees, leaf_indices, moves, config) 

487 

488 def compute( 

489 args: tuple[ 

490 UInt32[Array, ' tree_size'] 

491 | Float32[Array, ' tree_size'] 

492 | Float32[Array, 'k k tree_size'], 

493 UInt[Array, ' n'], 

494 Moves, 

495 ], 

496 ) -> ( 

497 tuple[UInt32[Array, ' tree_size'], Counts] 

498 | tuple[Float32[Array, ' tree_size'], None] 

499 | tuple[Float32[Array, 'k k tree_size'], None] 

500 ): 

501 tree, leaf_indices, moves = args 

502 return _compute_count_or_prec_tree( 

503 prec_scale, tree, leaf_indices, moves, config 

504 ) 

505 

506 return lax.map( 

507 compute, (trees, leaf_indices, moves), batch_size=config.prec_count_num_trees 

508 ) 

509 

510 

511def _compute_count_or_prec_tree( 

512 prec_scale: Float32[Array, ' n'] | Float32[Array, 'k k n'] | None, 

513 tree: UInt32[Array, ' tree_size'] 

514 | Float32[Array, ' tree_size'] 

515 | Float32[Array, 'k k tree_size'], 

516 leaf_indices: UInt[Array, ' n'], 

517 moves: Moves, 

518 config: StepConfig, 

519) -> ( 

520 tuple[UInt32[Array, ' tree_size'], Counts] 

521 | tuple[Float32[Array, ' tree_size'], None] 

522 | tuple[Float32[Array, 'k k tree_size'], None] 

523): 

524 """Update the cached count or precision tree for a single tree.""" 

525 (tree_size,) = moves.var_tree.shape 

526 tree_size *= 2 

527 

528 if prec_scale is None: 

529 value = 1 

530 dtype = jnp.uint32 

531 reduction_config = config.count_reduction_config 

532 else: 

533 value = prec_scale 

534 dtype = jnp.float32 

535 reduction_config = config.prec_reduction_config 

536 

537 # the cached tree is valid at the leaves, and the move only changes the 

538 # values at the nodes it involves, so reduce into the move's children alone: 

539 # the contiguous pair (left, right) = (2 * node, 2 * node + 1) = lrt_nodes[:2] 

540 lr = reduction_config._reduce( # noqa: SLF001 

541 value, 

542 leaf_indices, 

543 size=tree_size, 

544 subset_start=moves.lrt_nodes[0], 

545 subset_length=2, 

546 dtype=dtype, 

547 data_sharded=config.data_sharded, 

548 ) 

549 

550 # write the children sums into the cache along with their total at the 

551 # parent node (a non-leaf in the post-grow indexing the reduce runs on); 

552 # the weighted version of the counts is not needed because the likelihood 

553 # terms are derived from the leaf terms 

554 total = lr[..., 0] + lr[..., 1] 

555 lrt = jnp.concatenate([lr, total[..., None]], axis=-1) 

556 tree = tree.at[..., moves.lrt_nodes].set(lrt) 

557 

558 if prec_scale is None: 

559 return tree, Counts(lrt=lrt) 

560 else: 

561 return tree, None 

562 

563 

564@named_call 

565def compute_count_trees( 

566 count_trees: UInt32[Array, 'num_trees tree_size'], 

567 leaf_indices: UInt[Array, 'num_trees n'], 

568 moves: Moves, 

569 config: StepConfig, 

570) -> tuple[UInt32[Array, 'num_trees tree_size'], Counts]: 

571 """ 

572 Update the cached number of datapoints per leaf at the moves' nodes. 

573 

574 Parameters 

575 ---------- 

576 count_trees 

577 The cached number of points in each leaf; valid at the leaves of the 

578 pre-move trees. 

579 leaf_indices 

580 The index of the leaf each datapoint falls into, with the deeper version 

581 of the tree (post-GROW, pre-PRUNE). 

582 moves 

583 The proposed moves, see `propose_moves`. 

584 config 

585 The MCMC configuration. 

586 

587 Returns 

588 ------- 

589 count_trees : UInt32[Array, 'num_trees tree_size'] 

590 The updated cache, valid in each potential or actual leaf node. 

591 counts : Counts 

592 The counts of the number of points in the leaves grown or pruned by the 

593 moves. 

594 """ 

595 return _compute_count_or_prec_trees(None, count_trees, leaf_indices, moves, config) 

596 

597 

598@named_call 

599def compute_prec_trees( 

600 prec_trees: Float32[Array, 'num_trees tree_size'] 

601 | Float32[Array, 'num_trees k k tree_size'], 

602 prec_scale: Float32[Array, ' n'] | Float32[Array, 'k k n'], 

603 leaf_indices: UInt[Array, 'num_trees n'], 

604 moves: Moves, 

605 config: StepConfig, 

606) -> Float32[Array, 'num_trees tree_size'] | Float32[Array, 'num_trees k k tree_size']: 

607 """ 

608 Update the cached per-leaf likelihood precision scale at the moves' nodes. 

609 

610 Parameters 

611 ---------- 

612 prec_trees 

613 The cached likelihood precision scale in each leaf; valid at the leaves 

614 of the pre-move trees. 

615 prec_scale 

616 The scale of the precision of the error on each datapoint. 

617 leaf_indices 

618 The index of the leaf each datapoint falls into, with the deeper version 

619 of the tree (post-GROW, pre-PRUNE). 

620 moves 

621 The proposed moves, see `propose_moves`. 

622 config 

623 The MCMC configuration. 

624 

625 Returns 

626 ------- 

627 The updated cache, valid in each potential or actual leaf node. 

628 """ 

629 trees, _ = _compute_count_or_prec_trees( 

630 prec_scale, prec_trees, leaf_indices, moves, config 

631 ) 

632 return trees 

633 

634 

635@partial(vmap_nodoc, in_axes=(0, None)) 

636def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' tree_size']) -> Moves: 

637 """ 

638 Complete non-likelihood MH ratio calculation. 

639 

640 This function adds the probability of choosing a prune move over the grow 

641 move in the inverse transition, and the prior odds that the modified node 

642 is nonterminal with terminal children. 

643 

644 Parameters 

645 ---------- 

646 moves 

647 The proposed moves. Must have already been updated to keep into account 

648 the thresholds on the number of datapoints per node, this happens in 

649 `accept_moves_parallel_stage`. 

650 p_nonterminal 

651 The a priori probability of each node being nonterminal conditional on 

652 its ancestors, including at the maximum depth where it should be zero. 

653 

654 Returns 

655 ------- 

656 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set. 

657 """ 

658 assert moves.lrt_affluent is not None 

659 

660 # can the children be grown by the proposal? `lrt_affluent` already folds 

661 # in the `min_points_per_decision_node` threshold, because the grow 

662 # proposal draws from the pool of leaves that pass it. This enters only the 

663 # transition probability. 

664 

665 # p_prune if grow 

666 other_growable_leaves = moves.num_growable >= 2 

667 grow_again_allowed = other_growable_leaves | jnp.any(moves.lrt_affluent[:2]) 

668 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1.0) 

669 

670 # p_prune if prune 

671 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 

672 

673 # select p_prune 

674 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 

675 

676 # prior odds of the node being nonterminal, times the prior probability of 

677 # both children being terminal. The children terminality uses the 

678 # admissibility ignoring counts, because the standard BART prior conditions 

679 # the non-terminal probability only on the existence of available decision 

680 # rules, not on the count thresholds (which are a bartz proposal-efficiency 

681 # device, not part of the target distribution). The fill value avoids a 0 

682 # and then an inf in the log if the move is not allowed and the indices are 

683 # out of bounds. 

684 pnt = p_nonterminal.at[moves.lrt_nodes].get(mode='fill', fill_value=0.5) 

685 prior_ratio = pnt[2] / (1 - pnt[2]) * jnp.prod(1 - pnt[:2] * moves.lrt_growable[:2]) 

686 

687 assert moves.partial_ratio is not None 

688 return replace( 

689 moves, 

690 log_trans_prior_ratio=jnp.log(moves.partial_ratio * prior_ratio * p_prune), 

691 partial_ratio=None, 

692 ) 

693 

694 

695@named_call 

696def adapt_leaf_trees_to_grow_indices( 

697 leaf_trees: Float32[Array, 'num_trees tree_size'] 

698 | Float32[Array, 'num_trees k tree_size'], 

699 moves: Moves, 

700) -> Float32[Array, 'num_trees tree_size'] | Float32[Array, 'num_trees k tree_size']: 

701 """ 

702 Modify leaves such that post-grow indices work on the original tree. 

703 

704 The value of the leaf to grow is copied to what would be its children if the 

705 grow move was accepted. 

706 

707 Parameters 

708 ---------- 

709 leaf_trees 

710 The leaf values. 

711 moves 

712 The proposed moves, see `propose_moves`. 

713 

714 Returns 

715 ------- 

716 The modified leaf values. 

717 """ 

718 return _adapt_leaf_trees_to_grow_indices(leaf_trees, moves) 

719 

720 

721@vmap_nodoc 

722def _adapt_leaf_trees_to_grow_indices( 

723 leaf_trees: Float32[Array, ' tree_size'] | Float32[Array, ' k tree_size'], 

724 moves: Moves, 

725) -> Float32[Array, ' tree_size'] | Float32[Array, ' k tree_size']: 

726 """Implement `adapt_leaf_trees_to_grow_indices`.""" 

727 # the parent slot is written back unchanged to share a single scatter 

728 values_at_node = leaf_trees[..., moves.lrt_nodes[2]] 

729 return leaf_trees.at[ 

730 ..., jnp.where(moves.grow, moves.lrt_nodes, leaf_trees.size) 

731 ].set(values_at_node[..., None]) 

732 

733 

734def _logdet_from_chol(L: Float32[Array, '... k k']) -> Float32[Array, '...']: 

735 """Compute logdet of A = LL' via Cholesky (sum of log of diag^2).""" 

736 diags: Float32[Array, '... k'] = jnp.diagonal(L, axis1=-2, axis2=-1) 

737 return 2.0 * jnp.sum(jnp.log(diags), axis=-1) 

738 

739 

740def compute_B( 

741 error_cov_inv: Float32[Array, 'k k'], resid: Float32[Array, 'k k *tree_size'] 

742) -> Float32[Array, ' k *tree_size']: 

743 """Compute the leaf score from the leaf weighted sum of residuals.""" 

744 return jnp.einsum('ab,ab...->a...', error_cov_inv, resid) 

745 

746 

747def _precompute_leaf_terms_uv( 

748 key: Key[Array, ''], 

749 prec_trees: Float32[Array, 'num_trees tree_size'] 

750 | UInt32[Array, 'num_trees tree_size'], 

751 error_cov_inv: Float32[Array, ''], 

752 leaf_prior_cov_inv: Float32[Array, ''], 

753 z: Float32[Array, 'num_trees tree_size'] | None = None, 

754) -> PreLfUV: 

755 prec_lk = prec_trees * error_cov_inv 

756 var_post = jnp.reciprocal(prec_lk + leaf_prior_cov_inv) 

757 if z is None: 

758 z = random.normal(key, prec_trees.shape, error_cov_inv.dtype) 

759 return PreLfUV( 

760 mean_factor=var_post * error_cov_inv, 

761 # | mean = mean_lk * prec_lk * var_post 

762 # | resid_tree = mean_lk * prec_tree --> 

763 # | --> mean_lk = resid_tree / prec_tree (kind of) 

764 # | mean_factor = 

765 # | = mean / resid_tree = 

766 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree = 

767 # | = 1 / prec_tree * prec_tree / sigma2 * var_post = 

768 # | = var_post / sigma2 

769 centered_leaves=z * jnp.sqrt(var_post), 

770 ) 

771 

772 

773def _precompute_leaf_terms_mv( 

774 key: Key[Array, ''], 

775 prec_trees: Float32[Array, 'num_trees tree_size'] 

776 | UInt32[Array, 'num_trees tree_size'], 

777 error_cov_inv: Float32[Array, 'k k'], 

778 leaf_prior_cov_inv: Float32[Array, 'k k'], 

779 z: Float32[Array, 'num_trees k tree_size'] | None = None, 

780) -> PreLfMV: 

781 num_trees, tree_size = prec_trees.shape 

782 k, _ = error_cov_inv.shape 

783 if z is None: 783 ↛ 786line 783 didn't jump to line 786 because the condition on line 783 was always true

784 z = random.normal(key, (num_trees, k, tree_size)) 

785 

786 def per_leaf( 

787 prec: Float32[Array, ''] | UInt32[Array, ''], z: Float32[Array, ' k'] 

788 ) -> tuple[Float32[Array, 'k k'], Float32[Array, ' k'], Float32[Array, '']]: 

789 L_prec = chol_with_gersh(leaf_prior_cov_inv + prec * error_cov_inv) 

790 Y = solve_triangular(L_prec, error_cov_inv, lower=True) 

791 mean_factor = solve_triangular(L_prec, Y, trans='T', lower=True).mT 

792 centered = solve_triangular(L_prec, z[:, None], trans='T', lower=True).squeeze( 

793 -1 

794 ) 

795 # only a few leaves per tree end up using their logdet, but reducing 

796 # right away is lighter on memory than storing diagonals for later 

797 return mean_factor, centered, _logdet_from_chol(L_prec) 

798 

799 # vmap over trees then over leaves; the leaf axis is trailing in both 

800 # `prec_trees`/`z` (in_axes) and the stored output (out_axes=-1) 

801 return PreLfMV(*vmap(vmap(per_leaf, in_axes=(0, -1), out_axes=-1))(prec_trees, z)) 

802 

803 

804def _precompute_leaf_terms_mv_het( 

805 key: Key[Array, ''], 

806 prec_trees: Float32[Array, 'num_trees k k tree_size'], 

807 error_cov_inv: Float32[Array, 'k k'], 

808 leaf_prior_cov_inv: Float32[Array, 'k k'], 

809 z: Float32[Array, 'num_trees k tree_size'] | None = None, 

810) -> PreLfMVHet: 

811 num_trees, k, _, tree_size = prec_trees.shape 

812 if z is None: 812 ↛ 815line 812 didn't jump to line 815 because the condition on line 812 was always true

813 z = random.normal(key, (num_trees, k, tree_size)) 

814 

815 def per_leaf( 

816 prec: Float32[Array, 'k k'], z: Float32[Array, ' k'] 

817 ) -> tuple[Float32[Array, 'k k'], Float32[Array, ' k']]: 

818 # mean_factor stores the precision cholesky itself; the mean solve happens 

819 # downstream in `accept_move_and_sample_leaves` 

820 L_prec = chol_with_gersh(leaf_prior_cov_inv + error_cov_inv * prec) 

821 centered = solve_triangular(L_prec, z[:, None], trans='T', lower=True).squeeze( 

822 -1 

823 ) 

824 return L_prec, centered 

825 

826 # vmap over trees then over leaves; the leaf axis is trailing in both 

827 # `prec_trees`/`z` (in_axes=-1) and the stored output (out_axes=-1) 

828 return PreLfMVHet( 

829 *vmap(vmap(per_leaf, in_axes=(-1, -1), out_axes=-1))(prec_trees, z) 

830 ) 

831 

832 

833@named_call 

834def precompute_leaf_terms( 

835 key: Key[Array, ''], 

836 prec_trees: Float32[Array, 'num_trees tree_size'] 

837 | UInt32[Array, 'num_trees tree_size'] 

838 | Float32[Array, 'num_trees k k tree_size'], 

839 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

840 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

841 z: Float32[Array, 'num_trees tree_size'] 

842 | Float32[Array, 'num_trees k tree_size'] 

843 | None = None, 

844) -> PreLf: 

845 """ 

846 Pre-compute terms used to sample leaves from their posterior. 

847 

848 Handles both univariate and multivariate cases based on the shape of the 

849 input arrays. 

850 

851 Parameters 

852 ---------- 

853 key 

854 A jax random key. 

855 prec_trees 

856 The likelihood precision scale in each potential or actual leaf node. 

857 error_cov_inv 

858 The inverse error variance (univariate) or the inverse of error 

859 covariance matrix (multivariate). For univariate case, this is the 

860 inverse global error variance factor if `prec_scale` is set. 

861 leaf_prior_cov_inv 

862 The inverse prior variance of each leaf (univariate) or the inverse of 

863 prior covariance matrix of each leaf (multivariate). 

864 z 

865 Optional standard normal noise to use for sampling the centered leaves. 

866 This is intended for testing purposes only. 

867 

868 Returns 

869 ------- 

870 Pre-computed terms for leaf sampling. 

871 """ 

872 if error_cov_inv.ndim == 0: 

873 return _precompute_leaf_terms_uv( 

874 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z 

875 ) 

876 elif prec_trees.ndim == 4: 

877 return _precompute_leaf_terms_mv_het( 

878 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z 

879 ) 

880 else: 

881 return _precompute_leaf_terms_mv( 

882 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z 

883 ) 

884 

885 

886@vmap_nodoc 

887def _gather_lrt( 

888 leaf_values: Float32[Array, '*k_k tree_size'], lrt_nodes: Int32[Array, ' 3'] 

889) -> Float32[Array, ' 3 *k_k']: 

890 """Gather per-tree leaf values at the left child, right child, and parent.""" 

891 return jnp.moveaxis(leaf_values[..., lrt_nodes], -1, 0) 

892 

893 

894def _precompute_likelihood_terms_uv( 

895 error_cov_inv: Float32[Array, ''], 

896 leaf_prior_cov_inv: Float32[Array, ''], 

897 prelf: PreLfUV, 

898 lrt_nodes: Int32[Array, 'num_trees 3'], 

899) -> PreLkV: 

900 # mean_factor is error_cov_inv / prec, complete the sandwich 

901 lrt = error_cov_inv * _gather_lrt(prelf.mean_factor, lrt_nodes) 

902 # the same value with the prior-only precision, computed with the same 

903 # operations as in `_precompute_leaf_terms_uv` such that it matches `lrt` 

904 # bitwise on empty nodes and the ratio is exactly 1 without data 

905 prior_lrt = error_cov_inv * (jnp.reciprocal(leaf_prior_cov_inv) * error_cov_inv) 

906 log_sqrt_term = jnp.log(lrt[..., 0] * lrt[..., 1] / (prior_lrt * lrt[..., 2])) / 2 

907 return PreLkV(lrt=lrt, log_sqrt_term=log_sqrt_term) 

908 

909 

910def _precompute_likelihood_terms_mv( 

911 error_cov_inv: Float32[Array, 'k k'], 

912 leaf_prior_cov_inv: Float32[Array, 'k k'], 

913 prelf: PreLfMV, 

914 lrt_nodes: Int32[Array, 'num_trees 3'], 

915) -> PreLkV: 

916 logdet_prior = _logdet_from_chol(chol_with_gersh(leaf_prior_cov_inv)) 

917 logdet_prec = _gather_lrt(prelf.logdet_prec, lrt_nodes) 

918 log_sqrt_term = (logdet_prior + logdet_prec @ jnp.array([-1.0, -1.0, 1.0])) / 2 

919 

920 # mean_factor is error_cov_inv @ inv(prec), complete the sandwich 

921 mean_factor = _gather_lrt(prelf.mean_factor, lrt_nodes) # (num_trees, 3, k, k) 

922 return PreLkV(lrt=mean_factor @ error_cov_inv, log_sqrt_term=log_sqrt_term) 

923 

924 

925def _precompute_likelihood_terms_mv_het( 

926 leaf_prior_cov_inv: Float32[Array, 'k k'], 

927 prelf: PreLfMVHet, 

928 lrt_nodes: Int32[Array, 'num_trees 3'], 

929) -> PreLkV: 

930 logdet_prior = _logdet_from_chol(chol_with_gersh(leaf_prior_cov_inv)) 

931 

932 # mean_factor is the precision cholesky itself 

933 L = _gather_lrt(prelf.mean_factor, lrt_nodes) # (num_trees, 3, k, k) 

934 log_sqrt_term = ( 

935 logdet_prior + _logdet_from_chol(L) @ jnp.array([-1.0, -1.0, 1.0]) 

936 ) / 2 

937 return PreLkV(lrt=L, log_sqrt_term=log_sqrt_term) 

938 

939 

940@named_call 

941def precompute_likelihood_terms( 

942 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

943 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

944 prelf: PreLf, 

945 moves: Moves, 

946) -> PreLkV: 

947 """ 

948 Pre-compute terms used in the likelihood ratio of the acceptance step. 

949 

950 The likelihood ratio terms are mostly a subset of the leaf sampling terms, 

951 so they are derived from `prelf`, gathered at the nodes involved in the 

952 moves. 

953 

954 Parameters 

955 ---------- 

956 error_cov_inv 

957 The inverse error variance (univariate) or the inverse of the error 

958 covariance matrix (multivariate). This is the inverse global error 

959 variance factor if `prec_scale` is set. 

960 leaf_prior_cov_inv 

961 The inverse prior variance of each leaf (univariate) or the inverse of 

962 prior covariance matrix of each leaf (multivariate). 

963 prelf 

964 The pre-computed terms of the leaf sampling, see `precompute_leaf_terms`. 

965 moves 

966 The proposed moves, see `propose_moves`. 

967 

968 Returns 

969 ------- 

970 Pre-computed terms of the likelihood ratio, one per tree. 

971 """ 

972 if isinstance(prelf, PreLfUV): 

973 return _precompute_likelihood_terms_uv( 

974 error_cov_inv, leaf_prior_cov_inv, prelf, moves.lrt_nodes 

975 ) 

976 elif isinstance(prelf, PreLfMVHet): 

977 return _precompute_likelihood_terms_mv_het( 

978 leaf_prior_cov_inv, prelf, moves.lrt_nodes 

979 ) 

980 else: 

981 assert isinstance(prelf, PreLfMV) 

982 return _precompute_likelihood_terms_mv( 

983 error_cov_inv, leaf_prior_cov_inv, prelf, moves.lrt_nodes 

984 ) 

985 

986 

987@named_call 

988def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 

989 """ 

990 Accept/reject the moves one tree at a time. 

991 

992 This is the most performance-sensitive function because it contains all and 

993 only the parts of the algorithm that can not be parallelized across trees. 

994 

995 Parameters 

996 ---------- 

997 pso 

998 The output of `accept_moves_parallel_stage`. 

999 

1000 Returns 

1001 ------- 

1002 state : State 

1003 A partially updated BART mcmc state. 

1004 moves : Moves 

1005 The accepted/rejected moves, with `acc` and `to_prune` set. 

1006 """ 

1007 

1008 def loop( 

1009 resid: Float32[Array, ' n'] | Float32[Array, ' k n'], pt: SeqStageInPerTree 

1010 ) -> tuple[ 

1011 Float32[Array, ' n'] | Float32[Array, ' k n'], 

1012 tuple[ 

1013 Float32[Array, ' tree_size'] | Float32[Array, ' k tree_size'], 

1014 Bool[Array, ''], 

1015 Bool[Array, ''], 

1016 Float32[Array, ''] | None, 

1017 ], 

1018 ]: 

1019 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 

1020 resid, 

1021 SeqStageInAllTrees( 

1022 pso.state.X, 

1023 pso.state.config.resid_reduction_config, 

1024 pso.state.config.data_sharded, 

1025 pso.state.prec_scale, 

1026 pso.state.forest.log_likelihood is not None, 

1027 pso.state.error_cov_inv.value 

1028 if isinstance(pso.prelf, PreLfMVHet) 

1029 else None, 

1030 ), 

1031 pt, 

1032 ) 

1033 return resid, (leaf_tree, acc, to_prune, lkratio) 

1034 

1035 pts = SeqStageInPerTree( 

1036 pso.state.forest.leaf_tree, 

1037 pso.prec_trees, 

1038 pso.moves, 

1039 pso.state.forest.leaf_indices, 

1040 pso.prelkv, 

1041 pso.prelf, 

1042 ) 

1043 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan( 

1044 loop, pso.state.resid, pts, unroll=pso.state.config.sequential_unroll 

1045 ) 

1046 

1047 state = replace( 

1048 pso.state, 

1049 resid=resid, 

1050 forest=replace(pso.state.forest, leaf_tree=leaf_trees, log_likelihood=lkratio), 

1051 ) 

1052 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 

1053 

1054 return state, moves 

1055 

1056 

1057class SeqStageInAllTrees(Module): 

1058 """The inputs to `accept_move_and_sample_leaves` that are shared by all trees.""" 

1059 

1060 X: UInt[Array, 'p n'] 

1061 """The predictors.""" 

1062 

1063 resid_reduction_config: ReductionConfig 

1064 """How to sum the residuals in each leaf.""" 

1065 

1066 data_sharded: bool = field(static=True) 

1067 """Whether the data axis is sharded across devices.""" 

1068 

1069 prec_scale: Float32[Array, ' n'] | Float32[Array, 'k k n'] | None 

1070 """The scale of the precision of the error on each datapoint. If None, it 

1071 is assumed to be 1.""" 

1072 

1073 save_ratios: bool = field(static=True) 

1074 """Whether to save the acceptance ratios.""" 

1075 

1076 error_cov_inv: Float32[Array, 'k k'] | None 

1077 """The global error precision scale. Set only in the multivariate 

1078 vector-weight case, where the sequential stage needs it to compute the 

1079 leaf scores.""" 

1080 

1081 

1082class SeqStageInPerTree(Module): 

1083 """The inputs to `accept_move_and_sample_leaves` that are separate for each tree.""" 

1084 

1085 # Although consumed one tree at a time by `lax.scan`, this object is only 

1086 # ever constructed in the stacked (batched) form fed to the scan, so 

1087 # `num_trees` stays a fixed (non-variadic) leading axis disambiguated by 

1088 # rank/dtype (cf. `ParallelStageOut`); the per-tree slices reach `loop` via 

1089 # scan, which does not re-run `__init__`. 

1090 leaf_tree: ( 

1091 Float32[Array, 'num_trees tree_size'] | Float32[Array, 'num_trees k tree_size'] 

1092 ) 

1093 """The leaf values of the trees.""" 

1094 

1095 prec_tree: ( 

1096 Float32[Array, 'num_trees tree_size'] 

1097 | UInt32[Array, 'num_trees tree_size'] 

1098 | Float32[Array, 'num_trees k k tree_size'] 

1099 ) 

1100 """The likelihood precision scale in each potential or actual leaf node.""" 

1101 

1102 move: Moves 

1103 """The proposed move, see `propose_moves`.""" 

1104 

1105 leaf_indices: UInt[Array, 'num_trees n'] 

1106 """The leaf indices for the largest version of the tree compatible with 

1107 the move.""" 

1108 

1109 prelkv: PreLkV 

1110 """The pre-computed terms of the likelihood ratio which are specific to the tree.""" 

1111 

1112 prelf: PreLf 

1113 """The pre-computed terms of the leaf sampling which are specific to the tree.""" 

1114 

1115 

1116@named_call 

1117def accept_move_and_sample_leaves( 

1118 resid: Float32[Array, ' n'] | Float32[Array, ' k n'], 

1119 at: SeqStageInAllTrees, 

1120 pt: SeqStageInPerTree, 

1121) -> tuple[ 

1122 Float32[Array, ' n'] | Float32[Array, ' k n'], 

1123 Float32[Array, ' tree_size'] | Float32[Array, ' k tree_size'], 

1124 Bool[Array, ''], 

1125 Bool[Array, ''], 

1126 Float32[Array, ''] | None, 

1127]: 

1128 """ 

1129 Accept or reject a proposed move and sample the new leaf values. 

1130 

1131 Parameters 

1132 ---------- 

1133 resid 

1134 The residuals (data minus forest value). 

1135 at 

1136 The inputs that are the same for all trees. 

1137 pt 

1138 The inputs that are separate for each tree. 

1139 

1140 Returns 

1141 ------- 

1142 resid : Float32[Array, 'n'] | Float32[Array, ' k n'] 

1143 The updated residuals (data minus forest value). 

1144 leaf_tree : Float32[Array, 'tree_size'] | Float32[Array, ' k tree_size'] 

1145 The new leaf values of the tree. 

1146 acc : Bool[Array, ''] 

1147 Whether the move was accepted. 

1148 to_prune : Bool[Array, ''] 

1149 Whether, to reflect the acceptance status of the move, the state should 

1150 be updated by pruning the leaves involved in the move. 

1151 log_lk_ratio : Float32[Array, ''] | None 

1152 The logarithm of the likelihood ratio for the move. `None` if not to be 

1153 saved. 

1154 """ 

1155 # sum residuals in each leaf, in tree proposed by grow move 

1156 if at.prec_scale is None: 

1157 scaled_resid = resid 

1158 else: 

1159 scaled_resid = resid * at.prec_scale 

1160 

1161 tree_size = pt.leaf_tree.shape[-1] # 2**d 

1162 

1163 resid_tree = sum_resid( 

1164 scaled_resid, 

1165 pt.leaf_indices, 

1166 tree_size, 

1167 at.resid_reduction_config, 

1168 at.data_sharded, 

1169 ) 

1170 

1171 # subtract starting tree from function 

1172 resid_tree += pt.prec_tree * pt.leaf_tree 

1173 

1174 # sum residuals in parent node modified by move and compute likelihood; 

1175 # the children slots are written back unchanged to share a single scatter 

1176 assert pt.move.lrt_nodes.dtype == jnp.int32 

1177 resid_lrt = _fill_lrt_total(resid_tree[..., pt.move.lrt_nodes]) 

1178 resid_tree = resid_tree.at[..., pt.move.lrt_nodes].set(resid_lrt) 

1179 

1180 log_lk_ratio = compute_likelihood_ratio(resid_lrt, pt.prelkv, at.error_cov_inv) 

1181 

1182 # calculate accept/reject ratio 

1183 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 

1184 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 

1185 if not at.save_ratios: 

1186 log_lk_ratio = None 

1187 

1188 # determine whether to accept the move 

1189 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 

1190 

1191 # compute leaves posterior and sample leaves 

1192 if at.error_cov_inv is not None: 

1193 # multivariate w/ vector weights 

1194 b_tree = compute_B(at.error_cov_inv, resid_tree) # (k, 2**d) 

1195 l_lead = jnp.moveaxis(pt.prelf.mean_factor, -1, 0) # (2**d, k, k) 

1196 b_lead = b_tree.T[:, :, None] # (2**d, k, 1) 

1197 y = solve_triangular(l_lead, b_lead, lower=True) 

1198 mu = solve_triangular(l_lead, y, lower=True, trans='T').squeeze(-1) 

1199 mean_post = mu.T # (k, 2**d) 

1200 elif resid.ndim > 1: 

1201 # multivariate homoskedastic or scalar weights 

1202 mean_post = jnp.einsum('kil,kl->il', pt.prelf.mean_factor, resid_tree) 

1203 else: 

1204 # univariate 

1205 mean_post = resid_tree * pt.prelf.mean_factor 

1206 leaf_tree = mean_post + pt.prelf.centered_leaves 

1207 

1208 # copy leaves around such that the leaf indices point to the correct leaf; 

1209 # the parent slot is written back unchanged to share a single scatter 

1210 to_prune = acc ^ pt.move.grow 

1211 leaf_tree = leaf_tree.at[ 

1212 ..., jnp.where(to_prune, pt.move.lrt_nodes, tree_size) 

1213 ].set(leaf_tree[..., pt.move.lrt_nodes[2], None]) 

1214 # replace old tree with new tree in function values 

1215 resid += (pt.leaf_tree - leaf_tree)[..., pt.leaf_indices] 

1216 

1217 return resid, leaf_tree, acc, to_prune, log_lk_ratio 

1218 

1219 

1220@named_call 

1221def sum_resid( 

1222 scaled_resid: ( 

1223 Float32[Array, ' n'] | Float32[Array, 'k n'] | Float32[Array, 'k k n'] 

1224 ), 

1225 leaf_indices: UInt[Array, ' n'], 

1226 tree_size: int, 

1227 reduction_config: ReductionConfig, 

1228 data_sharded: bool, 

1229) -> ( 

1230 Float32[Array, ' {tree_size}'] 

1231 | Float32[Array, 'k {tree_size}'] 

1232 | Float32[Array, 'k k {tree_size}'] 

1233): 

1234 """ 

1235 Sum the residuals in each leaf. 

1236 

1237 Parameters 

1238 ---------- 

1239 scaled_resid 

1240 The residuals (data minus forest value) multiplied by the error 

1241 precision scale. 

1242 leaf_indices 

1243 The leaf indices of the tree (in which leaf each data point falls into). 

1244 tree_size 

1245 The size of the tree array (2 ** d). 

1246 reduction_config 

1247 How to sum the residuals in each leaf. 

1248 data_sharded 

1249 Whether the data axis is sharded; if true, the result is psum-reduced 

1250 across the ``'data'`` axis of the enclosing `shard_map`. 

1251 

1252 Returns 

1253 ------- 

1254 The per-leaf sum, with the same leading dimensions as ``scaled_resid`` and a trailing axis over the leaves. 

1255 """ 

1256 return reduction_config._reduce( # noqa: SLF001 

1257 scaled_resid, 

1258 leaf_indices, 

1259 size=tree_size, 

1260 dtype=jnp.float32, 

1261 data_sharded=data_sharded, 

1262 ) 

1263 

1264 

1265def _compute_likelihood_ratio_uv( 

1266 resid_lrt: Float32[Array, ' 3'], prelkv: PreLkV 

1267) -> Float32[Array, '']: 

1268 # quadratic form r * v * r for each of the (left, right, total) terms 

1269 qf = resid_lrt * resid_lrt * prelkv.lrt 

1270 exp_term = 0.5 * (qf @ jnp.array([1.0, 1.0, -1.0])) 

1271 return prelkv.log_sqrt_term + exp_term 

1272 

1273 

1274def _compute_likelihood_ratio_mv( 

1275 resid_lrt: Float32[Array, 'k 3'], prelkv: PreLkV 

1276) -> Float32[Array, '']: 

1277 # quadratic form r' M r for each of the (left, right, total) terms 

1278 qf = jnp.einsum('it,tij,jt->t', resid_lrt, prelkv.lrt, resid_lrt) 

1279 exp_term = 0.5 * (qf @ jnp.array([1.0, 1.0, -1.0])) 

1280 return prelkv.log_sqrt_term + exp_term 

1281 

1282 

1283def _compute_likelihood_ratio_mv_het( 

1284 resid_lrt: Float32[Array, 'k k 3'], 

1285 error_cov_inv: Float32[Array, 'k k'], 

1286 prelkv: PreLkV, 

1287) -> Float32[Array, '']: 

1288 b = compute_B(error_cov_inv, resid_lrt) # (k, 3) 

1289 y = solve_triangular(prelkv.lrt, b.T[..., None], lower=True).squeeze(-1) # (3, k) 

1290 qf = jnp.einsum('ti,ti->t', y, y) 

1291 exp_term = 0.5 * (qf @ jnp.array([1.0, 1.0, -1.0])) 

1292 return prelkv.log_sqrt_term + exp_term 

1293 

1294 

1295@named_call 

1296def compute_likelihood_ratio( 

1297 resid_lrt: (Float32[Array, ' 3'] | Float32[Array, 'k 3'] | Float32[Array, 'k k 3']), 

1298 prelkv: PreLkV, 

1299 error_cov_inv: Float32[Array, 'k k'] | None, 

1300) -> Float32[Array, '']: 

1301 """ 

1302 Compute the likelihood ratio of a grow move. 

1303 

1304 Parameters 

1305 ---------- 

1306 resid_lrt 

1307 The sum of the residuals (scaled by error precision scale) of the 

1308 datapoints falling in the left child, right child, and parent node 

1309 involved in the move, stacked along the trailing axis. 

1310 prelkv 

1311 The pre-computed terms of the likelihood ratio, see 

1312 `precompute_likelihood_terms`. 

1313 error_cov_inv 

1314 The global error precision scale. Set only in the multivariate 

1315 vector-weight case. 

1316 

1317 Returns 

1318 ------- 

1319 The log-likelihood ratio log P(data | new tree) - log P(data | old tree). 

1320 """ 

1321 if error_cov_inv is not None: 

1322 return _compute_likelihood_ratio_mv_het(resid_lrt, error_cov_inv, prelkv) 

1323 elif resid_lrt.ndim > 1: 

1324 return _compute_likelihood_ratio_mv(resid_lrt, prelkv) 

1325 else: 

1326 return _compute_likelihood_ratio_uv(resid_lrt, prelkv) 

1327 

1328 

1329@named_call 

1330def accept_moves_final_stage(state: State, moves: Moves) -> State: 

1331 """ 

1332 Post-process the mcmc state after accepting/rejecting the moves. 

1333 

1334 This function is separate from `accept_moves_sequential_stage` to signal it 

1335 can work in parallel across trees. 

1336 

1337 Parameters 

1338 ---------- 

1339 state 

1340 A partially updated BART mcmc state. 

1341 moves 

1342 The proposed moves (see `propose_moves`) as updated by 

1343 `accept_moves_sequential_stage`. 

1344 

1345 Returns 

1346 ------- 

1347 The fully updated BART mcmc state. 

1348 """ 

1349 assert moves.acc is not None 

1350 return replace( 

1351 state, 

1352 forest=replace( 

1353 state.forest, 

1354 grow_acc_count=jnp.sum(moves.acc & moves.grow), 

1355 prune_acc_count=jnp.sum(moves.acc & ~moves.grow), 

1356 leaf_indices=apply_moves_to_leaf_indices(state.forest.leaf_indices, moves), 

1357 split_tree=apply_moves_to_split_trees(state.forest.split_tree, moves), 

1358 affluence_tree=apply_moves_to_affluence_trees( 

1359 state.forest.affluence_tree, moves 

1360 ), 

1361 ), 

1362 ) 

1363 

1364 

1365@named_call 

1366def apply_moves_to_leaf_indices( 

1367 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves 

1368) -> UInt[Array, 'num_trees n']: 

1369 """ 

1370 Update the leaf indices to match the accepted move. 

1371 

1372 Parameters 

1373 ---------- 

1374 leaf_indices 

1375 The index of the leaf each datapoint falls into, if the grow move was 

1376 accepted. 

1377 moves 

1378 The proposed moves (see `propose_moves`), as updated by 

1379 `accept_moves_sequential_stage`. 

1380 

1381 Returns 

1382 ------- 

1383 The updated leaf indices. 

1384 """ 

1385 return _apply_moves_to_leaf_indices(leaf_indices, moves) 

1386 

1387 

1388@vmap_nodoc 

1389def _apply_moves_to_leaf_indices( 

1390 leaf_indices: UInt[Array, ' n'], moves: Moves 

1391) -> UInt[Array, ' n']: 

1392 """Implement `apply_moves_to_leaf_indices`.""" 

1393 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 

1394 is_child = (leaf_indices & mask) == moves.lrt_nodes[0] 

1395 assert moves.to_prune is not None 

1396 return jnp.where( 

1397 is_child & moves.to_prune, 

1398 moves.lrt_nodes[2].astype(leaf_indices.dtype), 

1399 leaf_indices, 

1400 ) 

1401 

1402 

1403@named_call 

1404def apply_moves_to_split_trees( 

1405 split_tree: UInt[Array, 'num_trees half_tree_size'], moves: Moves 

1406) -> UInt[Array, 'num_trees half_tree_size']: 

1407 """ 

1408 Update the split trees to match the accepted move. 

1409 

1410 Parameters 

1411 ---------- 

1412 split_tree 

1413 The cutpoints of the decision nodes in the initial trees. 

1414 moves 

1415 The proposed moves (see `propose_moves`), as updated by 

1416 `accept_moves_sequential_stage`. 

1417 

1418 Returns 

1419 ------- 

1420 The updated split trees. 

1421 """ 

1422 return _apply_moves_to_split_trees(split_tree, moves) 

1423 

1424 

1425@vmap_nodoc 

1426def _apply_moves_to_split_trees( 

1427 split_tree: UInt[Array, ' half_tree_size'], moves: Moves 

1428) -> UInt[Array, ' half_tree_size']: 

1429 """Implement `apply_moves_to_split_trees`.""" 

1430 assert moves.to_prune is not None 

1431 # a single scatter serves both cases: an accepted grow writes the new 

1432 # cutpoint, while pruning (accepted prune or rejected grow) zeroes the node 

1433 return split_tree.at[ 

1434 jnp.where(moves.grow | moves.to_prune, moves.lrt_nodes[2], split_tree.size) 

1435 ].set(jnp.where(moves.to_prune, 0, moves.grow_split).astype(split_tree.dtype)) 

1436 

1437 

1438@named_call 

1439def apply_moves_to_affluence_trees( 

1440 affluence_tree: Bool[Array, 'num_trees half_tree_size'], moves: Moves 

1441) -> Bool[Array, 'num_trees half_tree_size']: 

1442 """ 

1443 Update the affluence trees to match the accepted move. 

1444 

1445 The affluence tree marks the growable leaves; this restores that invariant 

1446 after the move by re-marking only the nodes it touched, starting from the 

1447 clean pre-move mask. 

1448 

1449 Parameters 

1450 ---------- 

1451 affluence_tree 

1452 The mask of the growable leaves in the initial trees. 

1453 moves 

1454 The proposed moves (see `propose_moves`), as updated by 

1455 `accept_moves_sequential_stage`. 

1456 

1457 Returns 

1458 ------- 

1459 The updated affluence trees. 

1460 """ 

1461 return _apply_moves_to_affluence_trees(affluence_tree, moves) 

1462 

1463 

1464@vmap_nodoc 

1465def _apply_moves_to_affluence_trees( 

1466 affluence_tree: Bool[Array, ' half_tree_size'], moves: Moves 

1467) -> Bool[Array, ' half_tree_size']: 

1468 """Implement `apply_moves_to_affluence_trees`.""" 

1469 assert moves.to_prune is not None 

1470 assert moves.lrt_affluent is not None 

1471 # GROW: node becomes internal, children become leaves with their affluence. 

1472 # PRUNE (accepted prune or rejected grow): node becomes a leaf with its 

1473 # affluence, children are deleted. Either way all three nodes are written: 

1474 # the mask keeps the affluence of the nodes that become leaves and zeroes 

1475 # the rest. If no move is applied (a rejected prune), the indices resolve 

1476 # to `size` and the writes drop. 

1477 becomes_leaf = moves.to_prune ^ jnp.array([True, True, False]) 

1478 return affluence_tree.at[ 

1479 jnp.where(moves.grow | moves.to_prune, moves.lrt_nodes, affluence_tree.size) 

1480 ].set(moves.lrt_affluent & becomes_leaf) 

1481 

1482 

1483@jit 

1484def _sample_wishart_bartlett( 

1485 key: Key[Array, ''], 

1486 df: Float32[Array, ''] | float, 

1487 scale_inv: Float32[Array, 'k k'], 

1488) -> Float32[Array, 'k k']: 

1489 """ 

1490 Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition. 

1491 

1492 Parameters 

1493 ---------- 

1494 key 

1495 A JAX random key 

1496 df 

1497 Degrees of freedom 

1498 scale_inv 

1499 Scale matrix of the corresponding Inverse Wishart distribution 

1500 

1501 Returns 

1502 ------- 

1503 A sample from Wishart(df, scale) 

1504 """ 

1505 keys = split(key) 

1506 

1507 # Diagonal elements: A_ii ~ sqrt(chi^2(df - i)), with chi^2(k) = Gamma(k/2, scale=2). 

1508 # sqrt(2 * Gamma) = sqrt(2) * exp(loggamma / 2), folding the sqrt into the exp. 

1509 k, _ = scale_inv.shape 

1510 df_vector = df - jnp.arange(k) 

1511 diag_A = jnp.sqrt(2.0) * jnp.exp(loggamma(keys.pop(), df_vector / 2.0) / 2.0) 

1512 

1513 off_diag_A = random.normal(keys.pop(), (k, k)) 

1514 A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A) 

1515 L = chol_with_gersh(scale_inv, absolute_eps=True) 

1516 T = solve_triangular(L, A, lower=True, trans='T') 

1517 

1518 return T @ T.T 

1519 

1520 

1521def _step_error_cov_inv_mv(key: Key[Array, ''], state: State) -> State: 

1522 assert state.error_cov_inv.nu is not None 

1523 assert state.error_cov_inv.rate is not None 

1524 

1525 resid = state.resid 

1526 if state.inv_sdev_scale is None: 

1527 _, n_eff = resid.shape 

1528 n_eff *= get_axis_size(state.config.mesh, 'data') 

1529 else: 

1530 # 2-D inv_sdev_scale dispatches to the diagonal path, so here it is 1-D 

1531 n_eff = jnp.sum(state.inv_sdev_scale != 0, axis=-1) 

1532 if state.config.data_sharded: 

1533 n_eff = lax.psum(n_eff, 'data') 

1534 resid *= state.inv_sdev_scale 

1535 df_post = state.error_cov_inv.nu + n_eff 

1536 rrt = resid @ resid.T 

1537 if state.config.data_sharded: 

1538 rrt = lax.psum(rrt, 'data') 

1539 scale_post = state.error_cov_inv.rate + rrt 

1540 

1541 prec = _sample_wishart_bartlett(key, df_post, scale_post) 

1542 return replace(state, error_cov_inv=replace(state.error_cov_inv, value=prec)) 

1543 

1544 

1545def _step_error_cov_inv_diag(key: Key[Array, ''], state: State) -> State: 

1546 """Per-component inverse-gamma update for univariate, mixed, and partial-missing paths.""" 

1547 assert state.error_cov_inv.rate is not None 

1548 assert state.error_cov_inv.nu is not None 

1549 

1550 resid = state.resid 

1551 if state.inv_sdev_scale is not None: 

1552 resid *= state.inv_sdev_scale 

1553 

1554 # alpha 

1555 if state.inv_sdev_scale is None: 

1556 *_, n_eff = resid.shape 

1557 n_eff *= get_axis_size(state.config.mesh, 'data') 

1558 else: 

1559 n_eff = jnp.sum(state.inv_sdev_scale != 0, axis=-1) 

1560 if state.config.data_sharded: 

1561 n_eff = lax.psum(n_eff, 'data') 

1562 alpha = state.error_cov_inv.nu / 2 + n_eff / 2 

1563 

1564 # beta 

1565 norm2 = jnp.einsum('...n,...n->...', resid, resid) 

1566 if state.config.data_sharded: 

1567 norm2 = lax.psum(norm2, 'data') 

1568 scale = state.error_cov_inv.rate 

1569 kshape = resid.shape[:-1] 

1570 if kshape: 

1571 scale = jnp.diag(scale) 

1572 beta = scale / 2 + norm2 / 2 

1573 

1574 # draw the gamma from the first of a split, mirroring the Bartlett sampler 

1575 # in the multivariate path so the two branches coincide at k=1 

1576 keys = split(key) 

1577 samples = jnp.exp(loggamma(keys.pop(), alpha, kshape)) 

1578 prec = samples / beta 

1579 if state.binary_indices is not None: 

1580 prec = prec.at[state.binary_indices].set(1.0) 

1581 if kshape: 

1582 prec = jnp.diag(prec) 

1583 return replace(state, error_cov_inv=replace(state.error_cov_inv, value=prec)) 

1584 

1585 

1586@named_call 

1587def step_error_cov_inv(key: Key[Array, ''], state: State) -> State: 

1588 """MCMC-update the inverse error covariance.""" 

1589 if ( 

1590 state.error_cov_inv.value.ndim == 2 

1591 and state.binary_indices is None 

1592 and (state.inv_sdev_scale is None or state.inv_sdev_scale.ndim == 1) 

1593 ): 

1594 return _step_error_cov_inv_mv(key, state) 

1595 else: 

1596 return _step_error_cov_inv_diag(key, state) 

1597 

1598 

1599@named_call 

1600def step_z(key: Key[Array, ''], state: State) -> State: 

1601 """ 

1602 MCMC-update the latent variable for binary regression. 

1603 

1604 Parameters 

1605 ---------- 

1606 key 

1607 A jax random key. 

1608 state 

1609 A BART MCMC state. 

1610 

1611 Returns 

1612 ------- 

1613 The updated BART MCMC state. 

1614 """ 

1615 assert state.z is not None 

1616 assert state.binary_y is not None 

1617 

1618 if state.binary_indices is not None: 

1619 resid = state.resid[..., state.binary_indices, :] 

1620 else: 

1621 resid = state.resid 

1622 

1623 trees_plus_offset = state.z - resid 

1624 if state.config.data_sharded: 

1625 # decorrelate the seed across data shards; the seed is replicated 

1626 # because the trees and most of the algorithm are replicated 

1627 key = random.fold_in(key, lax.axis_index('data')) 

1628 resid = truncated_normal_onesided(key, (), ~state.binary_y, -trees_plus_offset) 

1629 z = trees_plus_offset + resid 

1630 

1631 if state.binary_indices is not None: 

1632 resid = state.resid.at[..., state.binary_indices, :].set(resid) 

1633 

1634 return replace(state, z=z, resid=resid) 

1635 

1636 

1637def _blocked_mass_tree( 

1638 key: Key[Array, ''], 

1639 var_tree: UInt[Array, ' half_tree_size'], 

1640 split_tree: UInt[Array, ' half_tree_size'], 

1641 max_split: UInt[Array, ' p'], 

1642 s: Float32[Array, ' p'], 

1643) -> Float32[Array, ' p']: 

1644 """Per-variable data-augmentation mass blocked by a single tree. 

1645 

1646 At each internal node, draws the latent augmentation weight ``lambda / e`` 

1647 (``lambda`` exponential, ``e`` the eligible split probability mass at the 

1648 node) and adds it to every variable ineligible at that node. 

1649 

1650 Parameters 

1651 ---------- 

1652 key 

1653 Random key for sampling. 

1654 var_tree 

1655 The splitting axes of the tree. 

1656 split_tree 

1657 The splitting points of the tree. 

1658 max_split 

1659 The maximum split index for each variable. 

1660 s 

1661 Split probabilities normalized over selectable variables. 

1662 

1663 Returns 

1664 ------- 

1665 The blocked mass for each variable. 

1666 """ 

1667 (half_tree_size,) = split_tree.shape 

1668 d_minus_1 = half_tree_size.bit_length() - 1 # number of decision-node levels 

1669 p = max_split.size 

1670 nodes = jnp.arange(half_tree_size) 

1671 split = split_tree.astype(jnp.int32) 

1672 is_internal = split_tree.astype(bool) 

1673 

1674 # Range [lo, hi) of cutpoints still available for each node's own splitting 

1675 # variable, given the constraints inherited from the ancestors. 

1676 lo, hi = vmap(split_range, in_axes=(None, None, None, 0, 0))( 

1677 var_tree, split_tree, max_split, nodes, var_tree 

1678 ) 

1679 

1680 # An internal node exhausts its own variable for a child when its cutpoint 

1681 # sits at the matching end of the available range, so the variable becomes 

1682 # ineligible throughout that child's subtree. Row 0 is the left child (low 

1683 # end lo), row 1 the right child (high end hi - 1). 

1684 blocks = is_internal & (split == jnp.stack([lo, hi - 1])) 

1685 

1686 # A node can block at most its own splitting variable, so the per-variable 

1687 # totals are recovered from these per-node blocks via top-down/bottom-up 

1688 # accumulation over depth levels, rather than scanning each node's ancestors. 

1689 

1690 # Ineligible mass per node: the s-mass of the variables blocked along the 

1691 # path from the root. Each variable is blocked at exactly one node per path, 

1692 # so summing the per-node increments top-down reproduces the per-node sum 

1693 # over distinct ineligible variables. 

1694 parent = nodes >> 1 

1695 side = nodes & 1 # 0 if the node is a left child, 1 if a right child 

1696 parent_blocks = blocks[side, parent] 

1697 # var_tree[parent] is a valid index wherever parent_blocks holds (the parent 

1698 # is then internal); elsewhere the clamped gather is masked away 

1699 ineligible_mass = jnp.where(parent_blocks, s[var_tree[parent]], 0.0) 

1700 for level in range(1, d_minus_1): 

1701 lhs, rhs = 1 << level, 1 << (level + 1) 

1702 parent_mass = jnp.repeat(ineligible_mass[lhs >> 1 : rhs >> 1], 2) 

1703 ineligible_mass = ineligible_mass.at[lhs:rhs].add(parent_mass) 

1704 

1705 # Per-node augmentation weight lambda_b / e_b, zero at non-internal nodes. The 

1706 # eligible mass is positive at internal nodes (the split variable is eligible); 

1707 # the floor only guards against round-off, and is unused where weight is zero. 

1708 eligible_mass = jnp.maximum(1.0 - ineligible_mass, jnp.finfo(jnp.float32).eps) 

1709 weight = jnp.where(is_internal, random.exponential(key, (half_tree_size,)), 0.0) 

1710 weight /= eligible_mass 

1711 

1712 # Subtree weight: total weight of each node's internal descendants and itself, 

1713 # accumulated bottom-up. The children of the deepest decision level are leaves 

1714 # and contribute nothing. 

1715 subtree_weight = weight 

1716 for level in range(d_minus_1 - 2, -1, -1): 

1717 lhs, rhs = 1 << level, 1 << (level + 1) 

1718 children = subtree_weight[2 * lhs : 2 * rhs].reshape(-1, 2).sum(axis=1) 

1719 subtree_weight = subtree_weight.at[lhs:rhs].add(children) 

1720 

1721 # A variable blocked by node b at one of its children is ineligible in that 

1722 # whole child subtree, so it accumulates the subtree weight; scatter it onto 

1723 # the splitting variable. Only the upper half of nodes have internal children 

1724 # (the deepest decision nodes block into leaves, contributing nothing); their 

1725 # children are exactly subtree_weight reshaped into [left, right] pairs. 

1726 half = half_tree_size // 2 

1727 contrib = (blocks[:, :half] * subtree_weight.reshape(-1, 2).T).sum(axis=0) 

1728 scatter_var = jnp.where(is_internal, var_tree, p)[:half] 

1729 return jnp.zeros(p).at[scatter_var].add(contrib) 

1730 

1731 

1732def sample_s_augmentation(key: Key[Array, ''], forest: Forest) -> Int32[Array, ' p']: 

1733 """Sample the data-augmentation counts for the exact full conditional of `s`. 

1734 

1735 At each internal node, the variables with no available cutpoint given the 

1736 ancestors (plus the globally blocked ones) cannot be split on, so the plain 

1737 Dirichlet update for `s` is only approximate. This samples, for each 

1738 variable, the number of ineligible draws discarded before each realized 

1739 split, to be added to the variable usage counts. 

1740 

1741 Parameters 

1742 ---------- 

1743 key 

1744 Random key for sampling. 

1745 forest 

1746 The forest, providing the trees and the current `log_s`. 

1747 

1748 Returns 

1749 ------- 

1750 The discarded-draws count for each variable. 

1751 """ 

1752 assert forest.log_s is not None 

1753 keys = split(key) 

1754 (num_trees, _) = forest.var_tree.shape 

1755 

1756 # split probabilities normalized over the selectable (non-blocked) variables 

1757 selectable = forest.max_split > 0 

1758 s = softmax(forest.log_s, where=selectable) 

1759 

1760 # blocked_mass[j] = sum over internal nodes where j is ineligible of 

1761 # lambda_b / e_b, with lambda_b ~ Exponential(1) 

1762 blocked_mass = vmap(_blocked_mass_tree, in_axes=(0, 0, 0, None, None))( 

1763 keys.pop(num_trees), forest.var_tree, forest.split_tree, forest.max_split, s 

1764 ).sum(axis=0) # shape (p,) 

1765 

1766 # the per-node discarded-draw counts are negative-multinomial, with no 

1767 # closed form when summed over nodes, but their Gamma-Poisson mixture does: 

1768 # A_j | {lambda_b} ~ Poisson(s_j * blocked_mass[j]), independent across j 

1769 return random.poisson(keys.pop(), s * blocked_mass, dtype=jnp.int32) 

1770 

1771 

1772@named_call 

1773def step_s(key: Key[Array, ''], state: State) -> State: 

1774 """ 

1775 Update `log_s` using Dirichlet sampling. 

1776 

1777 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior 

1778 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where 

1779 varcount is the count of how many times each variable is used in the 

1780 current forest. 

1781 

1782 Parameters 

1783 ---------- 

1784 key 

1785 Random key for sampling. 

1786 state 

1787 The current BART state. 

1788 

1789 Returns 

1790 ------- 

1791 Updated BART state with re-sampled `log_s`. 

1792 

1793 Notes 

1794 ----- 

1795 By default this full conditional is approximate, because it ignores the 

1796 decision rules forbidden by the ancestors of each node. If 

1797 ``state.config.augment`` is set, the forbidden rules are accounted for 

1798 exactly with the data augmentation of `sample_s_augmentation`. 

1799 """ 

1800 assert state.forest.theta is not None 

1801 

1802 # reserve the Dirichlet draw key first and unconditionally, so it does not 

1803 # depend on whether augmentation is on; then the two modes draw identically 

1804 # when there are no forbidden rules, since the augmentation is exactly zero 

1805 keys = split(key) 

1806 log_s_key = keys.pop() 

1807 

1808 # histogram current variable usage 

1809 p = state.forest.max_split.size 

1810 varcount = var_histogram( 

1811 p, state.forest.var_tree, state.forest.split_tree, sum_batch_axis=-1 

1812 ) 

1813 

1814 # the Dirichlet posterior concentration, optionally completed with the exact 

1815 # accounting of forbidden rules via data augmentation 

1816 alpha = state.forest.theta / p + varcount 

1817 if state.config.augment: 

1818 alpha = alpha + sample_s_augmentation(keys.pop(), state.forest) 

1819 

1820 # sample from the Dirichlet posterior and update the forest with the new s 

1821 log_s = loggamma(log_s_key, alpha) 

1822 return replace(state, forest=replace(state.forest, log_s=log_s)) 

1823 

1824 

1825@named_call 

1826def step_theta(key: Key[Array, ''], state: State, *, num_grid: int = 1000) -> State: 

1827 """ 

1828 Update `theta`. 

1829 

1830 The prior is theta / (theta + rho) ~ Beta(a, b). 

1831 

1832 Parameters 

1833 ---------- 

1834 key 

1835 Random key for sampling. 

1836 state 

1837 The current BART state. 

1838 num_grid 

1839 The number of points in the evenly-spaced grid used to sample 

1840 theta / (theta + rho). 

1841 

1842 Returns 

1843 ------- 

1844 Updated BART state with re-sampled `theta`. 

1845 """ 

1846 assert state.forest.log_s is not None 

1847 assert state.forest.rho is not None 

1848 assert state.forest.a is not None 

1849 assert state.forest.b is not None 

1850 

1851 # the grid points are the midpoints of num_grid bins in (0, 1) 

1852 padding = 1 / (2 * num_grid) 

1853 lambda_grid = jnp.linspace(padding, 1 - padding, num_grid) 

1854 

1855 # normalize s 

1856 log_s = state.forest.log_s - logsumexp(state.forest.log_s) 

1857 

1858 # sample lambda 

1859 logp, theta_grid = _log_p_lambda( 

1860 lambda_grid, log_s, state.forest.rho, state.forest.a, state.forest.b 

1861 ) 

1862 i = random.categorical(key, logp) 

1863 theta = theta_grid[i] 

1864 

1865 return replace(state, forest=replace(state.forest, theta=theta)) 

1866 

1867 

1868def _log_p_lambda( 

1869 lambda_: Float32[Array, ' num_grid'], 

1870 log_s: Float32[Array, ' p'], 

1871 rho: Float32[Array, ''], 

1872 a: Float32[Array, ''], 

1873 b: Float32[Array, ''], 

1874) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]: 

1875 # in the following I use lambda_[::-1] == 1 - lambda_ 

1876 theta = rho * lambda_ / lambda_[::-1] 

1877 p = log_s.size 

1878 return ( 

1879 (a - 1) * jnp.log1p(-lambda_[::-1]) # log(lambda) 

1880 + (b - 1) * jnp.log1p(-lambda_) # log(1 - lambda) 

1881 + gammaln(theta) 

1882 - p * gammaln(theta / p) 

1883 + theta / p * jnp.sum(log_s) 

1884 ), theta 

1885 

1886 

1887@named_call 

1888def step_sparse(key: Key[Array, ''], state: State) -> State: 

1889 """ 

1890 Update the sparsity parameters. 

1891 

1892 This invokes `step_s`, and then `step_theta` only if the parameters of 

1893 the theta prior are defined. 

1894 

1895 Parameters 

1896 ---------- 

1897 key 

1898 Random key for sampling. 

1899 state 

1900 The current BART state. 

1901 

1902 Returns 

1903 ------- 

1904 Updated BART state with re-sampled `log_s` and `theta`. 

1905 """ 

1906 if state.config.sparse_on_at is not None: 

1907 state = lax.cond( 

1908 state.config.steps_done < state.config.sparse_on_at, 

1909 lambda _key, state: state, 

1910 _step_sparse, 

1911 key, 

1912 state, 

1913 ) 

1914 return state 

1915 

1916 

1917def _step_sparse(key: Key[Array, ''], state: State) -> State: 

1918 keys = split(key) 

1919 state = step_s(keys.pop(), state) 

1920 if state.forest.rho is not None: 

1921 state = step_theta(keys.pop(), state) 

1922 return state 

1923 

1924 

1925@named_call 

1926def step_config(state: State) -> State: 

1927 config = state.config 

1928 config = replace(config, steps_done=config.steps_done + 1) 

1929 return replace(state, config=config)