Coverage for src / bartz / mcmcstep / _moves.py: 100%

175 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/mcmcstep/_moves.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement `propose_moves` and associated dataclasses.""" 

26 

27from functools import partial 

28 

29import jax 

30from equinox import Module 

31from jax import named_call, random 

32from jax import numpy as jnp 

33from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt 

34 

35from bartz import grove 

36from bartz.jaxext import minimal_unsigned_dtype, split, vmap_nodoc 

37from bartz.mcmcstep._state import Forest, field 

38 

39 

40class Moves(Module): 

41 """Moves proposed to modify each tree.""" 

42 

43 allowed: Bool[Array, '*chains num_trees'] = field(chains=True) 

44 """Whether there is a possible move. If `False`, the other values may not 

45 make sense. The only case in which a move is marked as allowed but is 

46 then vetoed is if it does not satisfy `min_points_per_leaf`, which for 

47 efficiency is implemented post-hoc without changing the rest of the 

48 MCMC logic.""" 

49 

50 grow: Bool[Array, '*chains num_trees'] = field(chains=True) 

51 """Whether the move is a grow move or a prune move.""" 

52 

53 num_growable: UInt[Array, '*chains num_trees'] = field(chains=True) 

54 """The number of growable leaves in the original tree.""" 

55 

56 node: UInt[Array, '*chains num_trees'] = field(chains=True) 

57 """The index of the leaf to grow or node to prune.""" 

58 

59 left: UInt[Array, '*chains num_trees'] = field(chains=True) 

60 """The index of the left child of 'node'.""" 

61 

62 right: UInt[Array, '*chains num_trees'] = field(chains=True) 

63 """The index of the right child of 'node'.""" 

64 

65 partial_ratio: Float32[Array, '*chains num_trees'] | None = field(chains=True) 

66 """A factor of the Metropolis-Hastings ratio of the move. It lacks the 

67 likelihood ratio, the probability of proposing the prune move, and the 

68 probability that the children of the modified node are terminal. If the 

69 move is PRUNE, the ratio is inverted. `None` once 

70 `log_trans_prior_ratio` has been computed.""" 

71 

72 log_trans_prior_ratio: None | Float32[Array, '*chains num_trees'] = field( 

73 chains=True 

74 ) 

75 """The logarithm of the product of the transition and prior terms of the 

76 Metropolis-Hastings ratio for the acceptance of the proposed move. 

77 `None` if not yet computed. If PRUNE, the log-ratio is negated.""" 

78 

79 grow_var: UInt[Array, '*chains num_trees'] = field(chains=True) 

80 """The decision axes of the new rules.""" 

81 

82 grow_split: UInt[Array, '*chains num_trees'] = field(chains=True) 

83 """The decision boundaries of the new rules.""" 

84 

85 var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

86 """The updated decision axes of the trees, valid whatever move.""" 

87 

88 affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

89 """A partially updated `affluence_tree`, marking non-leaf nodes that would 

90 become leaves if the move was accepted. This mark initially (out of 

91 `propose_moves`) takes into account if there would be available decision 

92 rules to grow the leaf, and whether there are enough datapoints in the 

93 node is instead checked later in `accept_moves_parallel_stage`.""" 

94 

95 logu: Float32[Array, '*chains num_trees'] = field(chains=True) 

96 """The logarithm of a uniform (0, 1] random variable to be used to 

97 accept the move. It's in (-oo, 0].""" 

98 

99 acc: None | Bool[Array, '*chains num_trees'] = field(chains=True) 

100 """Whether the move was accepted. `None` if not yet computed.""" 

101 

102 to_prune: None | Bool[Array, '*chains num_trees'] = field(chains=True) 

103 """Whether the final operation to apply the move is pruning. This indicates 

104 an accepted prune move or a rejected grow move. `None` if not yet 

105 computed.""" 

106 

107 

108@named_call 

109def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves: 

