Coverage for src / bartz / mcmcstep / _step.py: 99%

377 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-13 00:35 +0000

1# bartz/src/bartz/mcmcstep/_step.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement `step`, `step_trees`, and the accept-reject logic.""" 

26 

27from dataclasses import replace 

28from functools import partial 

29 

30try: 

31 # available since jax v0.6.1 

32 from jax import shard_map 

33except ImportError: 

34 # deprecated in jax v0.8.0 

35 from jax.experimental.shard_map import shard_map 

36 

37import jax 

38from equinox import Module, tree_at 

39from jax import lax, random 

40from jax import numpy as jnp 

41from jax.lax import cond 

42from jax.scipy.linalg import solve_triangular 

43from jax.scipy.special import gammaln, logsumexp 

44from jax.sharding import Mesh, PartitionSpec 

45from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt, UInt32 

46 

47from bartz._profiler import ( 

48 get_profile_mode, 

49 jit_and_block_if_profiling, 

50 jit_if_not_profiling, 

51 jit_if_profiling, 

52 vmap_chains_if_not_profiling, 

53 vmap_chains_if_profiling, 

54) 

55from bartz.grove import var_histogram 

56from bartz.jaxext import split, truncated_normal_onesided, vmap_nodoc 

57from bartz.mcmcstep._moves import Moves, propose_moves 

58from bartz.mcmcstep._state import State, StepConfig, chol_with_gersh, field 

59 

60 

61@partial(jit_if_not_profiling, donate_argnums=(1,)) 

62@partial(vmap_chains_if_not_profiling, auto_split_keys=True) 

63def step(key: Key[Array, ''], bart: State) -> State: 

64 """ 

65 Do one MCMC step. 

66 

67 Parameters 

68 ---------- 

69 key 

70 A jax random key. 

71 bart 

72 A BART mcmc state, as created by `init`. 

73 

74 Returns 

75 ------- 

76 The new BART mcmc state. 

77 

78 Notes 

79 ----- 

80 The memory of the input state is re-used for the output state, so the input 

81 state can not be used any more after calling `step`. All this applies 

82 outside of `jax.jit`. 

83 """ 

84 # handle the interactions between chains and profile mode 

85 num_chains = bart.forest.num_chains() 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

86 chain_shape = () if num_chains is None else (num_chains,) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

87 if get_profile_mode() and num_chains is not None and key.ndim == 0: 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

88 key = random.split(key, num_chains) 1#J:=L

89 assert key.shape == chain_shape 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

90 

91 keys = split(key, 3) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

92 

93 if bart.y.dtype == bool: 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

94 bart = replace(bart, error_cov_inv=jnp.ones(chain_shape)) 1#:9!)*+(,

95 bart = step_trees(keys.pop(), bart) 1#:9!)*+(,

96 bart = replace(bart, error_cov_inv=None) 1#:9!)*+(,

97 bart = step_z(keys.pop(), bart) 1#:9!)*+(,

98 

99 else: # continuous or multivariate regression 

100 bart = step_trees(keys.pop(), bart) 1fJ?=eMgKhNrVsWiYAZ$%'jOkPXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

101 bart = step_error_cov_inv(keys.pop(), bart) 1fJ?=eMgKhNrVsWiYAZ$%'jOkPXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

102 

103 bart = step_sparse(keys.pop(), bart) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

104 return step_config(bart) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

105 

106 

107def step_trees(key: Key[Array, ''], bart: State) -> State: 

108 """ 

109 Forest sampling step of BART MCMC. 

110 

111 Parameters 

112 ---------- 

113 key 

114 A jax random key. 

115 bart 

116 A BART mcmc state, as created by `init`. 

117 

118 Returns 

119 ------- 

120 The new BART mcmc state. 

121 

122 Notes 

123 ----- 

124 This function zeroes the proposal counters. 

125 """ 

126 keys = split(key) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

127 moves = propose_moves(keys.pop(), bart.forest) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

128 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

129 

130 

131def accept_moves_and_sample_leaves( 

132 key: Key[Array, ''], bart: State, moves: Moves 

133) -> State: 

134 """ 

135 Accept or reject the proposed moves and sample the new leaf values. 

136 

137 Parameters 

138 ---------- 

139 key 

140 A jax random key. 

141 bart 

142 A valid BART mcmc state. 

143 moves 

144 The proposed moves, see `propose_moves`. 

145 

146 Returns 

147 ------- 

148 A new (valid) BART mcmc state. 

149 """ 

150 pso = accept_moves_parallel_stage(key, bart, moves) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

151 bart, moves = accept_moves_sequential_stage(pso) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

152 return accept_moves_final_stage(bart, moves) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

153 

154 

155class Counts(Module): 

156 """ 

157 Number of datapoints in the nodes involved in proposed moves for each tree. 

158 

159 Parameters 

160 ---------- 

161 left 

162 Number of datapoints in the left child. 

163 right 

164 Number of datapoints in the right child. 

165 total 

166 Number of datapoints in the parent (``= left + right``). 

167 """ 

168 

169 left: UInt[Array, '*chains num_trees'] = field(chains=True) 

170 right: UInt[Array, '*chains num_trees'] = field(chains=True) 

171 total: UInt[Array, '*chains num_trees'] = field(chains=True) 

172 

173 

174class Precs(Module): 

175 """ 

176 Likelihood precision scale in the nodes involved in proposed moves for each tree. 

177 

178 The "likelihood precision scale" of a tree node is the sum of the inverse 

179 squared error scales of the datapoints selected by the node. 

180 

181 Parameters 

182 ---------- 

183 left 

184 Likelihood precision scale in the left child. 

185 right 

186 Likelihood precision scale in the right child. 

187 total 

188 Likelihood precision scale in the parent (``= left + right``). 

189 """ 

190 

191 left: Float32[Array, '*chains num_trees'] = field(chains=True) 

192 right: Float32[Array, '*chains num_trees'] = field(chains=True) 

193 total: Float32[Array, '*chains num_trees'] = field(chains=True) 

194 

195 

196class PreLkV(Module): 

197 """ 

198 Non-sequential terms of the likelihood ratio for each tree. 

199 

200 These terms can be computed in parallel across trees. 

201 

202 Parameters 

203 ---------- 

204 left 

205 right 

206 total 

207 In the univariate case, this is the scalar term 

208 

209 ``1 / error_cov_inv + n_* / leaf_prior_cov_inv`` 

210 

211 In the multivariate case, this is the matrix term 

212 

213 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_* * error_cov_inv) @ error_cov_inv`` 

214 

215 In both cases, ``n_*`` is n_left/right/total, the number of datapoints 

216 respectively in the left child, right child, and parent node, or the 

217 likelihood precision scale in the heteroskedastic case. 

218 log_sqrt_term 

219 The logarithm of the square root term of the likelihood ratio. 

220 """ 

221 

222 left: ( 

223 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k'] 

224 ) = field(chains=True) 

225 right: ( 

226 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k'] 

227 ) = field(chains=True) 

228 total: ( 

229 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k'] 

230 ) = field(chains=True) 

231 log_sqrt_term: Float32[Array, '*chains num_trees'] = field(chains=True) 

232 

233 

234class PreLk(Module): 

235 """ 

236 Non-sequential terms of the likelihood ratio shared by all trees. 

237 

238 Parameters 

239 ---------- 

240 exp_factor 

241 The factor to multiply the likelihood ratio by, shared by all trees. 

242 """ 

243 

244 exp_factor: Float32[Array, '*chains'] = field(chains=True) 

245 

246 

247class PreLf(Module): 

