Coverage for src / bartz / mcmcstep / _step.py: 99%

420 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/mcmcstep/_step.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement `step`, `step_trees`, and the accept-reject logic.""" 

26 

27from dataclasses import replace 

28from functools import partial 

29 

30try: 

31 # available since jax v0.6.1 

32 from jax import shard_map 

33except ImportError: 

34 # deprecated in jax v0.8.0 

35 from jax.experimental.shard_map import shard_map 

36 

37import jax 

38from equinox import Module, tree_at 

39from jax import jit, lax, named_call, random, vmap 

40from jax import numpy as jnp 

41from jax.scipy.linalg import solve_triangular 

42from jax.scipy.special import gammaln, logsumexp 

43from jax.sharding import Mesh, PartitionSpec 

44from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt, UInt32 

45 

46from bartz.grove import var_histogram 

47from bartz.jaxext import split, truncated_normal_onesided, vmap_nodoc 

48from bartz.mcmcstep._moves import Moves, propose_moves 

49from bartz.mcmcstep._state import State, StepConfig, chol_with_gersh, field, vmap_chains 

50 

51 

52@partial(jit, donate_argnums=(1,)) 

53@vmap_chains 

54def step(key: Key[Array, ''], bart: State) -> State: 

55 """ 

56 Do one MCMC step. 

57 

58 Parameters 

59 ---------- 

60 key 

61 A jax random key. 

62 bart 

63 A BART mcmc state, as created by `init`. 

64 

65 Returns 

66 ------- 

67 The new BART mcmc state. 

68 

69 Notes 

70 ----- 

71 The memory of the input state is re-used for the output state, so the input 

72 state can not be used any more after calling `step`. All this applies 

73 outside of `jax.jit`. 

74 """ 

75 keys = split(key, 3) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST

76 

77 if bart.y.dtype == bool: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST

78 bart = replace(bart, error_cov_inv=jnp.array(1.0)) 1*#+4$.,/-:

79 bart = step_trees(keys.pop(), bart) 1*#+4$.,/-:

80 bart = replace(bart, error_cov_inv=None) 1*#+4$.,/-:

81 bart = step_z(keys.pop(), bart) 1*#+4$.,/-:

82 

83 else: # continuous or multivariate regression 

84 bart = step_trees(keys.pop(), bart) 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST

85 bart = step_error_cov_inv(keys.pop(), bart) 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST

86 

87 bart = step_sparse(keys.pop(), bart) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST

88 return step_config(bart) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST

89 

90 

91@named_call 

92def step_trees(key: Key[Array, ''], bart: State) -> State: 

93 """ 

94 Forest sampling step of BART MCMC. 

95 

96 Parameters 

97 ---------- 

98 key 

99 A jax random key. 

100 bart 

101 A BART mcmc state, as created by `init`. 

102 

103 Returns 

104 ------- 

105 The new BART mcmc state. 

106 

107 Notes 

108 ----- 

109 This function zeroes the proposal counters. 

110 """ 

111 keys = split(key) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

112 moves = propose_moves(keys.pop(), bart.forest) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

113 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

114 

115 

116@named_call 

117def accept_moves_and_sample_leaves( 

118 key: Key[Array, ''], bart: State, moves: Moves 

119) -> State: 

120 """ 

121 Accept or reject the proposed moves and sample the new leaf values. 

122 

123 Parameters 

124 ---------- 

125 key 

126 A jax random key. 

127 bart 

128 A valid BART mcmc state. 

129 moves 

130 The proposed moves, see `propose_moves`. 

131 

132 Returns 

133 ------- 

134 A new (valid) BART mcmc state. 

135 """ 

136 pso = accept_moves_parallel_stage(key, bart, moves) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

137 bart, moves = accept_moves_sequential_stage(pso) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

138 return accept_moves_final_stage(bart, moves) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

139 

140 

141class Counts(Module): 

142 """Number of datapoints in the nodes involved in proposed moves for each tree.""" 

143 

144 left: UInt[Array, '*chains num_trees'] = field(chains=True) 

145 """Number of datapoints in the left child.""" 

146 

147 right: UInt[Array, '*chains num_trees'] = field(chains=True) 

148 """Number of datapoints in the right child.""" 

149 

150 total: UInt[Array, '*chains num_trees'] = field(chains=True) 

151 """Number of datapoints in the parent (``= left + right``).""" 

152 

153 

154class Precs(Module): 

155 """Likelihood precision scale in the nodes involved in proposed moves for each tree. 

156 

157 The "likelihood precision scale" of a tree node is the sum of the inverse 

158 squared error scales of the datapoints selected by the node. 

159 """ 

160 

161 left: Float32[Array, '*chains num_trees'] = field(chains=True) 

162 """Likelihood precision scale in the left child.""" 

163 

164 right: Float32[Array, '*chains num_trees'] = field(chains=True) 

165 """Likelihood precision scale in the right child.""" 

166 

167 total: Float32[Array, '*chains num_trees'] = field(chains=True) 

168 """Likelihood precision scale in the parent (``= left + right``).""" 

169 

170 

171class PreLkV(Module): 

172 """Non-sequential terms of the likelihood ratio for each tree. 

173 

174 These terms can be computed in parallel across trees. 

175 """ 

176 

177 left: ( 

178 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k'] 

179 ) = field(chains=True) 

180 """In the univariate case, this is the scalar term 

181 

182 ``1 / error_cov_inv + n_left / leaf_prior_cov_inv``. 

183 

184 In the multivariate case, this is the matrix term 

185 

186 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_left * error_cov_inv) @ error_cov_inv``. 

187 

188 ``n_left`` is the number of datapoints in the left child, or the 

189 likelihood precision scale in the heteroskedastic case.""" 

190 

191 right: ( 

192 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k'] 

193 ) = field(chains=True) 

194 """In the univariate case, this is the scalar term 

195 

196 ``1 / error_cov_inv + n_right / leaf_prior_cov_inv``. 

197 

198 In the multivariate case, this is the matrix term 

199 

200 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_right * error_cov_inv) @ error_cov_inv``. 

201 

202 ``n_right`` is the number of datapoints in the right child, or the 

203 likelihood precision scale in the heteroskedastic case.""" 

204 

205 total: ( 

206 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k'] 

207 ) = field(chains=True) 

208 """In the univariate case, this is the scalar term 

209 

210 ``1 / error_cov_inv + n_total / leaf_prior_cov_inv``. 

211 

212 In the multivariate case, this is the matrix term 

213 

214 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_total * error_cov_inv) @ error_cov_inv``. 

215 

216 ``n_total`` is the number of datapoints in the parent node, or the 

217 likelihood precision scale in the heteroskedastic case.""" 

218 

219 log_sqrt_term: Float32[Array, '*chains num_trees'] = field(chains=True) 

220 """The logarithm of the square root term of the likelihood ratio.""" 

221 

222 

223class PreLk(Module): 

224 """Non-sequential terms of the likelihood ratio shared by all trees.""" 

225 

226 exp_factor: Float32[Array, '*chains'] = field(chains=True) 

227 """The factor to multiply the likelihood ratio by, shared by all trees.""" 

228 

229 

230class PreLf(Module): 

231 """Pre-computed terms used to sample leaves from their posterior. 

232 

233 These terms can be computed in parallel across trees. 

234 

235 For each tree and leaf, the terms are scalars in the univariate case, and 

236 matrices/vectors in the multivariate case. 

237 """ 

238 

239 mean_factor: ( 

240 Float32[Array, '*chains num_trees 2**d'] 

241 | Float32[Array, '*chains num_trees k k 2**d'] 

242 ) = field(chains=True) 

243 """The factor to be right-multiplied by the sum of the scaled residuals to 