110 """ 

111 Propose moves for all the trees. 

112 

113 There are two types of moves: GROW (convert a leaf to a decision node and 

114 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a 

115 leaf, deleting its children). 

116 

117 Parameters 

118 ---------- 

119 key 

120 A jax random key. 

121 forest 

122 The `forest` field of a BART MCMC state. 

123 

124 Returns 

125 ------- 

126 The proposed move for each tree. 

127 """ 

128 num_trees = forest.leaf_tree.shape[0] 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

129 keys = split(key, 2) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

130 grow_keys, prune_keys = keys.pop((2, num_trees)) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

131 

132 # compute moves 

133 grow_moves = propose_grow_moves( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

134 grow_keys, 

135 forest.var_tree, 

136 forest.split_tree, 

137 forest.affluence_tree, 

138 forest.max_split, 

139 forest.blocked_vars, 

140 forest.p_nonterminal, 

141 forest.p_propose_grow, 

142 forest.log_s, 

143 ) 

144 prune_moves = propose_prune_moves( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

145 prune_keys, 

146 forest.split_tree, 

147 grow_moves.affluence_tree, 

148 forest.p_nonterminal, 

149 forest.p_propose_grow, 

150 ) 

151 

152 u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees)) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

153 

154 # choose between grow or prune 

155 p_grow = jnp.where( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

156 grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed 

157 ) 

158 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

159 

160 # compute children indices 

161 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

162 left, right = (node << 1) | jnp.arange(2)[:, None] 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

163 

164 return Moves( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

165 allowed=grow_moves.allowed | prune_moves.allowed, 

166 grow=grow, 

167 num_growable=grow_moves.num_growable, 

168 node=node, 

169 left=left, 

170 right=right, 

171 partial_ratio=jnp.where( 

172 grow, grow_moves.partial_ratio, prune_moves.partial_ratio 

173 ), 

174 log_trans_prior_ratio=None, # will be set in complete_ratio 

175 grow_var=grow_moves.var, 

176 grow_split=grow_moves.split, 

177 # var_tree does not need to be updated if prune 

178 var_tree=grow_moves.var_tree, 

179 # affluence_tree is updated for both moves unconditionally, prune last 

180 affluence_tree=prune_moves.affluence_tree, 

181 logu=jnp.log1p(-exp1mlogu), 

182 acc=None, # will be set in accept_moves_sequential_stage 

183 to_prune=None, # will be set in accept_moves_sequential_stage 

184 ) 

185 

186 

187class GrowMoves(Module): 

188 """Represent a proposed grow move for each tree.""" 

189 

190 allowed: Bool[Array, ' num_trees'] 

191 """Whether the move is allowed for proposal.""" 

192 

193 num_growable: UInt[Array, ' num_trees'] 

194 """The number of leaves that can be proposed for grow.""" 

195 

196 node: UInt[Array, ' num_trees'] 

197 """The index of the leaf to grow. ``2 ** d`` if there are no growable 

198 leaves.""" 

199 

200 var: UInt[Array, ' num_trees'] 

201 """The decision axis of the new rule.""" 

202 

203 split: UInt[Array, ' num_trees'] 

204 """The decision boundary of the new rule.""" 

205 

206 partial_ratio: Float32[Array, ' num_trees'] 

207 """A factor of the Metropolis-Hastings ratio of the move. It lacks 

208 the likelihood ratio and the probability of proposing the prune 

209 move.""" 

210 

211 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 

212 """The updated decision axes of the tree.""" 

213 

214 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 

215 """A partially updated `affluence_tree` that marks each new leaf that 

216 would be produced as `True` if it would have available decision rules.""" 

217 

218 

219@named_call 

220@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None)) 

221def propose_grow_moves( 

222 key: Key[Array, ' num_trees'], 

223 var_tree: UInt[Array, 'num_trees 2**(d-1)'], 

224 split_tree: UInt[Array, 'num_trees 2**(d-1)'], 

225 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'], 

226 max_split: UInt[Array, ' p'], 

227 blocked_vars: Int32[Array, ' k'] | None, 

228 p_nonterminal: Float32[Array, ' 2**d'], 

229 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

230 log_s: Float32[Array, ' p'] | None, 

231) -> GrowMoves: 

232 """ 

233 Propose a GROW move for each tree. 

234 

235 A GROW move picks a leaf node and converts it to a non-terminal node with 

236 two leaf children. 

237 

238 Parameters 

239 ---------- 

240 key 

241 A jax random key. 

242 var_tree 

243 The splitting axes of the tree. 

244 split_tree 

245 The splitting points of the tree. 

246 affluence_tree 

247 Whether each leaf has enough points to be grown. 

248 max_split 

249 The maximum split index for each variable. 

250 blocked_vars 

251 The indices of the variables that have no available cutpoints. 

252 p_nonterminal 

253 The a priori probability of a node to be nonterminal conditional on the 

254 ancestors, including at the maximum depth where it should be zero. 

255 p_propose_grow 

256 The unnormalized probability of choosing a leaf to grow. 

257 log_s 

258 Unnormalized log-probability used to choose a variable to split on 

259 amongst the available ones. 

260 

261 Returns 

262 ------- 

263 An object representing the proposed move. 

264 

265 Notes 

266 ----- 

267 The move is not proposed if each leaf is already at maximum depth, or has 

268 less datapoints than the requested threshold `min_points_per_decision_node`, 

269 or it does not have any available decision rules given its ancestors. This 

270 is marked by setting `allowed` to `False` and `num_growable` to 0. 

271 """ 

272 keys = split(key, 3) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

273 

274 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

275 keys.pop(), split_tree, affluence_tree, p_propose_grow 

276 ) 

277 

278 # sample a decision rule 

279 var, num_available_var = choose_variable( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

280 keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s 

281 ) 

282 split_idx, l, r = choose_split( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

283 keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow 

284 ) 

285 

286 # determine if the new leaves would have available decision rules; if the 

287 # move is blocked, these values may not make sense 

288 leftright_growable = (num_available_var > 1) | jnp.stack( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

289 [l < split_idx, split_idx + 1 < r] 

290 ) 

291 leftright = (leaf_to_grow << 1) | jnp.arange(2) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

292 affluence_tree = affluence_tree.at[leftright].set(leftright_growable) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

293 

294 ratio = compute_partial_ratio( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

295 prob_choose, num_prunable, p_nonterminal, leaf_to_grow 

296 ) 

297 

298 return GrowMoves( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

299 allowed=num_growable > 0, 

300 num_growable=num_growable, 

301 node=leaf_to_grow, 

302 var=var, 

303 split=split_idx, 

304 partial_ratio=ratio, 

305 var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)), 

306 affluence_tree=affluence_tree, 

307 ) 

308 

309 

310def choose_leaf( 

311 key: Key[Array, ''], 

312 split_tree: UInt[Array, ' 2**(d-1)'], 

313 affluence_tree: Bool[Array, ' 2**(d-1)'], 

314 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

315) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]: 

316 """ 

317 Choose a leaf node to grow in a tree. 

318 

319 Parameters 

320 ---------- 

321 key 

322 A jax random key. 

323 split_tree 

324 The splitting points of the tree. 

325 affluence_tree 

326 Whether a leaf has enough points that it could be split into two leaves 

327 satisfying the `min_points_per_decision_node` requirement. 

328 p_propose_grow 

329 The unnormalized probability of choosing a leaf to grow. 

330 

331 Returns 

332 ------- 

333 leaf_to_grow : Int32[Array, ''] 

334 The index of the leaf to grow. If ``num_growable == 0``, return 

335 ``2 ** d``. 

336 num_growable : Int32[Array, ''] 

337 The number of leaf nodes that can be grown, i.e., are nonterminal 

338 and have at least twice `min_points_per_decision_node`. 

339 prob_choose : Float32[Array, ''] 

340 The (normalized) probability that this function had to choose that 

341 specific leaf, given the arguments. 

342 num_prunable : Int32[Array, ''] 

343 The number of leaf parents that could be pruned, after converting the 

344 selected leaf to a non-terminal node. 

345 """ 

346 is_growable = growable_leaves(split_tree, affluence_tree) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

347 num_growable = jnp.count_nonzero(is_growable) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

348 distr = jnp.where(is_growable, p_propose_grow, 0) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

349 leaf_to_grow, distr_norm = categorical(key, distr) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

350 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

351 prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

352 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

353 num_prunable = jnp.count_nonzero(is_parent) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

354 return leaf_to_grow, num_growable, prob_choose, num_prunable 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

355 

356 

357def growable_leaves( 

358 split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)'] 

359) -> Bool[Array, ' 2**(d-1)']: 

360 """ 

361 Return a mask indicating the leaf nodes that can be proposed for growth. 

362 

363 The condition is that a leaf is not at the bottom level, has available 

364 decision rules given its ancestors, and has at least 

365 `min_points_per_decision_node` points. 

366 

367 Parameters 

368 ---------- 

369 split_tree 

370 The splitting points of the tree. 

371 affluence_tree 

372 Marks leaves that can be grown. 

373 

374 Returns 

375 ------- 

376 The mask indicating the leaf nodes that can be proposed to grow. 

377 

378 Notes 

379 ----- 

380 This function needs `split_tree` and not just `affluence_tree` because 

381 `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`. 

382 """ 

383 return grove.is_actual_leaf(split_tree) & affluence_tree 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

384 

385 

386def categorical( 

387 key: Key[Array, ''], distr: Float32[Array, ' n'] 

388) -> tuple[Int32[Array, ''], Float32[Array, '']]: 

389 """ 

390 Return a random integer from an arbitrary distribution. 

391 

392 Parameters 

393 ---------- 

394 key 

395 A jax random key. 

396 distr 

397 An unnormalized probability distribution. 

398 

399 Returns 

400 ------- 

401 u : Int32[Array, ''] 

402 A random integer in the range ``[0, n)``. If all probabilities are zero, 

403 return ``n``. 

404 norm : Float32[Array, ''] 

405 The sum of `distr`. 

406 

407 Notes 

408 ----- 

409 This function uses a cumsum instead of the Gumbel trick, so it's ok only 

410 for small ranges with probabilities well greater than 0. 

411 """ 

412 ecdf = jnp.cumsum(distr) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

413 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

414 return jnp.searchsorted(ecdf, u, 'right', method='compare_all'), ecdf[-1] 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

415 

416 

417def choose_variable( 

418 key: Key[Array, ''], 

419 var_tree: UInt[Array, ' 2**(d-1)'], 

420 split_tree: UInt[Array, ' 2**(d-1)'], 

421 max_split: UInt[Array, ' p'], 

422 leaf_index: Int32[Array, ''], 

423 blocked_vars: Int32[Array, ' k'] | None, 

424 log_s: Float32[Array, ' p'] | None, 

425) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

426 """ 

427 Choose a variable to split on for a new non-terminal node. 

428 

429 Parameters 

430 ---------- 

431 key 

432 A jax random key. 

433 var_tree 

434 The variable indices of the tree. 

435 split_tree 

436 The splitting points of the tree. 

437 max_split 

438 The maximum split index for each variable. 

439 leaf_index 

440 The index of the leaf to grow. 

441 blocked_vars 

442 The indices of the variables that have no available cutpoints. If 

443 `None`, all variables are assumed unblocked. 

444 log_s 

445 The logarithm of the prior probability for choosing a variable. If 

446 `None`, use a uniform distribution. 

447 

448 Returns 

449 ------- 

450 var : Int32[Array, ''] 

451 The index of the variable to split on. 

452 num_available_var : Int32[Array, ''] 

453 The number of variables with available decision rules `var` was chosen 

454 from. 

455 """ 

456 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

457 if blocked_vars is not None: 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

458 var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars]) 1ab

459 

460 if log_s is None: 2a Y Z 0 eb1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X

461 return randint_exclude(key, max_split.size, var_to_ignore) 1acdefghijklmbnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

462 else: 

463 return categorical_exclude(key, log_s, var_to_ignore) 2Y Z 0 eb1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . /

464 

465 

466def fully_used_variables( 

467 var_tree: UInt[Array, ' 2**(d-1)'], 

468 split_tree: UInt[Array, ' 2**(d-1)'], 

469 max_split: UInt[Array, ' p'], 

470 leaf_index: Int32[Array, ''], 

471) -> UInt[Array, ' d-2']: 

472 """ 

473 Find variables in the ancestors of a node that have an empty split range. 

474 

475 Parameters 

476 ---------- 

477 var_tree 

478 The variable indices of the tree. 

479 split_tree 

480 The splitting points of the tree. 

481 max_split 

482 The maximum split index for each variable. 

483 leaf_index 

484 The index of the node, assumed to be valid for `var_tree`. 

485 

486 Returns 

487 ------- 

488 The indices of the variables that have an empty split range. 

489 

490 Notes 

491 ----- 

492 The number of unused variables is not known in advance. Unused values in the 

493 array are filled with `p`. The fill values are not guaranteed to be placed 

494 in any particular order, and variables may appear more than once. 

495 """ 

496 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

497 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

498 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

499 num_split = r - l 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

500 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

501 # the type of var_to_ignore is already sufficient to hold max_split.size, 

502 # see ancestor_variables() 

503 

504 

505def ancestor_variables( 

506 var_tree: UInt[Array, ' 2**(d-1)'], 

507 max_split: UInt[Array, ' p'], 

508 node_index: Int32[Array, ''], 

509) -> UInt[Array, ' d-2']: 

510 """ 

511 Return the list of variables in the ancestors of a node. 

512 

513 Parameters 

514 ---------- 

515 var_tree 

516 The variable indices of the tree. 

517 max_split 

518 The maximum split index for each variable. Used only to get `p`. 

519 node_index 

520 The index of the node, assumed to be valid for `var_tree`. 

521 

522 Returns 

523 ------- 

524 The variable indices of the ancestors of the node. 

525 

526 Notes 

527 ----- 

528 The ancestors are the nodes going from the root to the parent of the node. 

529 The number of ancestors is not known at tracing time; unused spots in the 

530 output array are filled with `p`. 

531 """ 

532 max_num_ancestors = grove.tree_depth(var_tree) - 1 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D fbgbhbibjbkbE F G H I J K L M N O P Q R S T U V W X

533 index = node_index >> jnp.arange(max_num_ancestors, 0, -1) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D fbgbhbibjbkbE F G H I J K L M N O P Q R S T U V W X

534 var = var_tree[index] 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D fbgbhbibjbkbE F G H I J K L M N O P Q R S T U V W X

535 var_type = minimal_unsigned_dtype(max_split.size) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D fbgbhbibjbkbE F G H I J K L M N O P Q R S T U V W X

536 p = jnp.array(max_split.size, var_type) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D fbgbhbibjbkbE F G H I J K L M N O P Q R S T U V W X

537 return jnp.where(index, var, p) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D fbgbhbibjbkbE F G H I J K L M N O P Q R S T U V W X

538 

539 

540def split_range( 

541 var_tree: UInt[Array, ' 2**(d-1)'], 

542 split_tree: UInt[Array, ' 2**(d-1)'], 

543 max_split: UInt[Array, ' p'], 

544 node_index: Int32[Array, ''], 

545 ref_var: Int32[Array, ''], 

546) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

547 """ 

548 Return the range of allowed splits for a variable at a given node. 

549 

550 Parameters 

551 ---------- 

552 var_tree 

553 The variable indices of the tree. 

554 split_tree 

555 The splitting points of the tree. 

556 max_split 

557 The maximum split index for each variable. 

558 node_index 

559 The index of the node, assumed to be valid for `var_tree`. 

560 ref_var 

561 The variable for which to measure the split range. 

562 

563 Returns 

564 ------- 

565 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1. 

566 """ 

567 max_num_ancestors = grove.tree_depth(var_tree) - 1 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX

568 index = node_index >> jnp.arange(max_num_ancestors) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX

569 right_child = (index & 1).astype(bool) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX

570 index >>= 1 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX

571 split = split_tree[index].astype(jnp.int32) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX

572 cond = (var_tree[index] == ref_var) & index.astype(bool) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX

573 l = jnp.max(split, initial=0, where=cond & right_child) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX

574 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX

575 jnp.int32 

576 ) 

577 r = jnp.min(split, initial=initial_r, where=cond & ~right_child) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX

578 

579 return l + 1, r 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX

580 

581 

582def randint_exclude( 

583 key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n'] 

584) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

585 """ 

586 Return a random integer in a range, excluding some values. 

587 

588 Parameters 

589 ---------- 

590 key 

591 A jax random key. 

592 sup 

593 The exclusive upper bound of the range, must be >= 1. 

594 exclude 

595 The values to exclude from the range. Values greater than or equal to 

596 `sup` are ignored. Values can appear more than once. 

597 

598 Returns 

599 ------- 

600 u : Int32[Array, ''] 

601 A random integer `u` in the range ``[0, sup)`` such that ``u not in 

602 exclude``. 

603 num_allowed : Int32[Array, ''] 

604 The number of integers in the range that were not excluded. 

605 

606 Notes 

607 ----- 

608 If all values in the range are excluded, return `sup`. 

609 """ 

610 exclude, num_allowed = _process_exclude(sup, exclude) 2a c d e f g h i j k l m b n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X

611 u = random.randint(key, (), 0, num_allowed) 2a c d e f g h i j k l m b n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X

612 u_shifted = u + jnp.arange(exclude.size) 2a c d e f g h i j k l m b n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X

613 u_shifted = jnp.minimum(u_shifted, sup - 1) 2a c d e f g h i j k l m b n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X

614 u += jnp.sum(u_shifted >= exclude) 2a c d e f g h i j k l m b n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X

615 return u, num_allowed 2a c d e f g h i j k l m b n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X

616 

617 

618def _process_exclude( 

619 sup: int | Integer[Array, ''], exclude: Integer[Array, ' n'] 

620) -> tuple[Integer[Array, ' n'], Integer[Array, '']]: 

621 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X

622 num_allowed = sup - jnp.sum(exclude < sup) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X

623 return exclude, num_allowed 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X

624 

625 

626def categorical_exclude( 

627 key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n'] 

628) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

629 """ 

630 Draw from a categorical distribution, excluding a set of values. 

631 

632 Parameters 

633 ---------- 

634 key 

635 A jax random key. 

636 logits 

637 The unnormalized log-probabilities of each category. 

638 exclude 

639 The values to exclude from the range [0, k). Values greater than or 

640 equal to `logits.size` are ignored. Values can appear more than once. 

641 

642 Returns 

643 ------- 

644 u : Int32[Array, ''] 

645 A random integer in the range ``[0, k)`` such that ``u not in exclude``. 

646 num_allowed : Int32[Array, ''] 

647 The number of integers in the range that were not excluded. 

648 

649 Notes 

650 ----- 

651 If all values in the range are excluded, the result is unspecified. 

652 """ 

653 exclude, num_allowed = _process_exclude(logits.size, exclude) 2Y Z 0 eb1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . /

654 kinda_neg_inf = jnp.finfo(logits.dtype).min 2Y Z 0 eb1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . /

655 logits = logits.at[exclude].set(kinda_neg_inf) 2Y Z 0 eb1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . /

656 u = random.categorical(key, logits) 2Y Z 0 eb1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . /

657 return u, num_allowed 2Y Z 0 eb1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . /

658 

659 

660def choose_split( 

661 key: Key[Array, ''], 

662 var: Int32[Array, ''], 

663 var_tree: UInt[Array, ' 2**(d-1)'], 

664 split_tree: UInt[Array, ' 2**(d-1)'], 

665 max_split: UInt[Array, ' p'], 

666 leaf_index: Int32[Array, ''], 

667) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]: 

668 """ 

669 Choose a split point for a new non-terminal node. 

670 

671 Parameters 

672 ---------- 

673 key 

674 A jax random key. 

675 var 

676 The variable to split on. 

677 var_tree 

678 The splitting axes of the tree. Does not need to already contain `var` 

679 at `leaf_index`. 

680 split_tree 

681 The splitting points of the tree. 

682 max_split 

683 The maximum split index for each variable. 

684 leaf_index 

685 The index of the leaf to grow. 

686 

687 Returns 

688 ------- 

689 split : Int32[Array, ''] 

690 The cutpoint. 

691 l : Int32[Array, ''] 

692 r : Int32[Array, ''] 

693 The integer range `split` was drawn from is [l, r). 

694 

695 Notes 

696 ----- 

697 If `var` is out of bounds, or if the available split range on that variable 

698 is empty, return 0. 

699 """ 

700 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

701 return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

702 

703 

704def compute_partial_ratio( 

705 prob_choose: Float32[Array, ''], 

706 num_prunable: Int32[Array, ''], 

707 p_nonterminal: Float32[Array, ' 2**d'], 

708 leaf_to_grow: Int32[Array, ''], 

709) -> Float32[Array, '']: 

710 """ 

711 Compute the product of the transition and prior ratios of a grow move. 

712 

713 Parameters 

714 ---------- 

715 prob_choose 

716 The probability that the leaf had to be chosen amongst the growable 

717 leaves. 

718 num_prunable 

719 The number of leaf parents that could be pruned, after converting the 

720 leaf to be grown to a non-terminal node. 

721 p_nonterminal 

722 The a priori probability of each node being nonterminal conditional on 

723 its ancestors. 

724 leaf_to_grow 

725 The index of the leaf to grow. 

726 

727 Returns 

728 ------- 

729 The partial transition ratio times the prior ratio. 

730 

731 Notes 

732 ----- 

733 The transition ratio is P(new tree => old tree) / P(old tree => new tree). 

734 The "partial" transition ratio returned is missing the factor P(propose 

735 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The 

736 "partial" prior ratio is missing the factor P(children are leaves). 

737 """ 

738 # the two ratios also contain factors num_available_split * 

739 # num_available_var * s[var], but they cancel out 

740 

741 # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be 

742 # computed here because they need the count trees, which are computed in the 

743 # acceptance phase 

744 

745 prune_allowed = leaf_to_grow != 1 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

746 # prune allowed <---> the initial tree is not a root 

747 # leaf to grow is root --> the tree can only be a root 

748 # tree is a root --> the only leaf I can grow is root 

749 p_grow = jnp.where(prune_allowed, 0.5, 1) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

750 inv_trans_ratio = p_grow * prob_choose * num_prunable 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

751 

752 # .at.get because if leaf_to_grow is out of bounds (move not allowed), this 

753 # would produce a 0 and then an inf when `complete_ratio` takes the log 

754 pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

755 tree_ratio = pnt / (1 - pnt) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

756 

757 return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

758 

759 

760class PruneMoves(Module): 

761 """Represent a proposed prune move for each tree.""" 

762 

763 allowed: Bool[Array, ' num_trees'] 

764 """Whether the move is possible.""" 

765 

766 node: UInt[Array, ' num_trees'] 

767 """The index of the node to prune. ``2 ** d`` if no node can be pruned.""" 

768 

769 partial_ratio: Float32[Array, ' num_trees'] 

770 """A factor of the Metropolis-Hastings ratio of the move. It lacks the 

771 likelihood ratio, the probability of proposing the prune move, and the 

772 prior probability that the children of the node to prune are leaves. 

773 This ratio is inverted, and is meant to be inverted back in 

774 `accept_move_and_sample_leaves`.""" 

775 

776 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 

777 """A partially updated `affluence_tree`, marking the node to prune as 

778 growable.""" 

779 

780 

781@named_call 

782@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) 

783def propose_prune_moves( 

784 key: Key[Array, ''], 

785 split_tree: UInt[Array, ' 2**(d-1)'], 

786 affluence_tree: Bool[Array, ' 2**(d-1)'], 

787 p_nonterminal: Float32[Array, ' 2**d'], 

788 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

789) -> PruneMoves: 

790 """ 

791 Tree structure prune move proposal of BART MCMC. 

792 

793 Parameters 

794 ---------- 

795 key 

796 A jax random key. 

797 split_tree 

798 The splitting points of the tree. 

799 affluence_tree 

800 Whether each leaf can be grown. 

801 p_nonterminal 

802 The a priori probability of a node to be nonterminal conditional on 

803 the ancestors, including at the maximum depth where it should be zero. 

804 p_propose_grow 

805 The unnormalized probability of choosing a leaf to grow. 

806 

807 Returns 

808 ------- 

809 An object representing the proposed moves. 

810 """ 

811 node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

812 key, split_tree, affluence_tree, p_propose_grow 

813 ) 

814 

815 ratio = compute_partial_ratio( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

816 prob_choose, num_prunable, p_nonterminal, node_to_prune 

817 ) 

818 

819 return PruneMoves( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

820 allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root 

821 node=node_to_prune, 

822 partial_ratio=ratio, 

823 affluence_tree=affluence_tree, 

824 ) 

825 

826 

827def choose_leaf_parent( 

828 key: Key[Array, ''], 

829 split_tree: UInt[Array, ' 2**(d-1)'], 

830 affluence_tree: Bool[Array, ' 2**(d-1)'], 

831 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

832) -> tuple[ 

833 Int32[Array, ''], 

834 Int32[Array, ''], 

835 Float32[Array, ''], 

836 Bool[Array, 'num_trees 2**(d-1)'], 

837]: 

838 """ 

839 Pick a non-terminal node with leaf children to prune in a tree. 

840 

841 Parameters 

842 ---------- 

843 key 

844 A jax random key. 

845 split_tree 

846 The splitting points of the tree. 

847 affluence_tree 

848 Whether a leaf has enough points to be grown. 

849 p_propose_grow 

850 The unnormalized probability of choosing a leaf to grow. 

851 

852 Returns 

853 ------- 

854 node_to_prune : Int32[Array, ''] 

855 The index of the node to prune. If ``num_prunable == 0``, return 

856 ``2 ** d``. 

857 num_prunable : Int32[Array, ''] 

858 The number of leaf parents that could be pruned. 

859 prob_choose : Float32[Array, ''] 

860 The (normalized) probability that `choose_leaf` would chose 

861 `node_to_prune` as leaf to grow, if passed the tree where 

862 `node_to_prune` had been pruned. 

863 affluence_tree : Bool[Array, 'num_trees 2**(d-1)'] 

864 A partially updated `affluence_tree`, marking the node to prune as 

865 growable. 

866 """ 

867 # sample a node to prune 

868 is_prunable = grove.is_leaves_parent(split_tree) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

869 num_prunable = jnp.count_nonzero(is_prunable) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

870 node_to_prune = randint_masked(key, is_prunable) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

871 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

872 

873 # compute stuff for reverse move 

874 split_tree = split_tree.at[node_to_prune].set(0) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

875 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

876 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

877 distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

878 prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

879 prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

880 

881 return node_to_prune, num_prunable, prob_choose, affluence_tree 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX

882 

883 

884def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']: 

885 """ 

886 Return a random integer in a range, including only some values. 

887 

888 Parameters 

889 ---------- 

890 key 

891 A jax random key. 

892 mask 

893 The mask indicating the allowed values. 

894 

895 Returns 

896 ------- 

897 A random integer in the range ``[0, n)`` such that ``mask[u] == True``. 

898 

899 Notes 

900 ----- 

901 If all values in the mask are `False`, return `n`. This function is 

902 optimized for small `n`. 

903 """ 

904 ecdf = jnp.cumsum(mask) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / lbn o p q r s t u v w x y z A B C D E F G H I J K L M N O P mbnbobpbQ R S T U V W X

905 u = random.randint(key, (), 0, ecdf[-1]) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / lbn o p q r s t u v w x y z A B C D E F G H I J K L M N O P mbnbobpbQ R S T U V W X

906 return jnp.searchsorted(ecdf, u, 'right', method='compare_all') 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / lbn o p q r s t u v w x y z A B C D E F G H I J K L M N O P mbnbobpbQ R S T U V W X