248 """ 

249 Pre-computed terms used to sample leaves from their posterior. 

250 

251 These terms can be computed in parallel across trees. 

252 

253 For each tree and leaf, the terms are scalars in the univariate case, and 

254 matrices/vectors in the multivariate case. 

255 

256 Parameters 

257 ---------- 

258 mean_factor 

259 The factor to be right-multiplied by the sum of the scaled residuals to 

260 obtain the posterior mean. 

261 centered_leaves 

262 The mean-zero normal values to be added to the posterior mean to 

263 obtain the posterior leaf samples. 

264 """ 

265 

266 mean_factor: ( 

267 Float32[Array, '*chains num_trees 2**d'] 

268 | Float32[Array, '*chains num_trees k k 2**d'] 

269 ) = field(chains=True) 

270 centered_leaves: ( 

271 Float32[Array, '*chains num_trees 2**d'] 

272 | Float32[Array, '*chains num_trees k 2**d'] 

273 ) = field(chains=True) 

274 

275 

276class ParallelStageOut(Module): 

277 """ 

278 The output of `accept_moves_parallel_stage`. 

279 

280 Parameters 

281 ---------- 

282 bart 

283 A partially updated BART mcmc state. 

284 moves 

285 The proposed moves, with `partial_ratio` set to `None` and 

286 `log_trans_prior_ratio` set to its final value. 

287 prec_trees 

288 The likelihood precision scale in each potential or actual leaf node. If 

289 there is no precision scale, this is the number of points in each leaf. 

290 move_counts 

291 The counts of the number of points in the the nodes modified by the 

292 moves. If `bart.min_points_per_leaf` is not set and 

293 `bart.prec_scale` is set, they are not computed. 

294 move_precs 

295 The likelihood precision scale in each node modified by the moves. If 

296 `bart.prec_scale` is not set, this is set to `move_counts`. 

297 prelkv 

298 prelk 

299 prelf 

300 Objects with pre-computed terms of the likelihood ratios and leaf 

301 samples. 

302 """ 

303 

304 bart: State 

305 moves: Moves 

306 prec_trees: ( 

307 Float32[Array, '*chains num_trees 2**d'] 

308 | Int32[Array, '*chains num_trees 2**d'] 

309 ) = field(chains=True) 

310 move_precs: Precs | Counts 

311 prelkv: PreLkV 

312 prelk: PreLk | None 

313 prelf: PreLf 

314 

315 

316@partial(jit_and_block_if_profiling, donate_argnums=(1, 2)) 

317@vmap_chains_if_profiling 

318def accept_moves_parallel_stage( 

319 key: Key[Array, ''], bart: State, moves: Moves 

320) -> ParallelStageOut: 

321 """ 

322 Pre-compute quantities used to accept moves, in parallel across trees. 

323 

324 Parameters 

325 ---------- 

326 key 

327 A jax random key. 

328 bart 

329 A BART mcmc state. 

330 moves 

331 The proposed moves, see `propose_moves`. 

332 

333 Returns 

334 ------- 

335 An object with all that could be done in parallel. 

336 """ 

337 # where the move is grow, modify the state like the move was accepted 

338 bart = replace( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

339 bart, 

340 forest=replace( 

341 bart.forest, 

342 var_tree=moves.var_tree, 

343 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X), 

344 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves), 

345 ), 

346 ) 

347 

348 # count number of datapoints per leaf 

349 if ( 349 ↛ 359line 349 didn't jump to line 359 because the condition on line 349 was always true1eKs(5t

350 bart.forest.min_points_per_decision_node is not None 

351 or bart.forest.min_points_per_leaf is not None 

352 or bart.prec_scale is None 

353 ): 

354 count_trees, move_counts = compute_count_trees( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

355 bart.forest.leaf_indices, moves, bart.config 

356 ) 

357 

358 # mark which leaves & potential leaves have enough points to be grown 

359 if bart.forest.min_points_per_decision_node is not None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

360 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1fJe9Mg!KhNrViYAZjOkPXx4lmn6CDo7EFI8pqG5

361 moves = replace( 1fJe9Mg!KhNrViYAZjOkPXx4lmn6CDo7EFI8pqG5

362 moves, 

363 affluence_tree=moves.affluence_tree 

364 & (count_half_trees >= bart.forest.min_points_per_decision_node), 

365 ) 

366 

367 # copy updated affluence_tree to state 

368 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

369 

370 # veto grove move if new leaves don't have enough datapoints 

371 if bart.forest.min_points_per_leaf is not None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

372 moves = replace( 1fJe9Mg!KhNsWiYAZjOkPXx4

373 moves, 

374 allowed=moves.allowed 

375 & (move_counts.left >= bart.forest.min_points_per_leaf) 

376 & (move_counts.right >= bart.forest.min_points_per_leaf), 

377 ) 

378 

379 # count number of datapoints per leaf, weighted by error precision scale 

380 if bart.prec_scale is None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

381 prec_trees = count_trees 1f#e9g!h)rsi*A+$%'j(k,Xx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

382 move_precs = move_counts 1f#e9g!h)rsi*A+$%'j(k,Xx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

383 else: 

384 prec_trees, move_precs = compute_prec_trees( 1JMKNVWYZOP

385 bart.prec_scale, bart.forest.leaf_indices, moves, bart.config 

386 ) 

387 assert move_precs is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

388 

389 # compute some missing information about moves 

390 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

391 save_ratios = bart.forest.log_likelihood is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

392 bart = replace( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

393 bart, 

394 forest=replace( 

395 bart.forest, 

396 grow_prop_count=jnp.sum(moves.grow), 

397 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow), 

398 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None, 

399 ), 

400 ) 

401 

402 assert bart.error_cov_inv is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

403 prelkv, prelk = precompute_likelihood_terms( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

404 bart.error_cov_inv, bart.forest.leaf_prior_cov_inv, move_precs 

405 ) 

406 prelf = precompute_leaf_terms( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

407 key, prec_trees, bart.error_cov_inv, bart.forest.leaf_prior_cov_inv 

408 ) 

409 

410 return ParallelStageOut( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

411 bart=bart, 

412 moves=moves, 

413 prec_trees=prec_trees, 

414 move_precs=move_precs, 

415 prelkv=prelkv, 

416 prelk=prelk, 

417 prelf=prelf, 

418 ) 

419 

420 

421@partial(vmap_nodoc, in_axes=(0, 0, None)) 

422def apply_grow_to_indices( 

423 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n'] 

424) -> UInt[Array, 'num_trees n']: 

425 """ 

426 Update the leaf indices to apply a grow move. 

427 

428 Parameters 

429 ---------- 

430 moves 

431 The proposed moves, see `propose_moves`. 

432 leaf_indices 

433 The index of the leaf each datapoint falls into. 

434 X 

435 The predictors matrix. 

436 

437 Returns 

438 ------- 

439 The updated leaf indices. 

440 """ 

441 left_child = moves.node.astype(leaf_indices.dtype) << 1 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

442 x: UInt[Array, ' n'] = X[moves.grow_var, :] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

443 go_right = x >= moves.grow_split 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

444 tree_size = jnp.array(2 * moves.var_tree.size) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

445 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

446 return jnp.where( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

447 leaf_indices == node_to_update, left_child + go_right, leaf_indices 

448 ) 

449 

450 

451@partial(vmap_nodoc, in_axes=(None, 0, 0, None)) 

452def _compute_count_or_prec_trees( 

453 prec_scale: Float32[Array, ' n'] | None, 

454 leaf_indices: UInt[Array, 'num_trees n'], 

455 moves: Moves, 

456 config: StepConfig, 

457) -> ( 

458 tuple[UInt32[Array, 'num_trees 2**d'], Counts] 

459 | tuple[Float32[Array, 'num_trees 2**d'], Precs] 

460): 