244 obtain the posterior mean.""" 

245 

246 centered_leaves: ( 

247 Float32[Array, '*chains num_trees 2**d'] 

248 | Float32[Array, '*chains num_trees k 2**d'] 

249 ) = field(chains=True) 

250 """The mean-zero normal values to be added to the posterior mean to 

251 obtain the posterior leaf samples.""" 

252 

253 

254class ParallelStageOut(Module): 

255 """The output of `accept_moves_parallel_stage`.""" 

256 

257 bart: State 

258 """A partially updated BART mcmc state.""" 

259 

260 moves: Moves 

261 """The proposed moves, with `partial_ratio` set to `None` and 

262 `log_trans_prior_ratio` set to its final value.""" 

263 

264 prec_trees: ( 

265 Float32[Array, '*chains num_trees 2**d'] 

266 | Int32[Array, '*chains num_trees 2**d'] 

267 ) = field(chains=True) 

268 """The likelihood precision scale in each potential or actual leaf node. If 

269 there is no precision scale, this is the number of points in each leaf.""" 

270 

271 move_precs: Precs | Counts 

272 """The likelihood precision scale in each node modified by the moves. If 

273 `bart.prec_scale` is not set, this is set to `move_counts`.""" 

274 

275 prelkv: PreLkV 

276 """Object with pre-computed terms of the likelihood ratios.""" 

277 

278 prelk: PreLk | None 

279 """Object with pre-computed terms of the likelihood ratios.""" 

280 

281 prelf: PreLf 

282 """Object with pre-computed terms of the leaf samples.""" 

283 

284 

285@named_call 

286def accept_moves_parallel_stage( 

287 key: Key[Array, ''], bart: State, moves: Moves 

288) -> ParallelStageOut: 

289 """ 

290 Pre-compute quantities used to accept moves, in parallel across trees. 

291 

292 Parameters 

293 ---------- 

294 key 

295 A jax random key. 

296 bart 

297 A BART mcmc state. 

298 moves 

299 The proposed moves, see `propose_moves`. 

300 

301 Returns 

302 ------- 

303 An object with all that could be done in parallel. 

304 """ 

305 # where the move is grow, modify the state like the move was accepted 

306 bart = replace( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

307 bart, 

308 forest=replace( 

309 bart.forest, 

310 var_tree=moves.var_tree, 

311 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X), 

312 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves), 

313 ), 

314 ) 

315 

316 # count number of datapoints per leaf 

317 if ( 1akQzZ5s

318 bart.forest.min_points_per_decision_node is not None 

319 or bart.forest.min_points_per_leaf is not None 

320 or bart.prec_scale is None 

321 ): 

322 count_trees, move_counts = compute_count_trees( 1!0*3a#L+b4xc$MdNjPkQz.,C/%'(e-Ol:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

323 bart.forest.leaf_indices, moves, bart.config 

324 ) 

325 

326 # mark which leaves & potential leaves have enough points to be grown 

327 if bart.forest.min_points_per_decision_node is not None: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

328 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1!03a#Lbxc$MdNjPeOYw6mno7DEp8FGH9qrI5

329 moves = replace( 1!03a#Lbxc$MdNjPeOYw6mno7DEp8FGH9qrI5

330 moves, 

331 affluence_tree=moves.affluence_tree 

332 & (count_half_trees >= bart.forest.min_points_per_decision_node), 

333 ) 

334 

335 # copy updated affluence_tree to state 

336 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

337 

338 # veto grove move if new leaves don't have enough datapoints 

339 if bart.forest.min_points_per_leaf is not None: 1!0*3;a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

340 moves = replace( 1!03;a#Lbxc$MdNkQeOYw

341 moves, 

342 allowed=moves.allowed 

343 & (move_counts.left >= bart.forest.min_points_per_leaf) 

344 & (move_counts.right >= bart.forest.min_points_per_leaf), 

345 ) 

346 

347 # count number of datapoints per leaf, weighted by error precision scale 

348 if bart.prec_scale is None: 1!0*3;a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

349 prec_trees = count_trees 1!0*a#+b4c$djkz.,C/%'(e-l:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

350 move_precs = move_counts 1!0*a#+b4c$djkz.,C/%'(e-l:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

351 else: 

352 prec_trees, move_precs = compute_prec_trees( 13;LxMNPQZ1O2

353 bart.prec_scale, bart.forest.leaf_indices, moves, bart.config 

354 ) 

355 assert move_precs is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

356 

357 # compute some missing information about moves 

358 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

359 save_ratios = bart.forest.log_likelihood is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

360 bart = replace( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

361 bart, 

362 forest=replace( 

363 bart.forest, 

364 grow_prop_count=jnp.sum(moves.grow), 

365 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow), 

366 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None, 

367 ), 

368 ) 

369 

370 assert bart.error_cov_inv is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

371 prelkv, prelk = precompute_likelihood_terms( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

372 bart.error_cov_inv, bart.forest.leaf_prior_cov_inv, move_precs 

373 ) 

374 prelf = precompute_leaf_terms( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

375 key, prec_trees, bart.error_cov_inv, bart.forest.leaf_prior_cov_inv 

376 ) 

377 

378 return ParallelStageOut( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

379 bart=bart, 

380 moves=moves, 

381 prec_trees=prec_trees, 

382 move_precs=move_precs, 

383 prelkv=prelkv, 

384 prelk=prelk, 

385 prelf=prelf, 

386 ) 

387 

388 

389@named_call 

390@partial(vmap_nodoc, in_axes=(0, 0, None)) 

391def apply_grow_to_indices( 

392 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n'] 

393) -> UInt[Array, 'num_trees n']: 

394 """ 

395 Update the leaf indices to apply a grow move. 

396 

397 Parameters 

398 ---------- 

399 moves 

400 The proposed moves, see `propose_moves`. 

401 leaf_indices 

402 The index of the leaf each datapoint falls into. 

403 X 

404 The predictors matrix. 

405 

406 Returns 

407 ------- 

408 The updated leaf indices. 

409 """ 

410 left_child = moves.node.astype(leaf_indices.dtype) << 1 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

411 x: UInt[Array, ' n'] = X[moves.grow_var, :] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

412 go_right = x >= moves.grow_split 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

413 tree_size = jnp.array(2 * moves.var_tree.size) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

414 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

415 return jnp.where( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

416 leaf_indices == node_to_update, left_child + go_right, leaf_indices 

417 ) 

418 

419 

420def _compute_count_or_prec_trees( 

421 prec_scale: Float32[Array, ' n'] | None, 

422 leaf_indices: UInt[Array, 'num_trees n'], 

423 moves: Moves, 

424 config: StepConfig, 

425) -> ( 

426 tuple[UInt32[Array, 'num_trees 2**d'], Counts] 

427 | tuple[Float32[Array, 'num_trees 2**d'], Precs] 

428): 

429 """Implement `compute_count_trees` and `compute_prec_trees`.""" 

430 if config.prec_count_num_trees is None: 1!0*3;a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

431 compute = vmap(_compute_count_or_prec_tree, in_axes=(None, 0, 0, None)) 1!*#+4$.,/%'(-:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

432 return compute(prec_scale, leaf_indices, moves, config) 1!*#+4$.,/%'(-:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

433 

434 def compute( 103;aLbxcMdNjPkQzZC1eOl2

435 args: tuple[UInt[Array, ' n'], Moves], 

436 ) -> tuple[UInt32[Array, ' 2**d'], Counts] | tuple[Float32[Array, ' 2**d'], Precs]: 

437 leaf_indices, moves = args 103;aLbxcMdNjPkQzZC1eOl2

438 return _compute_count_or_prec_tree(prec_scale, leaf_indices, moves, config) 103;aLbxcMdNjPkQzZC1eOl2

439 

440 return lax.map( 103;aLbxcMdNjPkQzZC1eOl2

441 compute, (leaf_indices, moves), batch_size=config.prec_count_num_trees 

442 ) 

443 

444 

445def _compute_count_or_prec_tree( 

446 prec_scale: Float32[Array, ' n'] | None, 

447 leaf_indices: UInt[Array, ' n'], 

448 moves: Moves, 

449 config: StepConfig, 

450) -> tuple[UInt32[Array, ' 2**d'], Counts] | tuple[Float32[Array, ' 2**d'], Precs]: 

451 """Compute count or precision tree for a single tree.""" 

452 (tree_size,) = moves.var_tree.shape 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

453 tree_size *= 2 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

454 

455 if prec_scale is None: 1!0*3;a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

456 value = 1 1!0*3a#L+b4xc$MdNjPkQz.,C/%'(e-Ol:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

457 cls = Counts 1!0*3a#L+b4xc$MdNjPkQz.,C/%'(e-Ol:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

458 dtype = jnp.uint32 1!0*3a#L+b4xc$MdNjPkQz.,C/%'(e-Ol:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

459 num_batches = config.count_num_batches 1!0*3a#L+b4xc$MdNjPkQz.,C/%'(e-Ol:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

460 else: 

461 value = prec_scale 13;LxMNPQZ1O2

462 cls = Precs 13;LxMNPQZ1O2

463 dtype = jnp.float32 13;LxMNPQZ1O2

464 num_batches = config.prec_num_batches 13;LxMNPQZ1O2

465 

466 trees = _scatter_add( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

467 value, leaf_indices, tree_size, dtype, num_batches, config.mesh 

468 ) 

469 

470 # count datapoints in nodes modified by move 

471 left = trees[moves.left] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

472 right = trees[moves.right] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

473 counts = cls(left=left, right=right, total=left + right) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

474 

475 # write count into non-leaf node 

476 trees = trees.at[moves.node].set(counts.total) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

477 

478 return trees, counts 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

479 

480 

481@named_call 

482def compute_count_trees( 

483 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, config: StepConfig 

484) -> tuple[UInt32[Array, 'num_trees 2**d'], Counts]: 

485 """ 

486 Count the number of datapoints in each leaf. 

487 

488 Parameters 

489 ---------- 

490 leaf_indices 

491 The index of the leaf each datapoint falls into, with the deeper version 

492 of the tree (post-GROW, pre-PRUNE). 

493 moves 

494 The proposed moves, see `propose_moves`. 

495 config 

496 The MCMC configuration. 

497 

498 Returns 

499 ------- 

500 count_trees : Int32[Array, 'num_trees 2**d'] 

501 The number of points in each potential or actual leaf node. 

502 counts : Counts 

503 The counts of the number of points in the leaves grown or pruned by the 

504 moves. 

505 """ 

506 return _compute_count_or_prec_trees(None, leaf_indices, moves, config) 1!0*3a#L+b4xc$MdNjPkQz.,C/%'(e-Ol:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

507 

508 

509@named_call 

510def compute_prec_trees( 

511 prec_scale: Float32[Array, ' n'], 

512 leaf_indices: UInt[Array, 'num_trees n'], 

513 moves: Moves, 

514 config: StepConfig, 

515) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]: 

516 """ 

517 Compute the likelihood precision scale in each leaf. 

518 

519 Parameters 

520 ---------- 

521 prec_scale 

522 The scale of the precision of the error on each datapoint. 

523 leaf_indices 

524 The index of the leaf each datapoint falls into, with the deeper version 

525 of the tree (post-GROW, pre-PRUNE). 

526 moves 

527 The proposed moves, see `propose_moves`. 

528 config 

529 The MCMC configuration. 

530 

531 Returns 

532 ------- 

533 prec_trees : Float32[Array, 'num_trees 2**d'] 

534 The likelihood precision scale in each potential or actual leaf node. 

535 precs : Precs 

536 The likelihood precision scale in the nodes involved in the moves. 

537 """ 

538 return _compute_count_or_prec_trees(prec_scale, leaf_indices, moves, config) 13;LxMNPQZ1O2

539 

540 

541@partial(vmap_nodoc, in_axes=(0, None)) 

542def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves: 

543 """ 

544 Complete non-likelihood MH ratio calculation. 

545 

546 This function adds the probability of choosing a prune move over the grow 

547 move in the inverse transition, and the a priori probability that the 

548 children nodes are leaves. 

549 

550 Parameters 

551 ---------- 

552 moves 

553 The proposed moves. Must have already been updated to keep into account 

554 the thresholds on the number of datapoints per node, this happens in 

555 `accept_moves_parallel_stage`. 

556 p_nonterminal 

557 The a priori probability of each node being nonterminal conditional on 

558 its ancestors, including at the maximum depth where it should be zero. 

559 

560 Returns 

561 ------- 

562 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set. 

563 """ 

564 # can the leaves be grown? 

565 left_growable = moves.affluence_tree.at[moves.left].get( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

566 mode='fill', fill_value=False 

567 ) 

568 right_growable = moves.affluence_tree.at[moves.right].get( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

569 mode='fill', fill_value=False 

570 ) 

571 

572 # p_prune if grow 

573 other_growable_leaves = moves.num_growable >= 2 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

574 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

575 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1.0) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

576 

577 # p_prune if prune 

578 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

579 

580 # select p_prune 

581 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

582 

583 # prior probability of both children being terminal 

584 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

585 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

586 pt_children = pt_left * pt_right 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

587 

588 assert moves.partial_ratio is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

589 return replace( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

590 moves, 

591 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune), 

592 partial_ratio=None, 

593 ) 

594 

595 

596@named_call 

597@vmap_nodoc 

598def adapt_leaf_trees_to_grow_indices( 

599 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves 

600) -> Float32[Array, 'num_trees 2**d']: 

601 """ 

602 Modify leaves such that post-grow indices work on the original tree. 

603 

604 The value of the leaf to grow is copied to what would be its children if the 

605 grow move was accepted. 

606 

607 Parameters 

608 ---------- 

609 leaf_trees 

610 The leaf values. 

611 moves 

612 The proposed moves, see `propose_moves`. 

613 

614 Returns 

615 ------- 

616 The modified leaf values. 

617 """ 

618 values_at_node = leaf_trees[..., moves.node] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

619 return ( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

620 leaf_trees.at[..., jnp.where(moves.grow, moves.left, leaf_trees.size)] 

621 .set(values_at_node) 

622 .at[..., jnp.where(moves.grow, moves.right, leaf_trees.size)] 

623 .set(values_at_node) 

624 ) 

625 

626 

627def _logdet_from_chol(L: Float32[Array, '... k k']) -> Float32[Array, '...']: 

628 """Compute logdet of A = LL' via Cholesky (sum of log of diag^2).""" 

629 diags: Float32[Array, '... k'] = jnp.diagonal(L, axis1=-2, axis2=-1) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

630 return 2.0 * jnp.sum(jnp.log(diags), axis=-1) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

631 

632 

633def _precompute_likelihood_terms_uv( 

634 error_cov_inv: Float32[Array, ''], 

635 leaf_prior_cov_inv: Float32[Array, ''], 

636 move_precs: Precs | Counts, 

637) -> tuple[PreLkV, PreLk]: 

638 sigma2 = jnp.reciprocal(error_cov_inv) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=

639 sigma_mu2 = jnp.reciprocal(leaf_prior_cov_inv) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=

640 left = sigma2 + move_precs.left * sigma_mu2 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=

641 right = sigma2 + move_precs.right * sigma_mu2 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=

642 total = sigma2 + move_precs.total * sigma_mu2 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=

643 prelkv = PreLkV( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=

644 left=left, 

645 right=right, 

646 total=total, 

647 log_sqrt_term=jnp.log(sigma2 * total / (left * right)) / 2, 

648 ) 

649 return prelkv, PreLk(exp_factor=error_cov_inv / leaf_prior_cov_inv / 2) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=

650 

651 

652def _precompute_likelihood_terms_mv( 

653 error_cov_inv: Float32[Array, 'k k'], 

654 leaf_prior_cov_inv: Float32[Array, 'k k'], 

655 move_precs: Counts, 

656) -> tuple[PreLkV, None]: 

657 nL: UInt[Array, 'num_trees 1 1'] = move_precs.left[..., None, None] 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

658 nR: UInt[Array, 'num_trees 1 1'] = move_precs.right[..., None, None] 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

659 nT: UInt[Array, 'num_trees 1 1'] = move_precs.total[..., None, None] 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

660 

661 L_left: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

662 error_cov_inv * nL + leaf_prior_cov_inv 

663 ) 

664 L_right: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

665 error_cov_inv * nR + leaf_prior_cov_inv 

666 ) 

667 L_total: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

668 error_cov_inv * nT + leaf_prior_cov_inv 

669 ) 

670 

671 log_sqrt_term: Float32[Array, ' num_trees'] = 0.5 * ( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

672 _logdet_from_chol(chol_with_gersh(leaf_prior_cov_inv)) 

673 + _logdet_from_chol(L_total) 

674 - _logdet_from_chol(L_left) 

675 - _logdet_from_chol(L_right) 

676 ) 

677 

678 def _term_from_chol( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

679 L: Float32[Array, 'num_trees k k'], 

680 ) -> Float32[Array, 'num_trees k k']: 

681 rhs: Float32[Array, 'num_trees k k'] = jnp.broadcast_to(error_cov_inv, L.shape) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

682 Y: Float32[Array, 'num_trees k k'] = solve_triangular(L, rhs, lower=True) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

683 return Y.mT @ Y 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

684 

685 prelkv = PreLkV( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

686 left=_term_from_chol(L_left), 

687 right=_term_from_chol(L_right), 

688 total=_term_from_chol(L_total), 

689 log_sqrt_term=log_sqrt_term, 

690 ) 

691 

692 return prelkv, None 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

693 

694 

695@named_call 

696def precompute_likelihood_terms( 

697 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

698 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

699 move_precs: Precs | Counts, 

700) -> tuple[PreLkV, PreLk | None]: 

701 """ 

702 Pre-compute terms used in the likelihood ratio of the acceptance step. 

703 

704 Handles both univariate and multivariate cases based on the shape of the 

705 input arrays. The multivariate implementation assumes a homoskedastic error 

706 model (i.e., the residual covariance is the same for all observations). 

707 

708 Parameters 

709 ---------- 

710 error_cov_inv 

711 The inverse error variance (univariate) or the inverse of the error 

712 covariance matrix (multivariate). For univariate case, this is the 

713 inverse global error variance factor if `prec_scale` is set. 

714 leaf_prior_cov_inv 

715 The inverse prior variance of each leaf (univariate) or the inverse of 

716 prior covariance matrix of each leaf (multivariate). 

717 move_precs 

718 The likelihood precision scale in the leaves grown or pruned by the 

719 moves, under keys 'left', 'right', and 'total' (left + right). 

720 

721 Returns 

722 ------- 

723 prelkv : PreLkV 

724 Pre-computed terms of the likelihood ratio, one per tree. 

725 prelk : PreLk | None 

726 Pre-computed terms of the likelihood ratio, shared by all trees. 

727 """ 

728 if error_cov_inv.ndim == 2: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

729 assert isinstance(move_precs, Counts) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX

730 return _precompute_likelihood_terms_mv( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX

731 error_cov_inv, leaf_prior_cov_inv, move_precs 

732 ) 

733 else: 

734 return _precompute_likelihood_terms_uv( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX

735 error_cov_inv, leaf_prior_cov_inv, move_precs 

736 ) 

737 

738 

739def _precompute_leaf_terms_uv( 

740 key: Key[Array, ''], 

741 prec_trees: Float32[Array, 'num_trees 2**d'], 

742 error_cov_inv: Float32[Array, ''], 

743 leaf_prior_cov_inv: Float32[Array, ''], 

744 z: Float32[Array, 'num_trees 2**d'] | None = None, 

745) -> PreLf: 

746 prec_lk = prec_trees * error_cov_inv 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX@

747 var_post = jnp.reciprocal(prec_lk + leaf_prior_cov_inv) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX@

748 if z is None: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX@

749 z = random.normal(key, prec_trees.shape, error_cov_inv.dtype) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX

750 return PreLf( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX@

751 mean_factor=var_post * error_cov_inv, 

752 # | mean = mean_lk * prec_lk * var_post 

753 # | resid_tree = mean_lk * prec_tree --> 

754 # | --> mean_lk = resid_tree / prec_tree (kind of) 

755 # | mean_factor = 

756 # | = mean / resid_tree = 

757 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree = 

758 # | = 1 / prec_tree * prec_tree / sigma2 * var_post = 

759 # | = var_post / sigma2 

760 centered_leaves=z * jnp.sqrt(var_post), 

761 ) 

762 

763 

764def _precompute_leaf_terms_mv( 

765 key: Key[Array, ''], 

766 prec_trees: Float32[Array, 'num_trees 2**d'], 

767 error_cov_inv: Float32[Array, 'k k'], 

768 leaf_prior_cov_inv: Float32[Array, 'k k'], 

769 z: Float32[Array, 'num_trees 2**d k'] | None = None, 

770) -> PreLf: 

771 num_trees, tree_size = prec_trees.shape 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

772 k = error_cov_inv.shape[0] 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

773 n_k: Float32[Array, 'num_trees tree_size 1 1'] = prec_trees[..., None, None] 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

774 

775 # Only broadcast the inverse of error covariance matrix to satisfy JAX's 

776 # batching rules for `lax.linalg.solve_triangular`, which does not support 

777 # implicit broadcasting. 

778 error_cov_inv_batched = jnp.broadcast_to( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

779 error_cov_inv, (num_trees, tree_size, k, k) 

780 ) 

781 

782 posterior_precision: Float32[Array, 'num_trees tree_size k k'] = ( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

783 leaf_prior_cov_inv + n_k * error_cov_inv_batched 

784 ) 

785 

786 L_prec: Float32[Array, 'num_trees tree_size k k'] = chol_with_gersh( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

787 posterior_precision 

788 ) 

789 Y: Float32[Array, 'num_trees tree_size k k'] = solve_triangular( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

790 L_prec, error_cov_inv_batched, lower=True 

791 ) 

792 mean_factor: Float32[Array, 'num_trees tree_size k k'] = solve_triangular( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

793 L_prec, Y, trans='T', lower=True 

794 ) 

795 mean_factor = mean_factor.mT 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

796 mean_factor_out: Float32[Array, 'num_trees k k tree_size'] = jnp.moveaxis( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

797 mean_factor, 1, -1 

798 ) 

799 

800 if z is None: 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

801 z = random.normal(key, (num_trees, tree_size, k)) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX[]^_

802 centered_leaves: Float32[Array, 'num_trees tree_size k'] = solve_triangular( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

803 L_prec, z, trans='T' 

804 ) 

805 centered_leaves_out: Float32[Array, 'num_trees k tree_size'] = jnp.swapaxes( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

806 centered_leaves, -1, -2 

807 ) 

808 

809 return PreLf(mean_factor=mean_factor_out, centered_leaves=centered_leaves_out) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_

810 

811 

812@named_call 

813def precompute_leaf_terms( 

814 key: Key[Array, ''], 

815 prec_trees: Float32[Array, 'num_trees 2**d'], 

816 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

817 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

818 z: Float32[Array, 'num_trees 2**d'] 

819 | Float32[Array, 'num_trees 2**d k'] 

820 | None = None, 

821) -> PreLf: 

822 """ 

823 Pre-compute terms used to sample leaves from their posterior. 

824 

825 Handles both univariate and multivariate cases based on the shape of the 

826 input arrays. 

827 

828 Parameters 

829 ---------- 

830 key 

831 A jax random key. 

832 prec_trees 

833 The likelihood precision scale in each potential or actual leaf node. 

834 error_cov_inv 

835 The inverse error variance (univariate) or the inverse of error 

836 covariance matrix (multivariate). For univariate case, this is the 

837 inverse global error variance factor if `prec_scale` is set. 

838 leaf_prior_cov_inv 

839 The inverse prior variance of each leaf (univariate) or the inverse of 

840 prior covariance matrix of each leaf (multivariate). 

841 z 

842 Optional standard normal noise to use for sampling the centered leaves. 

843 This is intended for testing purposes only. 

844 

845 Returns 

846 ------- 

847 Pre-computed terms for leaf sampling. 

848 """ 

849 if error_cov_inv.ndim == 2: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

850 return _precompute_leaf_terms_mv( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX

851 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z 

852 ) 

853 else: 

854 return _precompute_leaf_terms_uv( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX

855 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z 

856 ) 

857 

858 

859@named_call 

860def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 

861 """ 

862 Accept/reject the moves one tree at a time. 

863 

864 This is the most performance-sensitive function because it contains all and 

865 only the parts of the algorithm that can not be parallelized across trees. 

866 

867 Parameters 

868 ---------- 

869 pso 

870 The output of `accept_moves_parallel_stage`. 

871 

872 Returns 

873 ------- 

874 bart : State 

875 A partially updated BART mcmc state. 

876 moves : Moves 

877 The accepted/rejected moves, with `acc` and `to_prune` set. 

878 """ 

879 

880 def loop( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

881 resid: Float32[Array, ' n'] | Float32[Array, ' k n'], pt: SeqStageInPerTree 

882 ) -> tuple[ 

883 Float32[Array, ' n'] | Float32[Array, ' k n'], 

884 tuple[ 

885 Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'], 

886 Bool[Array, ''], 

887 Bool[Array, ''], 

888 Float32[Array, ''] | None, 

889 ], 

890 ]: 

891 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

892 resid, 

893 SeqStageInAllTrees( 

894 pso.bart.X, 

895 pso.bart.config.resid_num_batches, 

896 pso.bart.config.mesh, 

897 pso.bart.prec_scale, 

898 pso.bart.forest.log_likelihood is not None, 

899 pso.prelk, 

900 ), 

901 pt, 

902 ) 

903 return resid, (leaf_tree, acc, to_prune, lkratio) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

904 

905 pts = SeqStageInPerTree( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

906 pso.bart.forest.leaf_tree, 

907 pso.prec_trees, 

908 pso.moves, 

909 pso.move_precs, 

910 pso.bart.forest.leaf_indices, 

911 pso.prelkv, 

912 pso.prelf, 

913 ) 

914 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

915 

916 bart = replace( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

917 pso.bart, 

918 resid=resid, 

919 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio), 

920 ) 

921 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

922 

923 return bart, moves 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

924 

925 

926class SeqStageInAllTrees(Module): 

927 """The inputs to `accept_move_and_sample_leaves` that are shared by all trees.""" 

928 

929 X: UInt[Array, 'p n'] 

930 """The predictors.""" 

931 

932 resid_num_batches: int | None = field(static=True) 

933 """The number of batches for computing the sum of residuals in each leaf.""" 

934 

935 mesh: Mesh | None = field(static=True) 

936 """The mesh of devices to use.""" 

937 

938 prec_scale: Float32[Array, ' n'] | None 

939 """The scale of the precision of the error on each datapoint. If None, it 

940 is assumed to be 1.""" 

941 

942 save_ratios: bool = field(static=True) 

943 """Whether to save the acceptance ratios.""" 

944 

945 prelk: PreLk | None 

946 """The pre-computed terms of the likelihood ratio which are shared across 

947 trees.""" 

948 

949 

950class SeqStageInPerTree(Module): 

951 """The inputs to `accept_move_and_sample_leaves` that are separate for each tree.""" 

952 

953 leaf_tree: Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'] 

954 """The leaf values of the tree.""" 

955 

956 prec_tree: Float32[Array, ' 2**d'] 

957 """The likelihood precision scale in each potential or actual leaf node.""" 

958 

959 move: Moves 

960 """The proposed move, see `propose_moves`.""" 

961 

962 move_precs: Precs | Counts 

963 """The likelihood precision scale in each node modified by the moves.""" 

964 

965 leaf_indices: UInt[Array, ' n'] 

966 """The leaf indices for the largest version of the tree compatible with 

967 the move.""" 

968 

969 prelkv: PreLkV 

970 """The pre-computed terms of the likelihood ratio which are specific to the tree.""" 

971 

972 prelf: PreLf 

973 """The pre-computed terms of the leaf sampling which are specific to the tree.""" 

974 

975 

976@named_call 

977def accept_move_and_sample_leaves( 

978 resid: Float32[Array, ' n'] | Float32[Array, ' k n'], 

979 at: SeqStageInAllTrees, 

980 pt: SeqStageInPerTree, 

981) -> tuple[ 

982 Float32[Array, ' n'] | Float32[Array, ' k n'], 

983 Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'], 

984 Bool[Array, ''], 

985 Bool[Array, ''], 

986 Float32[Array, ''] | None, 

987]: 

988 """ 

989 Accept or reject a proposed move and sample the new leaf values. 

990 

991 Parameters 

992 ---------- 

993 resid 

994 The residuals (data minus forest value). 

995 at 

996 The inputs that are the same for all trees. 

997 pt 

998 The inputs that are separate for each tree. 

999 

1000 Returns 

1001 ------- 

1002 resid : Float32[Array, 'n'] | Float32[Array, ' k n'] 

1003 The updated residuals (data minus forest value). 

1004 leaf_tree : Float32[Array, '2**d'] | Float32[Array, ' k 2**d'] 

1005 The new leaf values of the tree. 

1006 acc : Bool[Array, ''] 

1007 Whether the move was accepted. 

1008 to_prune : Bool[Array, ''] 

1009 Whether, to reflect the acceptance status of the move, the state should 

1010 be updated by pruning the leaves involved in the move. 

1011 log_lk_ratio : Float32[Array, ''] | None 

1012 The logarithm of the likelihood ratio for the move. `None` if not to be 

1013 saved. 

1014 """ 

1015 # sum residuals in each leaf, in tree proposed by grow move 

1016 if at.prec_scale is None: 1!0*3;a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1017 scaled_resid = resid 1!0*a#+b4c$djkz.,C/%'(e-l:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1018 else: 

1019 scaled_resid = resid * at.prec_scale 13;LxMNPQZ1O2

1020 

1021 tree_size = pt.leaf_tree.shape[-1] # 2**d 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1022 

1023 resid_tree = sum_resid( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1024 scaled_resid, pt.leaf_indices, tree_size, at.resid_num_batches, at.mesh 

1025 ) 

1026 

1027 # subtract starting tree from function 

1028 resid_tree += pt.prec_tree * pt.leaf_tree 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1029 

1030 # sum residuals in parent node modified by move and compute likelihood 

1031 resid_left = resid_tree[..., pt.move.left] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1032 resid_right = resid_tree[..., pt.move.right] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1033 resid_total = resid_left + resid_right 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1034 assert pt.move.node.dtype == jnp.int32 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1035 resid_tree = resid_tree.at[..., pt.move.node].set(resid_total) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1036 

1037 log_lk_ratio = compute_likelihood_ratio( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1038 resid_total, resid_left, resid_right, pt.prelkv, at.prelk 

1039 ) 

1040 

1041 # calculate accept/reject ratio 

1042 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1043 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1044 if not at.save_ratios: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1045 log_lk_ratio = None 1!*#+4$,-Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1046 

1047 # determine whether to accept the move 

1048 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1049 

1050 # compute leaves posterior and sample leaves 

1051 if resid.ndim > 1: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1052 mean_post = jnp.einsum('kil,kl->il', pt.prelf.mean_factor, resid_tree) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX

1053 else: 

1054 mean_post = resid_tree * pt.prelf.mean_factor 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX

1055 leaf_tree = mean_post + pt.prelf.centered_leaves 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1056 

1057 # copy leaves around such that the leaf indices point to the correct leaf 

1058 to_prune = acc ^ pt.move.grow 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1059 leaf_tree = ( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1060 leaf_tree.at[..., jnp.where(to_prune, pt.move.left, tree_size)] 

1061 .set(leaf_tree[..., pt.move.node]) 

1062 .at[..., jnp.where(to_prune, pt.move.right, tree_size)] 

1063 .set(leaf_tree[..., pt.move.node]) 

1064 ) 

1065 # replace old tree with new tree in function values 

1066 resid += (pt.leaf_tree - leaf_tree)[..., pt.leaf_indices] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1067 

1068 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1069 

1070 

1071@named_call 

1072@partial(jnp.vectorize, excluded=(1, 2, 3, 4), signature='(n)->(ts)') 

1073def sum_resid( 

1074 scaled_resid: Float32[Array, ' n'] | Float32[Array, 'k n'], 

1075 leaf_indices: UInt[Array, ' n'], 

1076 tree_size: int, 

1077 resid_num_batches: int | None, 

1078 mesh: Mesh | None, 

1079) -> Float32[Array, ' {tree_size}'] | Float32[Array, 'k {tree_size}']: 

1080 """ 

1081 Sum the residuals in each leaf. 

1082 

1083 Handles both univariate and multivariate cases based on the shape of the 

1084 input arrays. 

1085 

1086 Parameters 

1087 ---------- 

1088 scaled_resid 

1089 The residuals (data minus forest value) multiplied by the error 

1090 precision scale. For multivariate case, shape is ``(k, n)`` where ``k`` 

1091 is the number of outcome columns. 

1092 leaf_indices 

1093 The leaf indices of the tree (in which leaf each data point falls into). 

1094 tree_size 

1095 The size of the tree array (2 ** d). 

1096 resid_num_batches 

1097 The number of batches for computing the sum of residuals in each leaf. 

1098 mesh 

1099 The mesh of devices to use. 

1100 

1101 Returns 

1102 ------- 

1103 The sum of the residuals at data points in each leaf. For multivariate 

1104 case, returns per-leaf sums of residual vectors. 

1105 """ 

1106 return _scatter_add( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1107 scaled_resid, leaf_indices, tree_size, jnp.float32, resid_num_batches, mesh 

1108 ) 

1109 

1110 

1111def _scatter_add( 

1112 values: Float32[Array, ' n'] | int, 

1113 indices: Integer[Array, ' n'], 

1114 size: int, 

1115 dtype: jnp.dtype, 

1116 batch_size: int | None, 

1117 mesh: Mesh | None, 

1118) -> Shaped[Array, ' {size}']: 

1119 """Indexed reduce with optional batching.""" 

1120 # check `values` 

1121 values = jnp.asarray(values) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1122 assert values.ndim == 0 or values.shape == indices.shape 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1123 

1124 # set configuration 

1125 _scatter_add = partial( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1126 _scatter_add_impl, size=size, dtype=dtype, num_batches=batch_size 

1127 ) 

1128 

1129 # single-device invocation 

1130 if mesh is None or 'data' not in mesh.axis_names: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1131 return _scatter_add(values, indices) 1!*3#L+b4x$MNPQz.Z,C/1%'(-O:2Yw6mno7DEp8FGH9qrI5tJKuvsyRSTUVWX

1132 

1133 # multi-device invocation 

1134 if values.shape: 10ab4xcdjkelfABghi

1135 in_specs = PartitionSpec('data'), PartitionSpec('data') 10ab4xcdjkelfABghi

1136 else: 

1137 in_specs = PartitionSpec(), PartitionSpec('data') 10ab4xcdjkelfABghi

1138 _scatter_add = partial(_scatter_add, final_psum=True) 10ab4xcdjkelfABghi

1139 _scatter_add = shard_map( 10ab4xcdjkelfABghi

1140 _scatter_add, 

1141 in_specs=in_specs, 

1142 out_specs=PartitionSpec(), 

1143 mesh=mesh, 

1144 **_get_shard_map_patch_kwargs(), 

1145 ) 

1146 return _scatter_add(values, indices) 10ab4xcdjkelfABghi

1147 

1148 

1149def _get_shard_map_patch_kwargs() -> dict[str, bool]: 

1150 # see jax/issues/#34249, problem with vmap(shard_map(psum)) 

1151 # we tried the config jax_disable_vmap_shmap_error but it didn't work 

1152 if jax.__version__ in ('0.8.1', '0.8.2'): 1152 ↛ 1153line 1152 didn't jump to line 1153 because the condition on line 1152 was never true10ab4xcdjkelfABghi

1153 return {'check_vma': False} 

1154 else: 

1155 return {} 10ab4xcdjkelfABghi

1156 

1157 

1158def _scatter_add_impl( 

1159 values: Float32[Array, ' n'] | Int32[Array, ''], 

1160 indices: Integer[Array, ' n'], 

1161 /, 

1162 *, 

1163 size: int, 

1164 dtype: jnp.dtype, 

1165 num_batches: int | None, 

1166 final_psum: bool = False, 

1167) -> Shaped[Array, ' {size}']: 

1168 if num_batches is None: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1169 out = jnp.zeros(size, dtype).at[indices].add(values) 1!0*a#+b4c$djkz.,C/%'(e-l:fABghiyRSTUVWX

1170 

1171 else: 

1172 # in the sharded case, n is the size of the local shard, not the full size 

1173 (n,) = indices.shape 13LxMNPQZ1O2Yw6mno7DEp8FGH9qrI5tJKuvs

1174 batch_indices = jnp.arange(n) % num_batches 13LxMNPQZ1O2Yw6mno7DEp8FGH9qrI5tJKuvs

1175 out = ( 13LxMNPQZ1O2Yw6mno7DEp8FGH9qrI5tJKuvs

1176 jnp.zeros((size, num_batches), dtype) 

1177 .at[indices, batch_indices] 

1178 .add(values) 

1179 .sum(axis=1) 

1180 ) 

1181 

1182 if final_psum: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1183 out = lax.psum(out, 'data') 10ab4xcdjkelfABghi

1184 return out 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1185 

1186 

1187def _compute_likelihood_ratio_uv( 

1188 total_resid: Float32[Array, ''], 

1189 left_resid: Float32[Array, ''], 

1190 right_resid: Float32[Array, ''], 

1191 prelkv: PreLkV, 

1192 prelk: PreLk, 

1193) -> Float32[Array, '']: 

1194 exp_term = prelk.exp_factor * ( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=

1195 left_resid * left_resid / prelkv.left 

1196 + right_resid * right_resid / prelkv.right 

1197 - total_resid * total_resid / prelkv.total 

1198 ) 

1199 return prelkv.log_sqrt_term + exp_term 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=

1200 

1201 

1202def _compute_likelihood_ratio_mv( 

1203 total_resid: Float32[Array, ' k'], 

1204 left_resid: Float32[Array, ' k'], 

1205 right_resid: Float32[Array, ' k'], 

1206 prelkv: PreLkV, 

1207) -> Float32[Array, '']: 

1208 def _quadratic_form( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

1209 r: Float32[Array, ' k'], mat: Float32[Array, 'k k'] 

1210 ) -> Float32[Array, '']: 

1211 return r @ mat @ r 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

1212 

1213 qf_left = _quadratic_form(left_resid, prelkv.left) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

1214 qf_right = _quadratic_form(right_resid, prelkv.right) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

1215 qf_total = _quadratic_form(total_resid, prelkv.total) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

1216 exp_term = 0.5 * (qf_left + qf_right - qf_total) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

1217 return prelkv.log_sqrt_term + exp_term 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=

1218 

1219 

1220@named_call 

1221def compute_likelihood_ratio( 

1222 total_resid: Float32[Array, ''] | Float32[Array, ' k'], 

1223 left_resid: Float32[Array, ''] | Float32[Array, ' k'], 

1224 right_resid: Float32[Array, ''] | Float32[Array, ' k'], 

1225 prelkv: PreLkV, 

1226 prelk: PreLk | None, 

1227) -> Float32[Array, '']: 

1228 """ 

1229 Compute the likelihood ratio of a grow move. 

1230 

1231 Handles both univariate and multivariate cases based on the shape of the 

1232 residual arrays. 

1233 

1234 Parameters 

1235 ---------- 

1236 total_resid 

1237 left_resid 

1238 right_resid 

1239 The sum of the residuals (scaled by error precision scale) of the 

1240 datapoints falling in the nodes involved in the moves. 

1241 prelkv 

1242 prelk 

1243 The pre-computed terms of the likelihood ratio, see 

1244 `precompute_likelihood_terms`. 

1245 

1246 Returns 

1247 ------- 

1248 The log-likelihood ratio log P(data | new tree) - log P(data | old tree). 

1249 """ 

1250 if total_resid.ndim > 0: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1251 return _compute_likelihood_ratio_mv( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX

1252 total_resid, left_resid, right_resid, prelkv 

1253 ) 

1254 else: 

1255 assert prelk is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX

1256 return _compute_likelihood_ratio_uv( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX

1257 total_resid, left_resid, right_resid, prelkv, prelk 

1258 ) 

1259 

1260 

1261@named_call 

1262def accept_moves_final_stage(bart: State, moves: Moves) -> State: 

1263 """ 

1264 Post-process the mcmc state after accepting/rejecting the moves. 

1265 

1266 This function is separate from `accept_moves_sequential_stage` to signal it 

1267 can work in parallel across trees. 

1268 

1269 Parameters 

1270 ---------- 

1271 bart 

1272 A partially updated BART mcmc state. 

1273 moves 

1274 The proposed moves (see `propose_moves`) as updated by 

1275 `accept_moves_sequential_stage`. 

1276 

1277 Returns 

1278 ------- 

1279 The fully updated BART mcmc state. 

1280 """ 

1281 return replace( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1282 bart, 

1283 forest=replace( 

1284 bart.forest, 

1285 grow_acc_count=jnp.sum(moves.acc & moves.grow), 

1286 prune_acc_count=jnp.sum(moves.acc & ~moves.grow), 

1287 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves), 

1288 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves), 

1289 ), 

1290 ) 

1291 

1292 

1293@named_call 

1294@vmap_nodoc 

1295def apply_moves_to_leaf_indices( 

1296 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves 

1297) -> UInt[Array, 'num_trees n']: 

1298 """ 

1299 Update the leaf indices to match the accepted move. 

1300 

1301 Parameters 

1302 ---------- 

1303 leaf_indices 

1304 The index of the leaf each datapoint falls into, if the grow move was 

1305 accepted. 

1306 moves 

1307 The proposed moves (see `propose_moves`), as updated by 

1308 `accept_moves_sequential_stage`. 

1309 

1310 Returns 

1311 ------- 

1312 The updated leaf indices. 

1313 """ 

1314 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1315 is_child = (leaf_indices & mask) == moves.left 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1316 assert moves.to_prune is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1317 return jnp.where( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1318 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices 

1319 ) 

1320 

1321 

1322@named_call 

1323@vmap_nodoc 

1324def apply_moves_to_split_trees( 

1325 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves 

1326) -> UInt[Array, 'num_trees 2**(d-1)']: 

1327 """ 

1328 Update the split trees to match the accepted move. 

1329 

1330 Parameters 

1331 ---------- 

1332 split_tree 

1333 The cutpoints of the decision nodes in the initial trees. 

1334 moves 

1335 The proposed moves (see `propose_moves`), as updated by 

1336 `accept_moves_sequential_stage`. 

1337 

1338 Returns 

1339 ------- 

1340 The updated split trees. 

1341 """ 

1342 assert moves.to_prune is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1343 return ( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX

1344 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)] 

1345 .set(moves.grow_split.astype(split_tree.dtype)) 

1346 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)] 

1347 .set(0) 

1348 ) 

1349 

1350 

1351@jax.jit 

1352def _sample_wishart_bartlett( 

1353 key: Key[Array, ''], df: Float32[Array, ''], scale_inv: Float32[Array, 'k k'] 

1354) -> Float32[Array, 'k k']: 

1355 """ 

1356 Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition. 

1357 

1358 Parameters 

1359 ---------- 

1360 key 

1361 A JAX random key 

1362 df 

1363 Degrees of freedom 

1364 scale_inv 

1365 Scale matrix of the corresponding Inverse Wishart distribution 

1366 

1367 Returns 

1368 ------- 

1369 A sample from Wishart(df, scale) 

1370 """ 

1371 keys = split(key) 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb

1372 

1373 # Diagonal elements: A_ii ~ sqrt(chi^2(df - i)) 

1374 # chi^2(k) = Gamma(k/2, scale=2) 

1375 k, _ = scale_inv.shape 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb

1376 df_vector = df - jnp.arange(k) 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb

1377 chi2_samples = random.gamma(keys.pop(), df_vector / 2.0) * 2.0 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb

1378 diag_A = jnp.sqrt(chi2_samples) 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb

1379 

1380 off_diag_A = random.normal(keys.pop(), (k, k)) 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb

1381 A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A) 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb

1382 L = chol_with_gersh(scale_inv, absolute_eps=True) 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb

1383 T = solve_triangular(L, A, lower=True, trans='T') 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb

1384 

1385 return T @ T.T 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb

1386 

1387 

1388def _step_error_cov_inv_uv(key: Key[Array, ''], bart: State) -> State: 

1389 resid = bart.resid 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|

1390 # inverse gamma prior: alpha = df / 2, beta = scale / 2 

1391 alpha = bart.error_cov_df / 2 + resid.size / 2 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|

1392 if bart.prec_scale is None: 1!03;aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|

1393 scaled_resid = resid 1!0abcdjkzC%'(elYw67895?`{|

1394 else: 

1395 scaled_resid = resid * bart.prec_scale 13;LxMNPQZ1O2

1396 norm2 = resid @ scaled_resid 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|

1397 beta = bart.error_cov_scale / 2 + norm2 / 2 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|

1398 

1399 sample = random.gamma(key, alpha) 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|

1400 # random.gamma seems to be slow at compiling, maybe cdf inversion would 

1401 # be better, but it's not implemented in jax 

1402 return replace(bart, error_cov_inv=sample / beta) 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|

1403 

1404 

1405def _step_error_cov_inv_mv(key: Key[Array, ''], bart: State) -> State: 

1406 n = bart.resid.shape[-1] 1mnoDEpFGHqrItJKuvsfABghi?`{|yRST

1407 df_post = bart.error_cov_df + n 1mnoDEpFGHqrItJKuvsfABghi?`{|yRST

1408 scale_post = bart.error_cov_scale + bart.resid @ bart.resid.T 1mnoDEpFGHqrItJKuvsfABghi?`{|yRST

1409 

1410 prec = _sample_wishart_bartlett(key, df_post, scale_post) 1mnoDEpFGHqrItJKuvsfABghi?`{|yRST

1411 return replace(bart, error_cov_inv=prec) 1mnoDEpFGHqrItJKuvsfABghi?`{|yRST

1412 

1413 

1414@named_call 

1415def step_error_cov_inv(key: Key[Array, ''], bart: State) -> State: 

1416 """ 

1417 MCMC-update the inverse error covariance. 

1418 

1419 Handles both univariate and multivariate cases based on the BART state's 

1420 `kind` attribute. 

1421 

1422 Parameters 

1423 ---------- 

1424 key 

1425 A jax random key. 

1426 bart 

1427 A BART mcmc state. 

1428 

1429 Returns 

1430 ------- 

1431 The new BART mcmc state, with an updated `error_cov_inv`. 

1432 """ 

1433 assert bart.error_cov_inv is not None 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST

1434 if bart.error_cov_inv.ndim == 2: 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST

1435 return _step_error_cov_inv_mv(key, bart) 1mnoDEpFGHqrItJKuvsfABghiyRST

1436 else: 

1437 return _step_error_cov_inv_uv(key, bart) 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895

1438 

1439 

1440@named_call 

1441def step_z(key: Key[Array, ''], bart: State) -> State: 

1442 """ 

1443 MCMC-update the latent variable for binary regression. 

1444 

1445 Parameters 

1446 ---------- 

1447 key 

1448 A jax random key. 

1449 bart 

1450 A BART MCMC state. 

1451 

1452 Returns 

1453 ------- 

1454 The updated BART MCMC state. 

1455 """ 

1456 trees_plus_offset = bart.z - bart.resid 1*#+4$.,/-:

1457 assert bart.y.dtype == bool 1*#+4$.,/-:

1458 resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset) 1*#+4$.,/-:

1459 z = trees_plus_offset + resid 1*#+4$.,/-:

1460 return replace(bart, z=z, resid=resid) 1*#+4$.,/-:

1461 

1462 

1463@named_call 

1464def step_s(key: Key[Array, ''], bart: State) -> State: 

1465 """ 

1466 Update `log_s` using Dirichlet sampling. 

1467 

1468 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior 

1469 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where 

1470 varcount is the count of how many times each variable is used in the 

1471 current forest. 

1472 

1473 Parameters 

1474 ---------- 

1475 key 

1476 Random key for sampling. 

1477 bart 

1478 The current BART state. 

1479 

1480 Returns 

1481 ------- 

1482 Updated BART state with re-sampled `log_s`. 

1483 

1484 Notes 

1485 ----- 

1486 This full conditional is approximated, because it does not take into account 

1487 that there are forbidden decision rules. 

1488 """ 

1489 assert bart.forest.theta is not None 1;aLbxcMdNjPkQzZC1eOl2Yw

1490 

1491 # histogram current variable usage 

1492 p = bart.forest.max_split.size 1;aLbxcMdNjPkQzZC1eOl2Yw

1493 varcount = var_histogram( 1;aLbxcMdNjPkQzZC1eOl2Yw

1494 p, bart.forest.var_tree, bart.forest.split_tree, sum_batch_axis=-1 

1495 ) 

1496 

1497 # sample from Dirichlet posterior 

1498 alpha = bart.forest.theta / p + varcount 1;aLbxcMdNjPkQzZC1eOl2Yw

1499 log_s = random.loggamma(key, alpha) 1;aLbxcMdNjPkQzZC1eOl2Yw

1500 

1501 # update forest with new s 

1502 return replace(bart, forest=replace(bart.forest, log_s=log_s)) 1;aLbxcMdNjPkQzZC1eOl2Yw

1503 

1504 

1505@named_call 

1506def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State: 

1507 """ 

1508 Update `theta`. 

1509 

1510 The prior is theta / (theta + rho) ~ Beta(a, b). 

1511 

1512 Parameters 

1513 ---------- 

1514 key 

1515 Random key for sampling. 

1516 bart 

1517 The current BART state. 

1518 num_grid 

1519 The number of points in the evenly-spaced grid used to sample 

1520 theta / (theta + rho). 

1521 

1522 Returns 

1523 ------- 

1524 Updated BART state with re-sampled `theta`. 

1525 """ 

1526 assert bart.forest.log_s is not None 1abcdjkzCelw

1527 assert bart.forest.rho is not None 1abcdjkzCelw

1528 assert bart.forest.a is not None 1abcdjkzCelw

1529 assert bart.forest.b is not None 1abcdjkzCelw

1530 

1531 # the grid points are the midpoints of num_grid bins in (0, 1) 

1532 padding = 1 / (2 * num_grid) 1abcdjkzCelw

1533 lamda_grid = jnp.linspace(padding, 1 - padding, num_grid) 1abcdjkzCelw

1534 

1535 # normalize s 

1536 log_s = bart.forest.log_s - logsumexp(bart.forest.log_s) 1abcdjkzCelw

1537 

1538 # sample lambda 

1539 logp, theta_grid = _log_p_lamda( 1abcdjkzCelw

1540 lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b 

1541 ) 

1542 i = random.categorical(key, logp) 1abcdjkzCelw

1543 theta = theta_grid[i] 1abcdjkzCelw

1544 

1545 return replace(bart, forest=replace(bart.forest, theta=theta)) 1abcdjkzCelw

1546 

1547 

1548def _log_p_lamda( 

1549 lamda: Float32[Array, ' num_grid'], 

1550 log_s: Float32[Array, ' p'], 

1551 rho: Float32[Array, ''], 

1552 a: Float32[Array, ''], 

1553 b: Float32[Array, ''], 

1554) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]: 

1555 # in the following I use lamda[::-1] == 1 - lamda 

1556 theta = rho * lamda / lamda[::-1] 1abcdjkzCelw

1557 p = log_s.size 1abcdjkzCelw

1558 return ( 1abcdjkzCelw

1559 (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda) 

1560 + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda) 

1561 + gammaln(theta) 

1562 - p * gammaln(theta / p) 

1563 + theta / p * jnp.sum(log_s) 

1564 ), theta 

1565 

1566 

1567@named_call 

1568def step_sparse(key: Key[Array, ''], bart: State) -> State: 

1569 """ 

1570 Update the sparsity parameters. 

1571 

1572 This invokes `step_s`, and then `step_theta` only if the parameters of 

1573 the theta prior are defined. 

1574 

1575 Parameters 

1576 ---------- 

1577 key 

1578 Random key for sampling. 

1579 bart 

1580 The current BART state. 

1581 

1582 Returns 

1583 ------- 

1584 Updated BART state with re-sampled `log_s` and `theta`. 

1585 """ 

1586 if bart.config.sparse_on_at is not None: 1!0*3;a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST

1587 bart = lax.cond( 1;aLbxcMdNjPkQzZC1eOl2Yw

1588 bart.config.steps_done < bart.config.sparse_on_at, 

1589 lambda _key, bart: bart, 

1590 _step_sparse, 

1591 key, 

1592 bart, 

1593 ) 

1594 return bart 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST

1595 

1596 

1597def _step_sparse(key: Key[Array, ''], bart: State) -> State: 

1598 keys = split(key) 1;aLbxcMdNjPkQzZC1eOl2Yw

1599 bart = step_s(keys.pop(), bart) 1;aLbxcMdNjPkQzZC1eOl2Yw

1600 if bart.forest.rho is not None: 1;aLbxcMdNjPkQzZC1eOl2Yw

1601 bart = step_theta(keys.pop(), bart) 1abcdjkzCelw

1602 return bart 1;aLbxcMdNjPkQzZC1eOl2Yw

1603 

1604 

1605@named_call 

1606def step_config(bart: State) -> State: 

1607 config = bart.config 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST

1608 config = replace(config, steps_done=config.steps_done + 1) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST

1609 return replace(bart, config=config) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST