Coverage for src / bartz / mcmcstep / _moves.py: 100%

148 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-13 00:35 +0000

1# bartz/src/bartz/mcmcstep/_moves.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement `propose_moves` and associated dataclasses.""" 

26 

27from functools import partial 

28 

29import jax 

30from equinox import Module 

31from jax import numpy as jnp 

32from jax import random 

33from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt 

34 

35from bartz import grove 

36from bartz._profiler import jit_and_block_if_profiling 

37from bartz.jaxext import minimal_unsigned_dtype, split, vmap_nodoc 

38from bartz.mcmcstep._state import Forest, field, vmap_chains 

39 

40 

41class Moves(Module): 

42 """ 

43 Moves proposed to modify each tree. 

44 

45 Parameters 

46 ---------- 

47 allowed 

48 Whether there is a possible move. If `False`, the other values may not 

49 make sense. The only case in which a move is marked as allowed but is 

50 then vetoed is if it does not satisfy `min_points_per_leaf`, which for 

51 efficiency is implemented post-hoc without changing the rest of the 

52 MCMC logic. 

53 grow 

54 Whether the move is a grow move or a prune move. 

55 num_growable 

56 The number of growable leaves in the original tree. 

57 node 

58 The index of the leaf to grow or node to prune. 

59 left 

60 right 

61 The indices of the children of 'node'. 

62 partial_ratio 

63 A factor of the Metropolis-Hastings ratio of the move. It lacks the 

64 likelihood ratio, the probability of proposing the prune move, and the 

65 probability that the children of the modified node are terminal. If the 

66 move is PRUNE, the ratio is inverted. `None` once 

67 `log_trans_prior_ratio` has been computed. 

68 log_trans_prior_ratio 

69 The logarithm of the product of the transition and prior terms of the 

70 Metropolis-Hastings ratio for the acceptance of the proposed move. 

71 `None` if not yet computed. If PRUNE, the log-ratio is negated. 

72 grow_var 

73 The decision axes of the new rules. 

74 grow_split 

75 The decision boundaries of the new rules. 

76 var_tree 

77 The updated decision axes of the trees, valid whatever move. 

78 affluence_tree 

79 A partially updated `affluence_tree`, marking non-leaf nodes that would 

80 become leaves if the move was accepted. This mark initially (out of 

81 `propose_moves`) takes into account if there would be available decision 

82 rules to grow the leaf, and whether there are enough datapoints in the 

83 node is instead checked later in `accept_moves_parallel_stage`. 

84 logu 

85 The logarithm of a uniform (0, 1] random variable to be used to 

86 accept the move. It's in (-oo, 0]. 

87 acc 

88 Whether the move was accepted. `None` if not yet computed. 

89 to_prune 

90 Whether the final operation to apply the move is pruning. This indicates 

91 an accepted prune move or a rejected grow move. `None` if not yet 

92 computed. 

93 """ 

94 

95 allowed: Bool[Array, '*chains num_trees'] = field(chains=True) 

96 grow: Bool[Array, '*chains num_trees'] = field(chains=True) 

97 num_growable: UInt[Array, '*chains num_trees'] = field(chains=True) 

98 node: UInt[Array, '*chains num_trees'] = field(chains=True) 

99 left: UInt[Array, '*chains num_trees'] = field(chains=True) 

100 right: UInt[Array, '*chains num_trees'] = field(chains=True) 

101 partial_ratio: Float32[Array, '*chains num_trees'] | None = field(chains=True) 

102 log_trans_prior_ratio: None | Float32[Array, '*chains num_trees'] = field( 

103 chains=True 

104 ) 

105 grow_var: UInt[Array, '*chains num_trees'] = field(chains=True) 

106 grow_split: UInt[Array, '*chains num_trees'] = field(chains=True) 

107 var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

108 affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

109 logu: Float32[Array, '*chains num_trees'] = field(chains=True) 

110 acc: None | Bool[Array, '*chains num_trees'] = field(chains=True) 

111 to_prune: None | Bool[Array, '*chains num_trees'] = field(chains=True) 

112 

113 

114@partial(jit_and_block_if_profiling, donate_argnums=(1,)) 

115@vmap_chains 

116def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves: 

117 """ 

118 Propose moves for all the trees. 

119 

120 There are two types of moves: GROW (convert a leaf to a decision node and 

121 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a 

122 leaf, deleting its children). 

123 

124 Parameters 

125 ---------- 

126 key 

127 A jax random key. 

128 forest 

129 The `forest` field of a BART MCMC state. 

130 

131 Returns 

132 ------- 

133 The proposed move for each tree. 

134 """ 

135 num_trees = forest.leaf_tree.shape[0] 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

136 keys = split(key, 2) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

137 grow_keys, prune_keys = keys.pop((2, num_trees)) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

138 

139 # compute moves 

140 grow_moves = propose_grow_moves( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

141 grow_keys, 

142 forest.var_tree, 

143 forest.split_tree, 

144 forest.affluence_tree, 

145 forest.max_split, 

146 forest.blocked_vars, 

147 forest.p_nonterminal, 

148 forest.p_propose_grow, 

149 forest.log_s, 

150 ) 

151 prune_moves = propose_prune_moves( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

152 prune_keys, 

153 forest.split_tree, 

154 grow_moves.affluence_tree, 

155 forest.p_nonterminal, 

156 forest.p_propose_grow, 

157 ) 

158 

159 u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees)) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

160 

161 # choose between grow or prune 

162 p_grow = jnp.where( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

163 grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed 

164 ) 

165 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

166 

167 # compute children indices 

168 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

169 left, right = (node << 1) | jnp.arange(2)[:, None] 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

170 

171 return Moves( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

172 allowed=grow_moves.allowed | prune_moves.allowed, 

173 grow=grow, 

174 num_growable=grow_moves.num_growable, 

175 node=node, 

176 left=left, 

177 right=right, 

178 partial_ratio=jnp.where( 

179 grow, grow_moves.partial_ratio, prune_moves.partial_ratio 

180 ), 

181 log_trans_prior_ratio=None, # will be set in complete_ratio 

182 grow_var=grow_moves.var, 

183 grow_split=grow_moves.split, 

184 # var_tree does not need to be updated if prune 

185 var_tree=grow_moves.var_tree, 

186 # affluence_tree is updated for both moves unconditionally, prune last 

187 affluence_tree=prune_moves.affluence_tree, 

188 logu=jnp.log1p(-exp1mlogu), 

189 acc=None, # will be set in accept_moves_sequential_stage 

190 to_prune=None, # will be set in accept_moves_sequential_stage 

191 ) 

192 

193 

194class GrowMoves(Module): 

195 """ 

196 Represent a proposed grow move for each tree. 

197 

198 Parameters 

199 ---------- 

200 allowed 

201 Whether the move is allowed for proposal. 

202 num_growable 

203 The number of leaves that can be proposed for grow. 

204 node 

205 The index of the leaf to grow. ``2 ** d`` if there are no growable 

206 leaves. 

207 var 

208 split 

209 The decision axis and boundary of the new rule. 

210 partial_ratio 

211 A factor of the Metropolis-Hastings ratio of the move. It lacks 

212 the likelihood ratio and the probability of proposing the prune 

213 move. 

214 var_tree 

215 The updated decision axes of the tree. 

216 affluence_tree 

217 A partially updated `affluence_tree` that marks each new leaf that 

218 would be produced as `True` if it would have available decision rules. 

219 """ 

220 

221 allowed: Bool[Array, ' num_trees'] 

222 num_growable: UInt[Array, ' num_trees'] 

223 node: UInt[Array, ' num_trees'] 

224 var: UInt[Array, ' num_trees'] 

225 split: UInt[Array, ' num_trees'] 

226 partial_ratio: Float32[Array, ' num_trees'] 

227 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 

228 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 

229 

230 

231@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None)) 

232def propose_grow_moves( 

233 key: Key[Array, ' num_trees'], 

234 var_tree: UInt[Array, 'num_trees 2**(d-1)'], 

235 split_tree: UInt[Array, 'num_trees 2**(d-1)'], 

236 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'], 

237 max_split: UInt[Array, ' p'], 

238 blocked_vars: Int32[Array, ' k'] | None, 

239 p_nonterminal: Float32[Array, ' 2**d'], 

240 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

241 log_s: Float32[Array, ' p'] | None, 

242) -> GrowMoves: 

243 """ 

244 Propose a GROW move for each tree. 

245 

246 A GROW move picks a leaf node and converts it to a non-terminal node with 

247 two leaf children. 

248 

249 Parameters 

250 ---------- 

251 key 

252 A jax random key. 

253 var_tree 

254 The splitting axes of the tree. 

255 split_tree 

256 The splitting points of the tree. 

257 affluence_tree 

258 Whether each leaf has enough points to be grown. 

259 max_split 

260 The maximum split index for each variable. 

261 blocked_vars 

262 The indices of the variables that have no available cutpoints. 

263 p_nonterminal 

264 The a priori probability of a node to be nonterminal conditional on the 

265 ancestors, including at the maximum depth where it should be zero. 

266 p_propose_grow 

267 The unnormalized probability of choosing a leaf to grow. 

268 log_s 

269 Unnormalized log-probability used to choose a variable to split on 

270 amongst the available ones. 

271 

272 Returns 

273 ------- 

274 An object representing the proposed move. 

275 

276 Notes 

277 ----- 

278 The move is not proposed if each leaf is already at maximum depth, or has 

279 less datapoints than the requested threshold `min_points_per_decision_node`, 

280 or it does not have any available decision rules given its ancestors. This 

281 is marked by setting `allowed` to `False` and `num_growable` to 0. 

282 """ 

283 keys = split(key, 3) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

284 

285 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

286 keys.pop(), split_tree, affluence_tree, p_propose_grow 

287 ) 

288 

289 # sample a decision rule 

290 var, num_available_var = choose_variable( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

291 keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s 

292 ) 

293 split_idx, l, r = choose_split( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

294 keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow 

295 ) 

296 

297 # determine if the new leaves would have available decision rules; if the 

298 # move is blocked, these values may not make sense 

299 leftright_growable = (num_available_var > 1) | jnp.stack( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

300 [l < split_idx, split_idx + 1 < r] 

301 ) 

302 leftright = (leaf_to_grow << 1) | jnp.arange(2) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

303 affluence_tree = affluence_tree.at[leftright].set(leftright_growable) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

304 

305 ratio = compute_partial_ratio( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

306 prob_choose, num_prunable, p_nonterminal, leaf_to_grow 

307 ) 

308 

309 return GrowMoves( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

310 allowed=num_growable > 0, 

311 num_growable=num_growable, 

312 node=leaf_to_grow, 

313 var=var, 

314 split=split_idx, 

315 partial_ratio=ratio, 

316 var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)), 

317 affluence_tree=affluence_tree, 

318 ) 

319 

320 

321def choose_leaf( 

322 key: Key[Array, ''], 

323 split_tree: UInt[Array, ' 2**(d-1)'], 

324 affluence_tree: Bool[Array, ' 2**(d-1)'], 

325 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

326) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]: 

327 """ 

328 Choose a leaf node to grow in a tree. 

329 

330 Parameters 

331 ---------- 

332 key 

333 A jax random key. 

334 split_tree 

335 The splitting points of the tree. 

336 affluence_tree 

337 Whether a leaf has enough points that it could be split into two leaves 

338 satisfying the `min_points_per_decision_node` requirement. 

339 p_propose_grow 

340 The unnormalized probability of choosing a leaf to grow. 

341 

342 Returns 

343 ------- 

344 leaf_to_grow : Int32[Array, ''] 

345 The index of the leaf to grow. If ``num_growable == 0``, return 

346 ``2 ** d``. 

347 num_growable : Int32[Array, ''] 

348 The number of leaf nodes that can be grown, i.e., are nonterminal 

349 and have at least twice `min_points_per_decision_node`. 

350 prob_choose : Float32[Array, ''] 

351 The (normalized) probability that this function had to choose that 

352 specific leaf, given the arguments. 

353 num_prunable : Int32[Array, ''] 

354 The number of leaf parents that could be pruned, after converting the 

355 selected leaf to a non-terminal node. 

356 """ 

357 is_growable = growable_leaves(split_tree, affluence_tree) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

358 num_growable = jnp.count_nonzero(is_growable) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

359 distr = jnp.where(is_growable, p_propose_grow, 0) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

360 leaf_to_grow, distr_norm = categorical(key, distr) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

361 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

362 prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

363 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

364 num_prunable = jnp.count_nonzero(is_parent) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

365 return leaf_to_grow, num_growable, prob_choose, num_prunable 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

366 

367 

368def growable_leaves( 

369 split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)'] 

370) -> Bool[Array, ' 2**(d-1)']: 

371 """ 

372 Return a mask indicating the leaf nodes that can be proposed for growth. 

373 

374 The condition is that a leaf is not at the bottom level, has available 

375 decision rules given its ancestors, and has at least 

376 `min_points_per_decision_node` points. 

377 

378 Parameters 

379 ---------- 

380 split_tree 

381 The splitting points of the tree. 

382 affluence_tree 

383 Marks leaves that can be grown. 

384 

385 Returns 

386 ------- 

387 The mask indicating the leaf nodes that can be proposed to grow. 

388 

389 Notes 

390 ----- 

391 This function needs `split_tree` and not just `affluence_tree` because 

392 `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`. 

393 """ 

394 return grove.is_actual_leaf(split_tree) & affluence_tree 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

395 

396 

397def categorical( 

398 key: Key[Array, ''], distr: Float32[Array, ' n'] 

399) -> tuple[Int32[Array, ''], Float32[Array, '']]: 

400 """ 

401 Return a random integer from an arbitrary distribution. 

402 

403 Parameters 

404 ---------- 

405 key 

406 A jax random key. 

407 distr 

408 An unnormalized probability distribution. 

409 

410 Returns 

411 ------- 

412 u : Int32[Array, ''] 

413 A random integer in the range ``[0, n)``. If all probabilities are zero, 

414 return ``n``. 

415 norm : Float32[Array, ''] 

416 The sum of `distr`. 

417 

418 Notes 

419 ----- 

420 This function uses a cumsum instead of the Gumbel trick, so it's ok only 

421 for small ranges with probabilities well greater than 0. 

422 """ 

423 ecdf = jnp.cumsum(distr) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

424 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

425 return jnp.searchsorted(ecdf, u, 'right', method='compare_all'), ecdf[-1] 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

426 

427 

428def choose_variable( 

429 key: Key[Array, ''], 

430 var_tree: UInt[Array, ' 2**(d-1)'], 

431 split_tree: UInt[Array, ' 2**(d-1)'], 

432 max_split: UInt[Array, ' p'], 

433 leaf_index: Int32[Array, ''], 

434 blocked_vars: Int32[Array, ' k'] | None, 

435 log_s: Float32[Array, ' p'] | None, 

436) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

437 """ 

438 Choose a variable to split on for a new non-terminal node. 

439 

440 Parameters 

441 ---------- 

442 key 

443 A jax random key. 

444 var_tree 

445 The variable indices of the tree. 

446 split_tree 

447 The splitting points of the tree. 

448 max_split 

449 The maximum split index for each variable. 

450 leaf_index 

451 The index of the leaf to grow. 

452 blocked_vars 

453 The indices of the variables that have no available cutpoints. If 

454 `None`, all variables are assumed unblocked. 

455 log_s 

456 The logarithm of the prior probability for choosing a variable. If 

457 `None`, use a uniform distribution. 

458 

459 Returns 

460 ------- 

461 var : Int32[Array, ''] 

462 The index of the variable to split on. 

463 num_available_var : Int32[Array, ''] 

464 The number of variables with available decision rules `var` was chosen 

465 from. 

466 """ 

467 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

468 if blocked_vars is not None: 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

469 var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars]) 1zaABbCDcEGHIJKdLMeNfghOiPQjRSTklmnopqrstuvwxy

470 

471 if log_s is None: 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

472 return randint_exclude(key, max_split.size, var_to_ignore) 1abcFdefghijkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

473 else: 

474 return categorical_exclude(key, log_s, var_to_ignore) 1zABCDE+,GHIJKLMNOPQRST

475 

476 

477def fully_used_variables( 

478 var_tree: UInt[Array, ' 2**(d-1)'], 

479 split_tree: UInt[Array, ' 2**(d-1)'], 

480 max_split: UInt[Array, ' p'], 

481 leaf_index: Int32[Array, ''], 

482) -> UInt[Array, ' d-2']: 

483 """ 

484 Find variables in the ancestors of a node that have an empty split range. 

485 

486 Parameters 

487 ---------- 

488 var_tree 

489 The variable indices of the tree. 

490 split_tree 

491 The splitting points of the tree. 

492 max_split 

493 The maximum split index for each variable. 

494 leaf_index 

495 The index of the node, assumed to be valid for `var_tree`. 

496 

497 Returns 

498 ------- 

499 The indices of the variables that have an empty split range. 

500 

501 Notes 

502 ----- 

503 The number of unused variables is not known in advance. Unused values in the 

504 array are filled with `p`. The fill values are not guaranteed to be placed 

505 in any particular order, and variables may appear more than once. 

506 """ 

507 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

508 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

509 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

510 num_split = r - l 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

511 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

512 # the type of var_to_ignore is already sufficient to hold max_split.size, 

513 # see ancestor_variables() 

514 

515 

516def ancestor_variables( 

517 var_tree: UInt[Array, ' 2**(d-1)'], 

518 max_split: UInt[Array, ' p'], 

519 node_index: Int32[Array, ''], 

520) -> UInt[Array, ' d-2']: 

521 """ 

522 Return the list of variables in the ancestors of a node. 

523 

524 Parameters 

525 ---------- 

526 var_tree 

527 The variable indices of the tree. 

528 max_split 

529 The maximum split index for each variable. Used only to get `p`. 

530 node_index 

531 The index of the node, assumed to be valid for `var_tree`. 

532 

533 Returns 

534 ------- 

535 The variable indices of the ancestors of the node. 

536 

537 Notes 

538 ----- 

539 The ancestors are the nodes going from the root to the parent of the node. 

540 The number of ancestors is not known at tracing time; unused spots in the 

541 output array are filled with `p`. 

542 """ 

543 max_num_ancestors = grove.tree_depth(var_tree) - 1 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 bbcbdbebfbgbl m n o p q r s t u v w x y ! # $ % ' ( ) *

544 index = node_index >> jnp.arange(max_num_ancestors, 0, -1) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 bbcbdbebfbgbl m n o p q r s t u v w x y ! # $ % ' ( ) *

545 var = var_tree[index] 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 bbcbdbebfbgbl m n o p q r s t u v w x y ! # $ % ' ( ) *

546 var_type = minimal_unsigned_dtype(max_split.size) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 bbcbdbebfbgbl m n o p q r s t u v w x y ! # $ % ' ( ) *

547 p = jnp.array(max_split.size, var_type) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 bbcbdbebfbgbl m n o p q r s t u v w x y ! # $ % ' ( ) *

548 return jnp.where(index, var, p) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 bbcbdbebfbgbl m n o p q r s t u v w x y ! # $ % ' ( ) *

549 

550 

551def split_range( 

552 var_tree: UInt[Array, ' 2**(d-1)'], 

553 split_tree: UInt[Array, ' 2**(d-1)'], 

554 max_split: UInt[Array, ' p'], 

555 node_index: Int32[Array, ''], 

556 ref_var: Int32[Array, ''], 

557) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

558 """ 

559 Return the range of allowed splits for a variable at a given node. 

560 

561 Parameters 

562 ---------- 

563 var_tree 

564 The variable indices of the tree. 

565 split_tree 

566 The splitting points of the tree. 

567 max_split 

568 The maximum split index for each variable. 

569 node_index 

570 The index of the node, assumed to be valid for `var_tree`. 

571 ref_var 

572 The variable for which to measure the split range. 

573 

574 Returns 

575 ------- 

576 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1. 

577 """ 

578 max_num_ancestors = grove.tree_depth(var_tree) - 1 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*

579 index = node_index >> jnp.arange(max_num_ancestors) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*

580 right_child = (index & 1).astype(bool) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*

581 index >>= 1 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*

582 split = split_tree[index].astype(jnp.int32) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*

583 cond = (var_tree[index] == ref_var) & index.astype(bool) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*

584 l = jnp.max(split, initial=0, where=cond & right_child) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*

585 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*

586 jnp.int32 

587 ) 

588 r = jnp.min(split, initial=initial_r, where=cond & ~right_child) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*

589 

590 return l + 1, r 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*

591 

592 

593def randint_exclude( 

594 key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n'] 

595) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

596 """ 

597 Return a random integer in a range, excluding some values. 

598 

599 Parameters 

600 ---------- 

601 key 

602 A jax random key. 

603 sup 

604 The exclusive upper bound of the range, must be >= 1. 

605 exclude 

606 The values to exclude from the range. Values greater than or equal to 

607 `sup` are ignored. Values can appear more than once. 

608 

609 Returns 

610 ------- 

611 u : Int32[Array, ''] 

612 A random integer `u` in the range ``[0, sup)`` such that ``u not in 

613 exclude``. 

614 num_allowed : Int32[Array, ''] 

615 The number of integers in the range that were not excluded. 

616 

617 Notes 

618 ----- 

619 If all values in the range are excluded, return `sup`. 

620 """ 

621 exclude, num_allowed = _process_exclude(sup, exclude) 2a b c F d e f g h i j k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *

622 u = random.randint(key, (), 0, num_allowed) 2a b c F d e f g h i j k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *

623 u_shifted = u + jnp.arange(exclude.size) 2a b c F d e f g h i j k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *

624 u_shifted = jnp.minimum(u_shifted, sup - 1) 2a b c F d e f g h i j k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *

625 u += jnp.sum(u_shifted >= exclude) 2a b c F d e f g h i j k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *

626 return u, num_allowed 2a b c F d e f g h i j k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *

627 

628 

629def _process_exclude(sup, exclude): 

630 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *

631 num_allowed = sup - jnp.sum(exclude < sup) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *

632 return exclude, num_allowed 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *

633 

634 

635def categorical_exclude( 

636 key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n'] 

637) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

638 """ 

639 Draw from a categorical distribution, excluding a set of values. 

640 

641 Parameters 

642 ---------- 

643 key 

644 A jax random key. 

645 logits 

646 The unnormalized log-probabilities of each category. 

647 exclude 

648 The values to exclude from the range [0, k). Values greater than or 

649 equal to `logits.size` are ignored. Values can appear more than once. 

650 

651 Returns 

652 ------- 

653 u : Int32[Array, ''] 

654 A random integer in the range ``[0, k)`` such that ``u not in exclude``. 

655 num_allowed : Int32[Array, ''] 

656 The number of integers in the range that were not excluded. 

657 

658 Notes 

659 ----- 

660 If all values in the range are excluded, the result is unspecified. 

661 """ 

662 exclude, num_allowed = _process_exclude(logits.size, exclude) 1zABCDE+,GHIJKLMNOPQRST

663 kinda_neg_inf = jnp.finfo(logits.dtype).min 1zABCDE+,GHIJKLMNOPQRST

664 logits = logits.at[exclude].set(kinda_neg_inf) 1zABCDE+,GHIJKLMNOPQRST

665 u = random.categorical(key, logits) 1zABCDE+,GHIJKLMNOPQRST

666 return u, num_allowed 1zABCDE+,GHIJKLMNOPQRST

667 

668 

669def choose_split( 

670 key: Key[Array, ''], 

671 var: Int32[Array, ''], 

672 var_tree: UInt[Array, ' 2**(d-1)'], 

673 split_tree: UInt[Array, ' 2**(d-1)'], 

674 max_split: UInt[Array, ' p'], 

675 leaf_index: Int32[Array, ''], 

676) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]: 

677 """ 

678 Choose a split point for a new non-terminal node. 

679 

680 Parameters 

681 ---------- 

682 key 

683 A jax random key. 

684 var 

685 The variable to split on. 

686 var_tree 

687 The splitting axes of the tree. Does not need to already contain `var` 

688 at `leaf_index`. 

689 split_tree 

690 The splitting points of the tree. 

691 max_split 

692 The maximum split index for each variable. 

693 leaf_index 

694 The index of the leaf to grow. 

695 

696 Returns 

697 ------- 

698 split : Int32[Array, ''] 

699 The cutpoint. 

700 l : Int32[Array, ''] 

701 r : Int32[Array, ''] 

702 The integer range `split` was drawn from is [l, r). 

703 

704 Notes 

705 ----- 

706 If `var` is out of bounds, or if the available split range on that variable 

707 is empty, return 0. 

708 """ 

709 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

710 return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

711 

712 

713def compute_partial_ratio( 

714 prob_choose: Float32[Array, ''], 

715 num_prunable: Int32[Array, ''], 

716 p_nonterminal: Float32[Array, ' 2**d'], 

717 leaf_to_grow: Int32[Array, ''], 

718) -> Float32[Array, '']: 

719 """ 

720 Compute the product of the transition and prior ratios of a grow move. 

721 

722 Parameters 

723 ---------- 

724 prob_choose 

725 The probability that the leaf had to be chosen amongst the growable 

726 leaves. 

727 num_prunable 

728 The number of leaf parents that could be pruned, after converting the 

729 leaf to be grown to a non-terminal node. 

730 p_nonterminal 

731 The a priori probability of each node being nonterminal conditional on 

732 its ancestors. 

733 leaf_to_grow 

734 The index of the leaf to grow. 

735 

736 Returns 

737 ------- 

738 The partial transition ratio times the prior ratio. 

739 

740 Notes 

741 ----- 

742 The transition ratio is P(new tree => old tree) / P(old tree => new tree). 

743 The "partial" transition ratio returned is missing the factor P(propose 

744 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The 

745 "partial" prior ratio is missing the factor P(children are leaves). 

746 """ 

747 # the two ratios also contain factors num_available_split * 

748 # num_available_var * s[var], but they cancel out 

749 

750 # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be 

751 # computed here because they need the count trees, which are computed in the 

752 # acceptance phase 

753 

754 prune_allowed = leaf_to_grow != 1 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

755 # prune allowed <---> the initial tree is not a root 

756 # leaf to grow is root --> the tree can only be a root 

757 # tree is a root --> the only leaf I can grow is root 

758 p_grow = jnp.where(prune_allowed, 0.5, 1) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

759 inv_trans_ratio = p_grow * prob_choose * num_prunable 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

760 

761 # .at.get because if leaf_to_grow is out of bounds (move not allowed), this 

762 # would produce a 0 and then an inf when `complete_ratio` takes the log 

763 pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

764 tree_ratio = pnt / (1 - pnt) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

765 

766 return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

767 

768 

769class PruneMoves(Module): 

770 """ 

771 Represent a proposed prune move for each tree. 

772 

773 Parameters 

774 ---------- 

775 allowed 

776 Whether the move is possible. 

777 node 

778 The index of the node to prune. ``2 ** d`` if no node can be pruned. 

779 partial_ratio 

780 A factor of the Metropolis-Hastings ratio of the move. It lacks the 

781 likelihood ratio, the probability of proposing the prune move, and the 

782 prior probability that the children of the node to prune are leaves. 

783 This ratio is inverted, and is meant to be inverted back in 

784 `accept_move_and_sample_leaves`. 

785 affluence_tree 

786 A partially updated `affluence_tree`, marking the node to prune as 

787 growable. 

788 """ 

789 

790 allowed: Bool[Array, ' num_trees'] 

791 node: UInt[Array, ' num_trees'] 

792 partial_ratio: Float32[Array, ' num_trees'] 

793 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 

794 

795 

796@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) 

797def propose_prune_moves( 

798 key: Key[Array, ''], 

799 split_tree: UInt[Array, ' 2**(d-1)'], 

800 affluence_tree: Bool[Array, ' 2**(d-1)'], 

801 p_nonterminal: Float32[Array, ' 2**d'], 

802 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

803) -> PruneMoves: 

804 """ 

805 Tree structure prune move proposal of BART MCMC. 

806 

807 Parameters 

808 ---------- 

809 key 

810 A jax random key. 

811 split_tree 

812 The splitting points of the tree. 

813 affluence_tree 

814 Whether each leaf can be grown. 

815 p_nonterminal 

816 The a priori probability of a node to be nonterminal conditional on 

817 the ancestors, including at the maximum depth where it should be zero. 

818 p_propose_grow 

819 The unnormalized probability of choosing a leaf to grow. 

820 

821 Returns 

822 ------- 

823 An object representing the proposed moves. 

824 """ 

825 node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

826 key, split_tree, affluence_tree, p_propose_grow 

827 ) 

828 

829 ratio = compute_partial_ratio( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

830 prob_choose, num_prunable, p_nonterminal, node_to_prune 

831 ) 

832 

833 return PruneMoves( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

834 allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root 

835 node=node_to_prune, 

836 partial_ratio=ratio, 

837 affluence_tree=affluence_tree, 

838 ) 

839 

840 

841def choose_leaf_parent( 

842 key: Key[Array, ''], 

843 split_tree: UInt[Array, ' 2**(d-1)'], 

844 affluence_tree: Bool[Array, ' 2**(d-1)'], 

845 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

846) -> tuple[ 

847 Int32[Array, ''], 

848 Int32[Array, ''], 

849 Float32[Array, ''], 

850 Bool[Array, 'num_trees 2**(d-1)'], 

851]: 

852 """ 

853 Pick a non-terminal node with leaf children to prune in a tree. 

854 

855 Parameters 

856 ---------- 

857 key 

858 A jax random key. 

859 split_tree 

860 The splitting points of the tree. 

861 affluence_tree 

862 Whether a leaf has enough points to be grown. 

863 p_propose_grow 

864 The unnormalized probability of choosing a leaf to grow. 

865 

866 Returns 

867 ------- 

868 node_to_prune : Int32[Array, ''] 

869 The index of the node to prune. If ``num_prunable == 0``, return 

870 ``2 ** d``. 

871 num_prunable : Int32[Array, ''] 

872 The number of leaf parents that could be pruned. 

873 prob_choose : Float32[Array, ''] 

874 The (normalized) probability that `choose_leaf` would chose 

875 `node_to_prune` as leaf to grow, if passed the tree where 

876 `node_to_prune` had been pruned. 

877 affluence_tree : Bool[Array, 'num_trees 2**(d-1)'] 

878 A partially updated `affluence_tree`, marking the node to prune as 

879 growable. 

880 """ 

881 # sample a node to prune 

882 is_prunable = grove.is_leaves_parent(split_tree) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

883 num_prunable = jnp.count_nonzero(is_prunable) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

884 node_to_prune = randint_masked(key, is_prunable) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

885 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

886 

887 # compute stuff for reverse move 

888 split_tree = split_tree.at[node_to_prune].set(0) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

889 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

890 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

891 distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

892 prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

893 prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

894 

895 return node_to_prune, num_prunable, prob_choose, affluence_tree 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*

896 

897 

898def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']: 

899 """ 

900 Return a random integer in a range, including only some values. 

901 

902 Parameters 

903 ---------- 

904 key 

905 A jax random key. 

906 mask 

907 The mask indicating the allowed values. 

908 

909 Returns 

910 ------- 

911 A random integer in the range ``[0, n)`` such that ``mask[u] == True``. 

912 

913 Notes 

914 ----- 

915 If all values in the mask are `False`, return `n`. This function is 

916 optimized for small `n`. 

917 """ 

918 ecdf = jnp.cumsum(mask) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k hbU V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y ibjbkblb! # $ % ' ( ) *

919 u = random.randint(key, (), 0, ecdf[-1]) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k hbU V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y ibjbkblb! # $ % ' ( ) *

920 return jnp.searchsorted(ecdf, u, 'right', method='compare_all') 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k hbU V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y ibjbkblb! # $ % ' ( ) *