461 (tree_size,) = moves.var_tree.shape 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

462 tree_size *= 2 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

463 

464 if prec_scale is None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

465 value = 1 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

466 cls = Counts 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

467 dtype = jnp.uint32 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

468 else: 

469 value = prec_scale 1JMKNVWYZOP

470 cls = Precs 1JMKNVWYZOP

471 dtype = jnp.float32 1JMKNVWYZOP

472 

473 trees = _scatter_add( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

474 value, leaf_indices, tree_size, dtype, config.count_batch_size, config.mesh 

475 ) 

476 

477 # count datapoints in nodes modified by move 

478 left = trees[moves.left] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

479 right = trees[moves.right] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

480 counts = cls(left=left, right=right, total=left + right) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

481 

482 # write count into non-leaf node 

483 trees = trees.at[moves.node].set(counts.total) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

484 

485 return trees, counts 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

486 

487 

488def compute_count_trees( 

489 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, config: StepConfig 

490) -> tuple[UInt32[Array, 'num_trees 2**d'], Counts]: 

491 """ 

492 Count the number of datapoints in each leaf. 

493 

494 Parameters 

495 ---------- 

496 leaf_indices 

497 The index of the leaf each datapoint falls into, with the deeper version 

498 of the tree (post-GROW, pre-PRUNE). 

499 moves 

500 The proposed moves, see `propose_moves`. 

501 config 

502 The MCMC configuration. 

503 

504 Returns 

505 ------- 

506 count_trees : Int32[Array, 'num_trees 2**d'] 

507 The number of points in each potential or actual leaf node. 

508 counts : Counts 

509 The counts of the number of points in the leaves grown or pruned by the 

510 moves. 

511 """ 

512 return _compute_count_or_prec_trees(None, leaf_indices, moves, config) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

513 

514 

515def compute_prec_trees( 

516 prec_scale: Float32[Array, ' n'], 

517 leaf_indices: UInt[Array, 'num_trees n'], 

518 moves: Moves, 

519 config: StepConfig, 

520) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]: 

521 """ 

522 Compute the likelihood precision scale in each leaf. 

523 

524 Parameters 

525 ---------- 

526 prec_scale 

527 The scale of the precision of the error on each datapoint. 

528 leaf_indices 

529 The index of the leaf each datapoint falls into, with the deeper version 

530 of the tree (post-GROW, pre-PRUNE). 

531 moves 

532 The proposed moves, see `propose_moves`. 

533 config 

534 The MCMC configuration. 

535 

536 Returns 

537 ------- 

538 prec_trees : Float32[Array, 'num_trees 2**d'] 

539 The likelihood precision scale in each potential or actual leaf node. 

540 precs : Precs 

541 The likelihood precision scale in the nodes involved in the moves. 

542 """ 

543 return _compute_count_or_prec_trees(prec_scale, leaf_indices, moves, config) 1JMKNVWYZOP

544 

545 

546@partial(vmap_nodoc, in_axes=(0, None)) 

547def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves: 

548 """ 

549 Complete non-likelihood MH ratio calculation. 

550 

551 This function adds the probability of choosing a prune move over the grow 

552 move in the inverse transition, and the a priori probability that the 

553 children nodes are leaves. 

554 

555 Parameters 

556 ---------- 

557 moves 

558 The proposed moves. Must have already been updated to keep into account 

559 the thresholds on the number of datapoints per node, this happens in 

560 `accept_moves_parallel_stage`. 

561 p_nonterminal 

562 The a priori probability of each node being nonterminal conditional on 

563 its ancestors, including at the maximum depth where it should be zero. 

564 

565 Returns 

566 ------- 

567 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set. 

568 """ 

569 # can the leaves be grown? 

570 left_growable = moves.affluence_tree.at[moves.left].get( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

571 mode='fill', fill_value=False 

572 ) 

573 right_growable = moves.affluence_tree.at[moves.right].get( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

574 mode='fill', fill_value=False 

575 ) 

576 

577 # p_prune if grow 

578 other_growable_leaves = moves.num_growable >= 2 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

579 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

580 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1.0) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

581 

582 # p_prune if prune 

583 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

584 

585 # select p_prune 

586 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

587 

588 # prior probability of both children being terminal 

589 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

590 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

591 pt_children = pt_left * pt_right 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

592 

593 assert moves.partial_ratio is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

594 return replace( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

595 moves, 

596 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune), 

597 partial_ratio=None, 

598 ) 

599 

600 

601@vmap_nodoc 

602def adapt_leaf_trees_to_grow_indices( 

603 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves 

604) -> Float32[Array, 'num_trees 2**d']: 

605 """ 

606 Modify leaves such that post-grow indices work on the original tree. 

607 

608 The value of the leaf to grow is copied to what would be its children if the 

609 grow move was accepted. 

610 

611 Parameters 

612 ---------- 

613 leaf_trees 

614 The leaf values. 

615 moves 

616 The proposed moves, see `propose_moves`. 

617 

618 Returns 

619 ------- 

620 The modified leaf values. 

621 """ 

622 values_at_node = leaf_trees[..., moves.node] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

623 return ( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

624 leaf_trees.at[..., jnp.where(moves.grow, moves.left, leaf_trees.size)] 

625 .set(values_at_node) 

626 .at[..., jnp.where(moves.grow, moves.right, leaf_trees.size)] 

627 .set(values_at_node) 

628 ) 

629 

630 

631def _logdet_from_chol(L: Float32[Array, '... k k']) -> Float32[Array, '...']: 

632 """Compute logdet of A = LL' via Cholesky (sum of log of diag^2).""" 

633 diags: Float32[Array, '... k'] = jnp.diagonal(L, axis1=-2, axis2=-1) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

634 return 2.0 * jnp.sum(jnp.log(diags), axis=-1) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

635 

636 

637def _precompute_likelihood_terms_uv( 

638 error_cov_inv: Float32[Array, ''], 

639 leaf_prior_cov_inv: Float32[Array, ''], 

640 move_precs: Precs | Counts, 

641) -> tuple[PreLkV, PreLk]: 

642 sigma2 = lax.reciprocal(error_cov_inv) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.

643 sigma_mu2 = lax.reciprocal(leaf_prior_cov_inv) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.

644 left = sigma2 + move_precs.left * sigma_mu2 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.

645 right = sigma2 + move_precs.right * sigma_mu2 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.

646 total = sigma2 + move_precs.total * sigma_mu2 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.

647 prelkv = PreLkV( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.

648 left=left, 

649 right=right, 

650 total=total, 

651 log_sqrt_term=jnp.log(sigma2 * total / (left * right)) / 2, 

652 ) 

653 return prelkv, PreLk(exp_factor=error_cov_inv / leaf_prior_cov_inv / 2) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.

654 

655 

656def _precompute_likelihood_terms_mv( 

657 error_cov_inv: Float32[Array, 'k k'], 

658 leaf_prior_cov_inv: Float32[Array, 'k k'], 

659 move_precs: Counts, 

660) -> tuple[PreLkV, None]: 

661 nL: UInt[Array, 'num_trees 1 1'] = move_precs.left[..., None, None] 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

662 nR: UInt[Array, 'num_trees 1 1'] = move_precs.right[..., None, None] 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

663 nT: UInt[Array, 'num_trees 1 1'] = move_precs.total[..., None, None] 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

664 

665 L_left: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

666 error_cov_inv * nL + leaf_prior_cov_inv 

667 ) 

668 L_right: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

669 error_cov_inv * nR + leaf_prior_cov_inv 

670 ) 

671 L_total: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

672 error_cov_inv * nT + leaf_prior_cov_inv 

673 ) 

674 

675 log_sqrt_term: Float32[Array, ' num_trees'] = 0.5 * ( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

676 _logdet_from_chol(chol_with_gersh(leaf_prior_cov_inv)) 

677 + _logdet_from_chol(L_total) 

678 - _logdet_from_chol(L_left) 

679 - _logdet_from_chol(L_right) 

680 ) 

681 

682 def _term_from_chol( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

683 L: Float32[Array, 'num_trees k k'], 

684 ) -> Float32[Array, 'num_trees k k']: 

685 rhs: Float32[Array, 'num_trees k k'] = jnp.broadcast_to(error_cov_inv, L.shape) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

686 Y: Float32[Array, 'num_trees k k'] = solve_triangular(L, rhs, lower=True) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

687 return Y.mT @ Y 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

688 

689 prelkv = PreLkV( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

690 left=_term_from_chol(L_left), 

691 right=_term_from_chol(L_right), 

692 total=_term_from_chol(L_total), 

693 log_sqrt_term=log_sqrt_term, 

694 ) 

695 

696 return prelkv, None 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

697 

698 

699def precompute_likelihood_terms( 

700 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

701 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

702 move_precs: Precs | Counts, 

703) -> tuple[PreLkV, PreLk | None]: 

704 """ 

705 Pre-compute terms used in the likelihood ratio of the acceptance step. 

706 

707 Handles both univariate and multivariate cases based on the shape of the 

708 input arrays. The multivariate implementation assumes a homoskedastic error 

709 model (i.e., the residual covariance is the same for all observations). 

710 

711 Parameters 

712 ---------- 

713 error_cov_inv 

714 The inverse error variance (univariate) or the inverse of the error 

715 covariance matrix (multivariate). For univariate case, this is the 

716 inverse global error variance factor if `prec_scale` is set. 

717 leaf_prior_cov_inv 

718 The inverse prior variance of each leaf (univariate) or the inverse of 

719 prior covariance matrix of each leaf (multivariate). 

720 move_precs 

721 The likelihood precision scale in the leaves grown or pruned by the 

722 moves, under keys 'left', 'right', and 'total' (left + right). 

723 

724 Returns 

725 ------- 

726 prelkv : PreLkV 

727 Pre-computed terms of the likelihood ratio, one per tree. 

728 prelk : PreLk | None 

729 Pre-computed terms of the likelihood ratio, shared by all trees. 

730 """ 

731 if error_cov_inv.ndim == 2: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

732 assert isinstance(move_precs, Counts) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123

733 return _precompute_likelihood_terms_mv( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123

734 error_cov_inv, leaf_prior_cov_inv, move_precs 

735 ) 

736 else: 

737 return _precompute_likelihood_terms_uv( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123

738 error_cov_inv, leaf_prior_cov_inv, move_precs 

739 ) 

740 

741 

742def _precompute_leaf_terms_uv( 

743 key: Key[Array, ''], 

744 prec_trees: Float32[Array, 'num_trees 2**d'], 

745 error_cov_inv: Float32[Array, ''], 

746 leaf_prior_cov_inv: Float32[Array, ''], 

747 z: Float32[Array, 'num_trees 2**d'] | None = None, 

748) -> PreLf: 

749 prec_lk = prec_trees * error_cov_inv 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123;

750 var_post = lax.reciprocal(prec_lk + leaf_prior_cov_inv) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123;

751 if z is None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123;

752 z = random.normal(key, prec_trees.shape, error_cov_inv.dtype) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123

753 return PreLf( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123;

754 mean_factor=var_post * error_cov_inv, 

755 # | mean = mean_lk * prec_lk * var_post 

756 # | resid_tree = mean_lk * prec_tree --> 

757 # | --> mean_lk = resid_tree / prec_tree (kind of) 

758 # | mean_factor = 

759 # | = mean / resid_tree = 

760 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree = 

761 # | = 1 / prec_tree * prec_tree / sigma2 * var_post = 

762 # | = var_post / sigma2 

763 centered_leaves=z * jnp.sqrt(var_post), 

764 ) 

765 

766 

767def _precompute_leaf_terms_mv( 

768 key: Key[Array, ''], 

769 prec_trees: Float32[Array, 'num_trees 2**d'], 

770 error_cov_inv: Float32[Array, 'k k'], 

771 leaf_prior_cov_inv: Float32[Array, 'k k'], 

772 z: Float32[Array, 'num_trees 2**d k'] | None = None, 

773) -> PreLf: 

774 num_trees, tree_size = prec_trees.shape 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

775 k = error_cov_inv.shape[0] 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

776 n_k: Float32[Array, 'num_trees tree_size 1 1'] = prec_trees[..., None, None] 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

777 

778 # Only broadcast the inverse of error covariance matrix to satisfy JAX's 

779 # batching rules for `lax.linalg.solve_triangular`, which does not support 

780 # implicit broadcasting. 

781 error_cov_inv_batched = jnp.broadcast_to( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

782 error_cov_inv, (num_trees, tree_size, k, k) 

783 ) 

784 

785 posterior_precision: Float32[Array, 'num_trees tree_size k k'] = ( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

786 leaf_prior_cov_inv + n_k * error_cov_inv_batched 

787 ) 

788 

789 L_prec: Float32[Array, 'num_trees tree_size k k'] = chol_with_gersh( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

790 posterior_precision 

791 ) 

792 Y: Float32[Array, 'num_trees tree_size k k'] = solve_triangular( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

793 L_prec, error_cov_inv_batched, lower=True 

794 ) 

795 mean_factor: Float32[Array, 'num_trees tree_size k k'] = solve_triangular( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

796 L_prec, Y, trans='T', lower=True 

797 ) 

798 mean_factor = mean_factor.mT 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

799 mean_factor_out: Float32[Array, 'num_trees k k tree_size'] = jnp.moveaxis( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

800 mean_factor, 1, -1 

801 ) 

802 

803 if z is None: 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

804 z = random.normal(key, (num_trees, tree_size, k)) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123@[]^

805 centered_leaves: Float32[Array, 'num_trees tree_size k'] = solve_triangular( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

806 L_prec, z, trans='T' 

807 ) 

808 centered_leaves_out: Float32[Array, 'num_trees k tree_size'] = jnp.swapaxes( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

809 centered_leaves, -1, -2 

810 ) 

811 

812 return PreLf(mean_factor=mean_factor_out, centered_leaves=centered_leaves_out) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^

813 

814 

815def precompute_leaf_terms( 

816 key: Key[Array, ''], 

817 prec_trees: Float32[Array, 'num_trees 2**d'], 

818 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

819 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

820 z: Float32[Array, 'num_trees 2**d'] 

821 | Float32[Array, 'num_trees 2**d k'] 

822 | None = None, 

823) -> PreLf: 

824 """ 

825 Pre-compute terms used to sample leaves from their posterior. 

826 

827 Handles both univariate and multivariate cases based on the shape of the 

828 input arrays. 

829 

830 Parameters 

831 ---------- 

832 key 

833 A jax random key. 

834 prec_trees 

835 The likelihood precision scale in each potential or actual leaf node. 

836 error_cov_inv 

837 The inverse error variance (univariate) or the inverse of error 

838 covariance matrix (multivariate). For univariate case, this is the 

839 inverse global error variance factor if `prec_scale` is set. 

840 leaf_prior_cov_inv 

841 The inverse prior variance of each leaf (univariate) or the inverse of 

842 prior covariance matrix of each leaf (multivariate). 

843 z 

844 Optional standard normal noise to use for sampling the centered leaves. 

845 This is intended for testing purposes only. 

846 

847 Returns 

848 ------- 

849 Pre-computed terms for leaf sampling. 

850 """ 

851 if error_cov_inv.ndim == 2: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

852 return _precompute_leaf_terms_mv( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123

853 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z 

854 ) 

855 else: 

856 return _precompute_leaf_terms_uv( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123

857 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z 

858 ) 

859 

860 

861@partial(jit_and_block_if_profiling, donate_argnums=(0,)) 

862@vmap_chains_if_profiling 

863def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 

864 """ 

865 Accept/reject the moves one tree at a time. 

866 

867 This is the most performance-sensitive function because it contains all and 

868 only the parts of the algorithm that can not be parallelized across trees. 

869 

870 Parameters 

871 ---------- 

872 pso 

873 The output of `accept_moves_parallel_stage`. 

874 

875 Returns 

876 ------- 

877 bart : State 

878 A partially updated BART mcmc state. 

879 moves : Moves 

880 The accepted/rejected moves, with `acc` and `to_prune` set. 

881 """ 

882 

883 def loop(resid, pt): 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

884 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

885 resid, 

886 SeqStageInAllTrees( 

887 pso.bart.X, 

888 pso.bart.config.resid_batch_size, 

889 pso.bart.config.mesh, 

890 pso.bart.prec_scale, 

891 pso.bart.forest.log_likelihood is not None, 

892 pso.prelk, 

893 ), 

894 pt, 

895 ) 

896 return resid, (leaf_tree, acc, to_prune, lkratio) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

897 

898 pts = SeqStageInPerTree( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

899 pso.bart.forest.leaf_tree, 

900 pso.prec_trees, 

901 pso.moves, 

902 pso.move_precs, 

903 pso.bart.forest.leaf_indices, 

904 pso.prelkv, 

905 pso.prelf, 

906 ) 

907 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

908 

909 bart = replace( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

910 pso.bart, 

911 resid=resid, 

912 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio), 

913 ) 

914 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

915 

916 return bart, moves 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

917 

918 

919class SeqStageInAllTrees(Module): 

920 """ 

921 The inputs to `accept_move_and_sample_leaves` that are shared by all trees. 

922 

923 Parameters 

924 ---------- 

925 X 

926 The predictors. 

927 resid_batch_size 

928 The batch size for computing the sum of residuals in each leaf. 

929 mesh 

930 The mesh of devices to use. 

931 prec_scale 

932 The scale of the precision of the error on each datapoint. If None, it 

933 is assumed to be 1. 

934 save_ratios 

935 Whether to save the acceptance ratios. 

936 prelk 

937 The pre-computed terms of the likelihood ratio which are shared across 

938 trees. 

939 """ 

940 

941 X: UInt[Array, 'p n'] 

942 resid_batch_size: int | None = field(static=True) 

943 mesh: Mesh | None = field(static=True) 

944 prec_scale: Float32[Array, ' n'] | None 

945 save_ratios: bool = field(static=True) 

946 prelk: PreLk | None 

947 

948 

949class SeqStageInPerTree(Module): 

950 """ 

951 The inputs to `accept_move_and_sample_leaves` that are separate for each tree. 

952 

953 Parameters 

954 ---------- 

955 leaf_tree 

956 The leaf values of the tree. 

957 prec_tree 

958 The likelihood precision scale in each potential or actual leaf node. 

959 move 

960 The proposed move, see `propose_moves`. 

961 move_precs 

962 The likelihood precision scale in each node modified by the moves. 

963 leaf_indices 

964 The leaf indices for the largest version of the tree compatible with 

965 the move. 

966 prelkv 

967 prelf 

968 The pre-computed terms of the likelihood ratio and leaf sampling which 

969 are specific to the tree. 

970 """ 

971 

972 leaf_tree: Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'] 

973 prec_tree: Float32[Array, ' 2**d'] 

974 move: Moves 

975 move_precs: Precs | Counts 

976 leaf_indices: UInt[Array, ' n'] 

977 prelkv: PreLkV 

978 prelf: PreLf 

979 

980 

981def accept_move_and_sample_leaves( 

982 resid: Float32[Array, ' n'] | Float32[Array, ' k n'], 

983 at: SeqStageInAllTrees, 

984 pt: SeqStageInPerTree, 

985) -> tuple[ 

986 Float32[Array, ' n'] | Float32[Array, ' k n'], 

987 Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'], 

988 Bool[Array, ''], 

989 Bool[Array, ''], 

990 Float32[Array, ''] | None, 

991]: 

992 """ 

993 Accept or reject a proposed move and sample the new leaf values. 

994 

995 Parameters 

996 ---------- 

997 resid 

998 The residuals (data minus forest value). 

999 at 

1000 The inputs that are the same for all trees. 

1001 pt 

1002 The inputs that are separate for each tree. 

1003 

1004 Returns 

1005 ------- 

1006 resid : Float32[Array, 'n'] | Float32[Array, ' k n'] 

1007 The updated residuals (data minus forest value). 

1008 leaf_tree : Float32[Array, '2**d'] | Float32[Array, ' k 2**d'] 

1009 The new leaf values of the tree. 

1010 acc : Bool[Array, ''] 

1011 Whether the move was accepted. 

1012 to_prune : Bool[Array, ''] 

1013 Whether, to reflect the acceptance status of the move, the state should 

1014 be updated by pruning the leaves involved in the move. 

1015 log_lk_ratio : Float32[Array, ''] | None 

1016 The logarithm of the likelihood ratio for the move. `None` if not to be 

1017 saved. 

1018 """ 

1019 # sum residuals in each leaf, in tree proposed by grow move 

1020 if at.prec_scale is None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1021 scaled_resid = resid 1f#e9g!h)rsi*A+$%'j(k,Xx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1022 else: 

1023 scaled_resid = resid * at.prec_scale 1JMKNVWYZOP

1024 

1025 tree_size = pt.leaf_tree.shape[-1] # 2**d 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1026 

1027 resid_tree = sum_resid( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1028 scaled_resid, pt.leaf_indices, tree_size, at.resid_batch_size, at.mesh 

1029 ) 

1030 

1031 # subtract starting tree from function 

1032 resid_tree += pt.prec_tree * pt.leaf_tree 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1033 

1034 # sum residuals in parent node modified by move and compute likelihood 

1035 resid_left = resid_tree[..., pt.move.left] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1036 resid_right = resid_tree[..., pt.move.right] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1037 resid_total = resid_left + resid_right 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1038 assert pt.move.node.dtype == jnp.int32 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1039 resid_tree = resid_tree.at[..., pt.move.node].set(resid_total) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1040 

1041 log_lk_ratio = compute_likelihood_ratio( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1042 resid_total, resid_left, resid_right, pt.prelkv, at.prelk 

1043 ) 

1044 

1045 # calculate accept/reject ratio 

1046 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1047 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1048 if not at.save_ratios: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1049 log_lk_ratio = None 1#9!)*+$%'(,Xx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1050 

1051 # determine whether to accept the move 

1052 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1053 

1054 # compute leaves posterior and sample leaves 

1055 if resid.ndim > 1: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1056 mean_post = jnp.einsum('kil,kl->il', pt.prelf.mean_factor, resid_tree) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123

1057 else: 

1058 mean_post = resid_tree * pt.prelf.mean_factor 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123

1059 leaf_tree = mean_post + pt.prelf.centered_leaves 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1060 

1061 # copy leaves around such that the leaf indices point to the correct leaf 

1062 to_prune = acc ^ pt.move.grow 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1063 leaf_tree = ( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1064 leaf_tree.at[..., jnp.where(to_prune, pt.move.left, tree_size)] 

1065 .set(leaf_tree[..., pt.move.node]) 

1066 .at[..., jnp.where(to_prune, pt.move.right, tree_size)] 

1067 .set(leaf_tree[..., pt.move.node]) 

1068 ) 

1069 # replace old tree with new tree in function values 

1070 resid += (pt.leaf_tree - leaf_tree)[..., pt.leaf_indices] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1071 

1072 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1073 

1074 

1075@partial(jnp.vectorize, excluded=(1, 2, 3, 4), signature='(n)->(ts)') 

1076def sum_resid( 

1077 scaled_resid: Float32[Array, ' n'] | Float32[Array, 'k n'], 

1078 leaf_indices: UInt[Array, ' n'], 

1079 tree_size: int, 

1080 resid_batch_size: int | None, 

1081 mesh: Mesh | None, 

1082) -> Float32[Array, ' {tree_size}'] | Float32[Array, 'k {tree_size}']: 

1083 """ 

1084 Sum the residuals in each leaf. 

1085 

1086 Handles both univariate and multivariate cases based on the shape of the 

1087 input arrays. 

1088 

1089 Parameters 

1090 ---------- 

1091 scaled_resid 

1092 The residuals (data minus forest value) multiplied by the error 

1093 precision scale. For multivariate case, shape is ``(k, n)`` where ``k`` 

1094 is the number of outcome columns. 

1095 leaf_indices 

1096 The leaf indices of the tree (in which leaf each data point falls into). 

1097 tree_size 

1098 The size of the tree array (2 ** d). 

1099 resid_batch_size 

1100 The batch size for computing the sum of residuals in each leaf. 

1101 mesh 

1102 The mesh of devices to use. 

1103 

1104 Returns 

1105 ------- 

1106 The sum of the residuals at data points in each leaf. For multivariate 

1107 case, returns per-leaf sums of residual vectors. 

1108 """ 

1109 return _scatter_add( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1110 scaled_resid, leaf_indices, tree_size, jnp.float32, resid_batch_size, mesh 

1111 ) 

1112 

1113 

1114def _scatter_add( 

1115 values: Float32[Array, ' n'] | int, 

1116 indices: Integer[Array, ' n'], 

1117 size: int, 

1118 dtype: jnp.dtype, 

1119 batch_size: int | None, 

1120 mesh: Mesh | None, 

1121) -> Shaped[Array, ' {size}']: 

1122 """Indexed reduce with optional batching.""" 

1123 # check `values` 

1124 values = jnp.asarray(values) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1125 assert values.ndim == 0 or values.shape == indices.shape 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1126 

1127 # set configuration 

1128 _scatter_add = partial( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1129 _scatter_add_impl, size=size, dtype=dtype, batch_size=batch_size 

1130 ) 

1131 

1132 # single-device invocation 

1133 if mesh is None or 'data' not in mesh.axis_names: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1134 return _scatter_add(values, indices) 1#J9M!K)NVW*YA+Z$%'(O,PXx4lmn6CDo7EFI8pqG5uHQvwtRLzSTU0123

1135 

1136 # multi-device invocation 

1137 if values.shape: 1feghrsijkbyBcda

1138 in_specs = PartitionSpec('data'), PartitionSpec('data') 1feghrsijkbyBcda

1139 else: 

1140 in_specs = PartitionSpec(), PartitionSpec('data') 1feghrsijkbyBcda

1141 _scatter_add = partial(_scatter_add, final_psum=True) 1feghrsijkbyBcda

1142 _scatter_add = shard_map( 1feghrsijkbyBcda

1143 _scatter_add, 

1144 in_specs=in_specs, 

1145 out_specs=PartitionSpec(), 

1146 mesh=mesh, 

1147 **_get_shard_map_patch_kwargs(), 

1148 ) 

1149 return _scatter_add(values, indices) 1feghrsijkbyBcda

1150 

1151 

1152def _get_shard_map_patch_kwargs(): 

1153 # see jax/issues/#34249, problem with vmap(shard_map(psum)) 

1154 if jax.__version__ in ('0.8.1', '0.8.2'): 1feghrsijkbyBcda

1155 return {'check_vma': False} 1ea

1156 else: 

1157 return {} 1feghrsijkbyBcda

1158 

1159 

1160def _scatter_add_impl( 

1161 values: Float32[Array, ' n'] | Int32[Array, ''], 

1162 indices: Integer[Array, ' n'], 

1163 /, 

1164 *, 

1165 size: int, 

1166 dtype: jnp.dtype, 

1167 batch_size: int | None, 

1168 final_psum: bool = False, 

1169) -> Shaped[Array, ' {size}']: 

1170 if batch_size is None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1171 out = jnp.zeros(size, dtype).at[indices].add(values) 1fJeMgKhNrVsWiYAZ$%'jOkPXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1172 

1173 else: 

1174 # in the sharded case, n is the size of the local shard, not the full 

1175 # size 

1176 (n,) = indices.shape 1#J9M!K)NVW*+(O,PXx4lmn6CDo7EF8pqG5uHvwtbycda

1177 nbatches = n // batch_size + bool(n % batch_size) 1#J9M!K)NVW*+(O,PXx4lmn6CDo7EF8pqG5uHvwtbycda

1178 batch_indices = jnp.arange(n) % nbatches 1#J9M!K)NVW*+(O,PXx4lmn6CDo7EF8pqG5uHvwtbycda

1179 out = ( 1#J9M!K)NVW*+(O,PXx4lmn6CDo7EF8pqG5uHvwtbycda

1180 jnp.zeros((size, nbatches), dtype) 

1181 .at[indices, batch_indices] 

1182 .add(values) 

1183 .sum(axis=1) 

1184 ) 

1185 

1186 if final_psum: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1187 out = lax.psum(out, 'data') 1feghrsijkbyBcda

1188 return out 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1189 

1190 

1191def _compute_likelihood_ratio_uv( 

1192 total_resid: Float32[Array, ''], 

1193 left_resid: Float32[Array, ''], 

1194 right_resid: Float32[Array, ''], 

1195 prelkv: PreLkV, 

1196 prelk: PreLk, 

1197) -> Float32[Array, '']: 

1198 exp_term = prelk.exp_factor * ( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.

1199 left_resid * left_resid / prelkv.left 

1200 + right_resid * right_resid / prelkv.right 

1201 - total_resid * total_resid / prelkv.total 

1202 ) 

1203 return prelkv.log_sqrt_term + exp_term 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.

1204 

1205 

1206def _compute_likelihood_ratio_mv( 

1207 total_resid: Float32[Array, ' k'], 

1208 left_resid: Float32[Array, ' k'], 

1209 right_resid: Float32[Array, ' k'], 

1210 prelkv: PreLkV, 

1211) -> Float32[Array, '']: 

1212 def _quadratic_form(r, mat): 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

1213 return r @ mat @ r 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

1214 

1215 qf_left = _quadratic_form(left_resid, prelkv.left) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

1216 qf_right = _quadratic_form(right_resid, prelkv.right) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

1217 qf_total = _quadratic_form(total_resid, prelkv.total) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

1218 exp_term = 0.5 * (qf_left + qf_right - qf_total) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

1219 return prelkv.log_sqrt_term + exp_term 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.

1220 

1221 

1222def compute_likelihood_ratio( 

1223 total_resid: Float32[Array, ''] | Float32[Array, ' k'], 

1224 left_resid: Float32[Array, ''] | Float32[Array, ' k'], 

1225 right_resid: Float32[Array, ''] | Float32[Array, ' k'], 

1226 prelkv: PreLkV, 

1227 prelk: PreLk | None, 

1228) -> Float32[Array, '']: 

1229 """ 

1230 Compute the likelihood ratio of a grow move. 

1231 

1232 Handles both univariate and multivariate cases based on the shape of the 

1233 residual arrays. 

1234 

1235 Parameters 

1236 ---------- 

1237 total_resid 

1238 left_resid 

1239 right_resid 

1240 The sum of the residuals (scaled by error precision scale) of the 

1241 datapoints falling in the nodes involved in the moves. 

1242 prelkv 

1243 prelk 

1244 The pre-computed terms of the likelihood ratio, see 

1245 `precompute_likelihood_terms`. 

1246 

1247 Returns 

1248 ------- 

1249 The log-likelihood ratio log P(data | new tree) - log P(data | old tree). 

1250 """ 

1251 if total_resid.ndim > 0: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1252 return _compute_likelihood_ratio_mv( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123

1253 total_resid, left_resid, right_resid, prelkv 

1254 ) 

1255 else: 

1256 assert prelk is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123

1257 return _compute_likelihood_ratio_uv( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123

1258 total_resid, left_resid, right_resid, prelkv, prelk 

1259 ) 

1260 

1261 

1262@partial(jit_and_block_if_profiling, donate_argnums=(0, 1)) 

1263@vmap_chains_if_profiling 

1264def accept_moves_final_stage(bart: State, moves: Moves) -> State: 

1265 """ 

1266 Post-process the mcmc state after accepting/rejecting the moves. 

1267 

1268 This function is separate from `accept_moves_sequential_stage` to signal it 

1269 can work in parallel across trees. 

1270 

1271 Parameters 

1272 ---------- 

1273 bart 

1274 A partially updated BART mcmc state. 

1275 moves 

1276 The proposed moves (see `propose_moves`) as updated by 

1277 `accept_moves_sequential_stage`. 

1278 

1279 Returns 

1280 ------- 

1281 The fully updated BART mcmc state. 

1282 """ 

1283 return replace( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1284 bart, 

1285 forest=replace( 

1286 bart.forest, 

1287 grow_acc_count=jnp.sum(moves.acc & moves.grow), 

1288 prune_acc_count=jnp.sum(moves.acc & ~moves.grow), 

1289 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves), 

1290 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves), 

1291 ), 

1292 ) 

1293 

1294 

1295@vmap_nodoc 

1296def apply_moves_to_leaf_indices( 

1297 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves 

1298) -> UInt[Array, 'num_trees n']: 

1299 """ 

1300 Update the leaf indices to match the accepted move. 

1301 

1302 Parameters 

1303 ---------- 

1304 leaf_indices 

1305 The index of the leaf each datapoint falls into, if the grow move was 

1306 accepted. 

1307 moves 

1308 The proposed moves (see `propose_moves`), as updated by 

1309 `accept_moves_sequential_stage`. 

1310 

1311 Returns 

1312 ------- 

1313 The updated leaf indices. 

1314 """ 

1315 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1316 is_child = (leaf_indices & mask) == moves.left 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1317 assert moves.to_prune is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1318 return jnp.where( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1319 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices 

1320 ) 

1321 

1322 

1323@vmap_nodoc 

1324def apply_moves_to_split_trees( 

1325 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves 

1326) -> UInt[Array, 'num_trees 2**(d-1)']: 

1327 """ 

1328 Update the split trees to match the accepted move. 

1329 

1330 Parameters 

1331 ---------- 

1332 split_tree 

1333 The cutpoints of the decision nodes in the initial trees. 

1334 moves 

1335 The proposed moves (see `propose_moves`), as updated by 

1336 `accept_moves_sequential_stage`. 

1337 

1338 Returns 

1339 ------- 

1340 The updated split trees. 

1341 """ 

1342 assert moves.to_prune is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1343 return ( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123

1344 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)] 

1345 .set(moves.grow_split.astype(split_tree.dtype)) 

1346 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)] 

1347 .set(0) 

1348 ) 

1349 

1350 

1351@jax.jit 

1352def _sample_wishart_bartlett( 

1353 key: Key[Array, ''], df: Integer[Array, ''], scale_inv: Float32[Array, 'k k'] 

1354) -> Float32[Array, 'k k']: 

1355 """ 

1356 Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition. 

1357 

1358 Parameters 

1359 ---------- 

1360 key 

1361 A JAX random key 

1362 df 

1363 Degrees of freedom 

1364 scale_inv 

1365 Scale matrix of the corresponding Inverse Wishart distribution 

1366 

1367 Returns 

1368 ------- 

1369 A sample from Wishart(df, scale) 

1370 """ 

1371 keys = split(key) 1lmnopquvwtbcda/z|}

1372 

1373 # Diagonal elements: A_ii ~ sqrt(chi^2(df - i)) 

1374 # chi^2(k) = Gamma(k/2, scale=2) 

1375 k, _ = scale_inv.shape 1lmnopquvwtbcda/z|}

1376 df_vector = df - jnp.arange(k) 1lmnopquvwtbcda/z|}

1377 chi2_samples = random.gamma(keys.pop(), df_vector / 2.0) * 2.0 1lmnopquvwtbcda/z|}

1378 diag_A = jnp.sqrt(chi2_samples) 1lmnopquvwtbcda/z|}

1379 

1380 off_diag_A = random.normal(keys.pop(), (k, k)) 1lmnopquvwtbcda/z|}

1381 A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A) 1lmnopquvwtbcda/z|}

1382 L = chol_with_gersh(scale_inv, absolute_eps=True) 1lmnopquvwtbcda/z|}

1383 T = solve_triangular(L, A, lower=True, trans='T') 1lmnopquvwtbcda/z|}

1384 

1385 return T @ T.T 1lmnopquvwtbcda/z|}

1386 

1387 

1388def _step_error_cov_inv_uv(key: Key[Array, ''], bart: State) -> State: 

1389 resid = bart.resid 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{

1390 # inverse gamma prior: alpha = df / 2, beta = scale / 2 

1391 alpha = bart.error_cov_df / 2 + resid.size / 2 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{

1392 if bart.prec_scale is None: 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{

1393 scaled_resid = resid 1feghrsiA$%'jkXx46785/_`{

1394 else: 

1395 scaled_resid = resid * bart.prec_scale 1JMKNVWYZOP

1396 norm2 = resid @ scaled_resid 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{

1397 beta = bart.error_cov_scale / 2 + norm2 / 2 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{

1398 

1399 sample = random.gamma(key, alpha) 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{

1400 # random.gamma seems to be slow at compiling, maybe cdf inversion would 

1401 # be better, but it's not implemented in jax 

1402 return replace(bart, error_cov_inv=sample / beta) 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{

1403 

1404 

1405def _step_error_cov_inv_mv(key: Key[Array, ''], bart: State) -> State: 

1406 n = bart.resid.shape[-1] 1lmnCDoEFIpqGuHQvwtbyBcdaRL/_`{zSTU

1407 df_post = bart.error_cov_df + n 1lmnCDoEFIpqGuHQvwtbyBcdaRL/_`{zSTU

1408 scale_post = bart.error_cov_scale + bart.resid @ bart.resid.T 1lmnCDoEFIpqGuHQvwtbyBcdaRL/_`{zSTU

1409 

1410 prec = _sample_wishart_bartlett(key, df_post, scale_post) 1lmnCDoEFIpqGuHQvwtbyBcdaRL/_`{zSTU

1411 return replace(bart, error_cov_inv=prec) 1lmnCDoEFIpqGuHQvwtbyBcdaRL/_`{zSTU

1412 

1413 

1414@partial(jit_and_block_if_profiling, donate_argnums=(1,)) 

1415@vmap_chains_if_profiling 

1416def step_error_cov_inv(key: Key[Array, ''], bart: State) -> State: 

1417 """ 

1418 MCMC-update the inverse error covariance. 

1419 

1420 Handles both univariate and multivariate cases based on the BART state's 

1421 `kind` attribute. 

1422 

1423 Parameters 

1424 ---------- 

1425 key 

1426 A jax random key. 

1427 bart 

1428 A BART mcmc state. 

1429 

1430 Returns 

1431 ------- 

1432 The new BART mcmc state, with an updated `error_cov_inv`. 

1433 """ 

1434 assert bart.error_cov_inv is not None 1fJeMgKhNrVsWiYAZ$%'jOkPXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

1435 if bart.error_cov_inv.ndim == 2: 1fJeMgKhNrVsWiYAZ$%'jOkPXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

1436 return _step_error_cov_inv_mv(key, bart) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU

1437 else: 

1438 return _step_error_cov_inv_uv(key, bart) 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785

1439 

1440 

1441@partial(jit_and_block_if_profiling, donate_argnums=(1,)) 

1442@vmap_chains_if_profiling 

1443def step_z(key: Key[Array, ''], bart: State) -> State: 

1444 """ 

1445 MCMC-update the latent variable for binary regression. 

1446 

1447 Parameters 

1448 ---------- 

1449 key 

1450 A jax random key. 

1451 bart 

1452 A BART MCMC state. 

1453 

1454 Returns 

1455 ------- 

1456 The updated BART MCMC state. 

1457 """ 

1458 trees_plus_offset = bart.z - bart.resid 1#9!)*+(,

1459 assert bart.y.dtype == bool 1#9!)*+(,

1460 resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset) 1#9!)*+(,

1461 z = trees_plus_offset + resid 1#9!)*+(,

1462 return replace(bart, z=z, resid=resid) 1#9!)*+(,

1463 

1464 

1465def step_s(key: Key[Array, ''], bart: State) -> State: 

1466 """ 

1467 Update `log_s` using Dirichlet sampling. 

1468 

1469 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior 

1470 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where 

1471 varcount is the count of how many times each variable is used in the 

1472 current forest. 

1473 

1474 Parameters 

1475 ---------- 

1476 key 

1477 Random key for sampling. 

1478 bart 

1479 The current BART state. 

1480 

1481 Returns 

1482 ------- 

1483 Updated BART state with re-sampled `log_s`. 

1484 

1485 Notes 

1486 ----- 

1487 This full conditional is approximated, because it does not take into account 

1488 that there are forbidden decision rules. 

1489 """ 

1490 assert bart.forest.theta is not None 1fJeMgKhNrVsWiYAZjOkPXx

1491 

1492 # histogram current variable usage 

1493 p = bart.forest.max_split.size 1fJeMgKhNrVsWiYAZjOkPXx

1494 varcount = var_histogram( 1fJeMgKhNrVsWiYAZjOkPXx

1495 p, bart.forest.var_tree, bart.forest.split_tree, sum_batch_axis=-1 

1496 ) 

1497 

1498 # sample from Dirichlet posterior 

1499 alpha = bart.forest.theta / p + varcount 1fJeMgKhNrVsWiYAZjOkPXx

1500 log_s = random.loggamma(key, alpha) 1fJeMgKhNrVsWiYAZjOkPXx

1501 

1502 # update forest with new s 

1503 return replace(bart, forest=replace(bart.forest, log_s=log_s)) 1fJeMgKhNrVsWiYAZjOkPXx

1504 

1505 

1506def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State: 

1507 """ 

1508 Update `theta`. 

1509 

1510 The prior is theta / (theta + rho) ~ Beta(a, b). 

1511 

1512 Parameters 

1513 ---------- 

1514 key 

1515 Random key for sampling. 

1516 bart 

1517 The current BART state. 

1518 num_grid 

1519 The number of points in the evenly-spaced grid used to sample 

1520 theta / (theta + rho). 

1521 

1522 Returns 

1523 ------- 

1524 Updated BART state with re-sampled `theta`. 

1525 """ 

1526 assert bart.forest.log_s is not None 1feghrsiAjkx

1527 assert bart.forest.rho is not None 1feghrsiAjkx

1528 assert bart.forest.a is not None 1feghrsiAjkx

1529 assert bart.forest.b is not None 1feghrsiAjkx

1530 

1531 # the grid points are the midpoints of num_grid bins in (0, 1) 

1532 padding = 1 / (2 * num_grid) 1feghrsiAjkx

1533 lamda_grid = jnp.linspace(padding, 1 - padding, num_grid) 1feghrsiAjkx

1534 

1535 # normalize s 

1536 log_s = bart.forest.log_s - logsumexp(bart.forest.log_s) 1feghrsiAjkx

1537 

1538 # sample lambda 

1539 logp, theta_grid = _log_p_lamda( 1feghrsiAjkx

1540 lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b 

1541 ) 

1542 i = random.categorical(key, logp) 1feghrsiAjkx

1543 theta = theta_grid[i] 1feghrsiAjkx

1544 

1545 return replace(bart, forest=replace(bart.forest, theta=theta)) 1feghrsiAjkx

1546 

1547 

1548def _log_p_lamda( 

1549 lamda: Float32[Array, ' num_grid'], 

1550 log_s: Float32[Array, ' p'], 

1551 rho: Float32[Array, ''], 

1552 a: Float32[Array, ''], 

1553 b: Float32[Array, ''], 

1554) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]: 

1555 # in the following I use lamda[::-1] == 1 - lamda 

1556 theta = rho * lamda / lamda[::-1] 1feghrsiAjkx

1557 p = log_s.size 1feghrsiAjkx

1558 return ( 1feghrsiAjkx

1559 (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda) 

1560 + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda) 

1561 + gammaln(theta) 

1562 - p * gammaln(theta / p) 

1563 + theta / p * jnp.sum(log_s) 

1564 ), theta 

1565 

1566 

1567@partial(jit_and_block_if_profiling, donate_argnums=(1,)) 

1568@vmap_chains_if_profiling 

1569def step_sparse(key: Key[Array, ''], bart: State) -> State: 

1570 """ 

1571 Update the sparsity parameters. 

1572 

1573 This invokes `step_s`, and then `step_theta` only if the parameters of 

1574 the theta prior are defined. 

1575 

1576 Parameters 

1577 ---------- 

1578 key 

1579 Random key for sampling. 

1580 bart 

1581 The current BART state. 

1582 

1583 Returns 

1584 ------- 

1585 Updated BART state with re-sampled `log_s` and `theta`. 

1586 """ 

1587 if bart.config.sparse_on_at is not None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

1588 bart = cond( 1fJeMgKhNrVsWiYAZjOkPXx

1589 bart.config.steps_done < bart.config.sparse_on_at, 

1590 lambda _key, bart: bart, 

1591 _step_sparse, 

1592 key, 

1593 bart, 

1594 ) 

1595 return bart 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

1596 

1597 

1598def _step_sparse(key, bart): 

1599 keys = split(key) 1fJeMgKhNrVsWiYAZjOkPXx

1600 bart = step_s(keys.pop(), bart) 1fJeMgKhNrVsWiYAZjOkPXx

1601 if bart.forest.rho is not None: 1fJeMgKhNrVsWiYAZjOkPXx

1602 bart = step_theta(keys.pop(), bart) 1feghrsiAjkx

1603 return bart 1fJeMgKhNrVsWiYAZjOkPXx

1604 

1605 

1606@jit_if_profiling 

1607# jit to avoid the overhead of replace(_: Module) 

1608def step_config(bart): 

1609 config = bart.config 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

1610 config = replace(config, steps_done=config.steps_done + 1) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU

1611 return replace(bart, config=config) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU