Coverage for src / bartz / mcmcstep / _moves.py: 100%
148 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
1# bartz/src/bartz/mcmcstep/_moves.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement `propose_moves` and associated dataclasses."""
27from functools import partial
29import jax
30from equinox import Module
31from jax import numpy as jnp
32from jax import random
33from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt
35from bartz import grove
36from bartz._profiler import jit_and_block_if_profiling
37from bartz.jaxext import minimal_unsigned_dtype, split, vmap_nodoc
38from bartz.mcmcstep._state import Forest, field, vmap_chains
41class Moves(Module):
42 """
43 Moves proposed to modify each tree.
45 Parameters
46 ----------
47 allowed
48 Whether there is a possible move. If `False`, the other values may not
49 make sense. The only case in which a move is marked as allowed but is
50 then vetoed is if it does not satisfy `min_points_per_leaf`, which for
51 efficiency is implemented post-hoc without changing the rest of the
52 MCMC logic.
53 grow
54 Whether the move is a grow move or a prune move.
55 num_growable
56 The number of growable leaves in the original tree.
57 node
58 The index of the leaf to grow or node to prune.
59 left
60 right
61 The indices of the children of 'node'.
62 partial_ratio
63 A factor of the Metropolis-Hastings ratio of the move. It lacks the
64 likelihood ratio, the probability of proposing the prune move, and the
65 probability that the children of the modified node are terminal. If the
66 move is PRUNE, the ratio is inverted. `None` once
67 `log_trans_prior_ratio` has been computed.
68 log_trans_prior_ratio
69 The logarithm of the product of the transition and prior terms of the
70 Metropolis-Hastings ratio for the acceptance of the proposed move.
71 `None` if not yet computed. If PRUNE, the log-ratio is negated.
72 grow_var
73 The decision axes of the new rules.
74 grow_split
75 The decision boundaries of the new rules.
76 var_tree
77 The updated decision axes of the trees, valid whatever move.
78 affluence_tree
79 A partially updated `affluence_tree`, marking non-leaf nodes that would
80 become leaves if the move was accepted. This mark initially (out of
81 `propose_moves`) takes into account if there would be available decision
82 rules to grow the leaf, and whether there are enough datapoints in the
83 node is instead checked later in `accept_moves_parallel_stage`.
84 logu
85 The logarithm of a uniform (0, 1] random variable to be used to
86 accept the move. It's in (-oo, 0].
87 acc
88 Whether the move was accepted. `None` if not yet computed.
89 to_prune
90 Whether the final operation to apply the move is pruning. This indicates
91 an accepted prune move or a rejected grow move. `None` if not yet
92 computed.
93 """
95 allowed: Bool[Array, '*chains num_trees'] = field(chains=True)
96 grow: Bool[Array, '*chains num_trees'] = field(chains=True)
97 num_growable: UInt[Array, '*chains num_trees'] = field(chains=True)
98 node: UInt[Array, '*chains num_trees'] = field(chains=True)
99 left: UInt[Array, '*chains num_trees'] = field(chains=True)
100 right: UInt[Array, '*chains num_trees'] = field(chains=True)
101 partial_ratio: Float32[Array, '*chains num_trees'] | None = field(chains=True)
102 log_trans_prior_ratio: None | Float32[Array, '*chains num_trees'] = field(
103 chains=True
104 )
105 grow_var: UInt[Array, '*chains num_trees'] = field(chains=True)
106 grow_split: UInt[Array, '*chains num_trees'] = field(chains=True)
107 var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
108 affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
109 logu: Float32[Array, '*chains num_trees'] = field(chains=True)
110 acc: None | Bool[Array, '*chains num_trees'] = field(chains=True)
111 to_prune: None | Bool[Array, '*chains num_trees'] = field(chains=True)
114@partial(jit_and_block_if_profiling, donate_argnums=(1,))
115@vmap_chains
116def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves:
117 """
118 Propose moves for all the trees.
120 There are two types of moves: GROW (convert a leaf to a decision node and
121 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
122 leaf, deleting its children).
124 Parameters
125 ----------
126 key
127 A jax random key.
128 forest
129 The `forest` field of a BART MCMC state.
131 Returns
132 -------
133 The proposed move for each tree.
134 """
135 num_trees = forest.leaf_tree.shape[0] 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
136 keys = split(key, 2) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
137 grow_keys, prune_keys = keys.pop((2, num_trees)) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
139 # compute moves
140 grow_moves = propose_grow_moves( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
141 grow_keys,
142 forest.var_tree,
143 forest.split_tree,
144 forest.affluence_tree,
145 forest.max_split,
146 forest.blocked_vars,
147 forest.p_nonterminal,
148 forest.p_propose_grow,
149 forest.log_s,
150 )
151 prune_moves = propose_prune_moves( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
152 prune_keys,
153 forest.split_tree,
154 grow_moves.affluence_tree,
155 forest.p_nonterminal,
156 forest.p_propose_grow,
157 )
159 u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees)) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
161 # choose between grow or prune
162 p_grow = jnp.where( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
163 grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
164 )
165 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
167 # compute children indices
168 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
169 left, right = (node << 1) | jnp.arange(2)[:, None] 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
171 return Moves( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
172 allowed=grow_moves.allowed | prune_moves.allowed,
173 grow=grow,
174 num_growable=grow_moves.num_growable,
175 node=node,
176 left=left,
177 right=right,
178 partial_ratio=jnp.where(
179 grow, grow_moves.partial_ratio, prune_moves.partial_ratio
180 ),
181 log_trans_prior_ratio=None, # will be set in complete_ratio
182 grow_var=grow_moves.var,
183 grow_split=grow_moves.split,
184 # var_tree does not need to be updated if prune
185 var_tree=grow_moves.var_tree,
186 # affluence_tree is updated for both moves unconditionally, prune last
187 affluence_tree=prune_moves.affluence_tree,
188 logu=jnp.log1p(-exp1mlogu),
189 acc=None, # will be set in accept_moves_sequential_stage
190 to_prune=None, # will be set in accept_moves_sequential_stage
191 )
194class GrowMoves(Module):
195 """
196 Represent a proposed grow move for each tree.
198 Parameters
199 ----------
200 allowed
201 Whether the move is allowed for proposal.
202 num_growable
203 The number of leaves that can be proposed for grow.
204 node
205 The index of the leaf to grow. ``2 ** d`` if there are no growable
206 leaves.
207 var
208 split
209 The decision axis and boundary of the new rule.
210 partial_ratio
211 A factor of the Metropolis-Hastings ratio of the move. It lacks
212 the likelihood ratio and the probability of proposing the prune
213 move.
214 var_tree
215 The updated decision axes of the tree.
216 affluence_tree
217 A partially updated `affluence_tree` that marks each new leaf that
218 would be produced as `True` if it would have available decision rules.
219 """
221 allowed: Bool[Array, ' num_trees']
222 num_growable: UInt[Array, ' num_trees']
223 node: UInt[Array, ' num_trees']
224 var: UInt[Array, ' num_trees']
225 split: UInt[Array, ' num_trees']
226 partial_ratio: Float32[Array, ' num_trees']
227 var_tree: UInt[Array, 'num_trees 2**(d-1)']
228 affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
231@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None))
232def propose_grow_moves(
233 key: Key[Array, ' num_trees'],
234 var_tree: UInt[Array, 'num_trees 2**(d-1)'],
235 split_tree: UInt[Array, 'num_trees 2**(d-1)'],
236 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
237 max_split: UInt[Array, ' p'],
238 blocked_vars: Int32[Array, ' k'] | None,
239 p_nonterminal: Float32[Array, ' 2**d'],
240 p_propose_grow: Float32[Array, ' 2**(d-1)'],
241 log_s: Float32[Array, ' p'] | None,
242) -> GrowMoves:
243 """
244 Propose a GROW move for each tree.
246 A GROW move picks a leaf node and converts it to a non-terminal node with
247 two leaf children.
249 Parameters
250 ----------
251 key
252 A jax random key.
253 var_tree
254 The splitting axes of the tree.
255 split_tree
256 The splitting points of the tree.
257 affluence_tree
258 Whether each leaf has enough points to be grown.
259 max_split
260 The maximum split index for each variable.
261 blocked_vars
262 The indices of the variables that have no available cutpoints.
263 p_nonterminal
264 The a priori probability of a node to be nonterminal conditional on the
265 ancestors, including at the maximum depth where it should be zero.
266 p_propose_grow
267 The unnormalized probability of choosing a leaf to grow.
268 log_s
269 Unnormalized log-probability used to choose a variable to split on
270 amongst the available ones.
272 Returns
273 -------
274 An object representing the proposed move.
276 Notes
277 -----
278 The move is not proposed if each leaf is already at maximum depth, or has
279 less datapoints than the requested threshold `min_points_per_decision_node`,
280 or it does not have any available decision rules given its ancestors. This
281 is marked by setting `allowed` to `False` and `num_growable` to 0.
282 """
283 keys = split(key, 3) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
285 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
286 keys.pop(), split_tree, affluence_tree, p_propose_grow
287 )
289 # sample a decision rule
290 var, num_available_var = choose_variable( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
291 keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
292 )
293 split_idx, l, r = choose_split( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
294 keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
295 )
297 # determine if the new leaves would have available decision rules; if the
298 # move is blocked, these values may not make sense
299 leftright_growable = (num_available_var > 1) | jnp.stack( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
300 [l < split_idx, split_idx + 1 < r]
301 )
302 leftright = (leaf_to_grow << 1) | jnp.arange(2) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
303 affluence_tree = affluence_tree.at[leftright].set(leftright_growable) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
305 ratio = compute_partial_ratio( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
306 prob_choose, num_prunable, p_nonterminal, leaf_to_grow
307 )
309 return GrowMoves( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
310 allowed=num_growable > 0,
311 num_growable=num_growable,
312 node=leaf_to_grow,
313 var=var,
314 split=split_idx,
315 partial_ratio=ratio,
316 var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
317 affluence_tree=affluence_tree,
318 )
321def choose_leaf(
322 key: Key[Array, ''],
323 split_tree: UInt[Array, ' 2**(d-1)'],
324 affluence_tree: Bool[Array, ' 2**(d-1)'],
325 p_propose_grow: Float32[Array, ' 2**(d-1)'],
326) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
327 """
328 Choose a leaf node to grow in a tree.
330 Parameters
331 ----------
332 key
333 A jax random key.
334 split_tree
335 The splitting points of the tree.
336 affluence_tree
337 Whether a leaf has enough points that it could be split into two leaves
338 satisfying the `min_points_per_decision_node` requirement.
339 p_propose_grow
340 The unnormalized probability of choosing a leaf to grow.
342 Returns
343 -------
344 leaf_to_grow : Int32[Array, '']
345 The index of the leaf to grow. If ``num_growable == 0``, return
346 ``2 ** d``.
347 num_growable : Int32[Array, '']
348 The number of leaf nodes that can be grown, i.e., are nonterminal
349 and have at least twice `min_points_per_decision_node`.
350 prob_choose : Float32[Array, '']
351 The (normalized) probability that this function had to choose that
352 specific leaf, given the arguments.
353 num_prunable : Int32[Array, '']
354 The number of leaf parents that could be pruned, after converting the
355 selected leaf to a non-terminal node.
356 """
357 is_growable = growable_leaves(split_tree, affluence_tree) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
358 num_growable = jnp.count_nonzero(is_growable) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
359 distr = jnp.where(is_growable, p_propose_grow, 0) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
360 leaf_to_grow, distr_norm = categorical(key, distr) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
361 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
362 prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
363 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
364 num_prunable = jnp.count_nonzero(is_parent) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
365 return leaf_to_grow, num_growable, prob_choose, num_prunable 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
368def growable_leaves(
369 split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
370) -> Bool[Array, ' 2**(d-1)']:
371 """
372 Return a mask indicating the leaf nodes that can be proposed for growth.
374 The condition is that a leaf is not at the bottom level, has available
375 decision rules given its ancestors, and has at least
376 `min_points_per_decision_node` points.
378 Parameters
379 ----------
380 split_tree
381 The splitting points of the tree.
382 affluence_tree
383 Marks leaves that can be grown.
385 Returns
386 -------
387 The mask indicating the leaf nodes that can be proposed to grow.
389 Notes
390 -----
391 This function needs `split_tree` and not just `affluence_tree` because
392 `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
393 """
394 return grove.is_actual_leaf(split_tree) & affluence_tree 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
397def categorical(
398 key: Key[Array, ''], distr: Float32[Array, ' n']
399) -> tuple[Int32[Array, ''], Float32[Array, '']]:
400 """
401 Return a random integer from an arbitrary distribution.
403 Parameters
404 ----------
405 key
406 A jax random key.
407 distr
408 An unnormalized probability distribution.
410 Returns
411 -------
412 u : Int32[Array, '']
413 A random integer in the range ``[0, n)``. If all probabilities are zero,
414 return ``n``.
415 norm : Float32[Array, '']
416 The sum of `distr`.
418 Notes
419 -----
420 This function uses a cumsum instead of the Gumbel trick, so it's ok only
421 for small ranges with probabilities well greater than 0.
422 """
423 ecdf = jnp.cumsum(distr) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
424 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
425 return jnp.searchsorted(ecdf, u, 'right', method='compare_all'), ecdf[-1] 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
428def choose_variable(
429 key: Key[Array, ''],
430 var_tree: UInt[Array, ' 2**(d-1)'],
431 split_tree: UInt[Array, ' 2**(d-1)'],
432 max_split: UInt[Array, ' p'],
433 leaf_index: Int32[Array, ''],
434 blocked_vars: Int32[Array, ' k'] | None,
435 log_s: Float32[Array, ' p'] | None,
436) -> tuple[Int32[Array, ''], Int32[Array, '']]:
437 """
438 Choose a variable to split on for a new non-terminal node.
440 Parameters
441 ----------
442 key
443 A jax random key.
444 var_tree
445 The variable indices of the tree.
446 split_tree
447 The splitting points of the tree.
448 max_split
449 The maximum split index for each variable.
450 leaf_index
451 The index of the leaf to grow.
452 blocked_vars
453 The indices of the variables that have no available cutpoints. If
454 `None`, all variables are assumed unblocked.
455 log_s
456 The logarithm of the prior probability for choosing a variable. If
457 `None`, use a uniform distribution.
459 Returns
460 -------
461 var : Int32[Array, '']
462 The index of the variable to split on.
463 num_available_var : Int32[Array, '']
464 The number of variables with available decision rules `var` was chosen
465 from.
466 """
467 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
468 if blocked_vars is not None: 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
469 var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars]) 1zaABbCDcEGHIJKdLMeNfghOiPQjRSTklmnopqrstuvwxy
471 if log_s is None: 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
472 return randint_exclude(key, max_split.size, var_to_ignore) 1abcFdefghijkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
473 else:
474 return categorical_exclude(key, log_s, var_to_ignore) 1zABCDE+,GHIJKLMNOPQRST
477def fully_used_variables(
478 var_tree: UInt[Array, ' 2**(d-1)'],
479 split_tree: UInt[Array, ' 2**(d-1)'],
480 max_split: UInt[Array, ' p'],
481 leaf_index: Int32[Array, ''],
482) -> UInt[Array, ' d-2']:
483 """
484 Find variables in the ancestors of a node that have an empty split range.
486 Parameters
487 ----------
488 var_tree
489 The variable indices of the tree.
490 split_tree
491 The splitting points of the tree.
492 max_split
493 The maximum split index for each variable.
494 leaf_index
495 The index of the node, assumed to be valid for `var_tree`.
497 Returns
498 -------
499 The indices of the variables that have an empty split range.
501 Notes
502 -----
503 The number of unused variables is not known in advance. Unused values in the
504 array are filled with `p`. The fill values are not guaranteed to be placed
505 in any particular order, and variables may appear more than once.
506 """
507 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
508 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
509 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
510 num_split = r - l 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
511 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
512 # the type of var_to_ignore is already sufficient to hold max_split.size,
513 # see ancestor_variables()
516def ancestor_variables(
517 var_tree: UInt[Array, ' 2**(d-1)'],
518 max_split: UInt[Array, ' p'],
519 node_index: Int32[Array, ''],
520) -> UInt[Array, ' d-2']:
521 """
522 Return the list of variables in the ancestors of a node.
524 Parameters
525 ----------
526 var_tree
527 The variable indices of the tree.
528 max_split
529 The maximum split index for each variable. Used only to get `p`.
530 node_index
531 The index of the node, assumed to be valid for `var_tree`.
533 Returns
534 -------
535 The variable indices of the ancestors of the node.
537 Notes
538 -----
539 The ancestors are the nodes going from the root to the parent of the node.
540 The number of ancestors is not known at tracing time; unused spots in the
541 output array are filled with `p`.
542 """
543 max_num_ancestors = grove.tree_depth(var_tree) - 1 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 bbcbdbebfbgbl m n o p q r s t u v w x y ! # $ % ' ( ) *
544 index = node_index >> jnp.arange(max_num_ancestors, 0, -1) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 bbcbdbebfbgbl m n o p q r s t u v w x y ! # $ % ' ( ) *
545 var = var_tree[index] 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 bbcbdbebfbgbl m n o p q r s t u v w x y ! # $ % ' ( ) *
546 var_type = minimal_unsigned_dtype(max_split.size) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 bbcbdbebfbgbl m n o p q r s t u v w x y ! # $ % ' ( ) *
547 p = jnp.array(max_split.size, var_type) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 bbcbdbebfbgbl m n o p q r s t u v w x y ! # $ % ' ( ) *
548 return jnp.where(index, var, p) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 bbcbdbebfbgbl m n o p q r s t u v w x y ! # $ % ' ( ) *
551def split_range(
552 var_tree: UInt[Array, ' 2**(d-1)'],
553 split_tree: UInt[Array, ' 2**(d-1)'],
554 max_split: UInt[Array, ' p'],
555 node_index: Int32[Array, ''],
556 ref_var: Int32[Array, ''],
557) -> tuple[Int32[Array, ''], Int32[Array, '']]:
558 """
559 Return the range of allowed splits for a variable at a given node.
561 Parameters
562 ----------
563 var_tree
564 The variable indices of the tree.
565 split_tree
566 The splitting points of the tree.
567 max_split
568 The maximum split index for each variable.
569 node_index
570 The index of the node, assumed to be valid for `var_tree`.
571 ref_var
572 The variable for which to measure the split range.
574 Returns
575 -------
576 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
577 """
578 max_num_ancestors = grove.tree_depth(var_tree) - 1 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*
579 index = node_index >> jnp.arange(max_num_ancestors) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*
580 right_child = (index & 1).astype(bool) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*
581 index >>= 1 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*
582 split = split_tree[index].astype(jnp.int32) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*
583 cond = (var_tree[index] == ref_var) & index.astype(bool) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*
584 l = jnp.max(split, initial=0, where=cond & right_child) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*
585 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*
586 jnp.int32
587 )
588 r = jnp.min(split, initial=initial_r, where=cond & ~right_child) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*
590 return l + 1, r 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy./:;=?@[]^!#$%'()*
593def randint_exclude(
594 key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
595) -> tuple[Int32[Array, ''], Int32[Array, '']]:
596 """
597 Return a random integer in a range, excluding some values.
599 Parameters
600 ----------
601 key
602 A jax random key.
603 sup
604 The exclusive upper bound of the range, must be >= 1.
605 exclude
606 The values to exclude from the range. Values greater than or equal to
607 `sup` are ignored. Values can appear more than once.
609 Returns
610 -------
611 u : Int32[Array, '']
612 A random integer `u` in the range ``[0, sup)`` such that ``u not in
613 exclude``.
614 num_allowed : Int32[Array, '']
615 The number of integers in the range that were not excluded.
617 Notes
618 -----
619 If all values in the range are excluded, return `sup`.
620 """
621 exclude, num_allowed = _process_exclude(sup, exclude) 2a b c F d e f g h i j k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *
622 u = random.randint(key, (), 0, num_allowed) 2a b c F d e f g h i j k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *
623 u_shifted = u + jnp.arange(exclude.size) 2a b c F d e f g h i j k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *
624 u_shifted = jnp.minimum(u_shifted, sup - 1) 2a b c F d e f g h i j k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *
625 u += jnp.sum(u_shifted >= exclude) 2a b c F d e f g h i j k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *
626 return u, num_allowed 2a b c F d e f g h i j k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *
629def _process_exclude(sup, exclude):
630 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *
631 num_allowed = sup - jnp.sum(exclude < sup) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *
632 return exclude, num_allowed 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k U V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y _ ` { | } ~ ab! # $ % ' ( ) *
635def categorical_exclude(
636 key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
637) -> tuple[Int32[Array, ''], Int32[Array, '']]:
638 """
639 Draw from a categorical distribution, excluding a set of values.
641 Parameters
642 ----------
643 key
644 A jax random key.
645 logits
646 The unnormalized log-probabilities of each category.
647 exclude
648 The values to exclude from the range [0, k). Values greater than or
649 equal to `logits.size` are ignored. Values can appear more than once.
651 Returns
652 -------
653 u : Int32[Array, '']
654 A random integer in the range ``[0, k)`` such that ``u not in exclude``.
655 num_allowed : Int32[Array, '']
656 The number of integers in the range that were not excluded.
658 Notes
659 -----
660 If all values in the range are excluded, the result is unspecified.
661 """
662 exclude, num_allowed = _process_exclude(logits.size, exclude) 1zABCDE+,GHIJKLMNOPQRST
663 kinda_neg_inf = jnp.finfo(logits.dtype).min 1zABCDE+,GHIJKLMNOPQRST
664 logits = logits.at[exclude].set(kinda_neg_inf) 1zABCDE+,GHIJKLMNOPQRST
665 u = random.categorical(key, logits) 1zABCDE+,GHIJKLMNOPQRST
666 return u, num_allowed 1zABCDE+,GHIJKLMNOPQRST
669def choose_split(
670 key: Key[Array, ''],
671 var: Int32[Array, ''],
672 var_tree: UInt[Array, ' 2**(d-1)'],
673 split_tree: UInt[Array, ' 2**(d-1)'],
674 max_split: UInt[Array, ' p'],
675 leaf_index: Int32[Array, ''],
676) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
677 """
678 Choose a split point for a new non-terminal node.
680 Parameters
681 ----------
682 key
683 A jax random key.
684 var
685 The variable to split on.
686 var_tree
687 The splitting axes of the tree. Does not need to already contain `var`
688 at `leaf_index`.
689 split_tree
690 The splitting points of the tree.
691 max_split
692 The maximum split index for each variable.
693 leaf_index
694 The index of the leaf to grow.
696 Returns
697 -------
698 split : Int32[Array, '']
699 The cutpoint.
700 l : Int32[Array, '']
701 r : Int32[Array, '']
702 The integer range `split` was drawn from is [l, r).
704 Notes
705 -----
706 If `var` is out of bounds, or if the available split range on that variable
707 is empty, return 0.
708 """
709 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
710 return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
713def compute_partial_ratio(
714 prob_choose: Float32[Array, ''],
715 num_prunable: Int32[Array, ''],
716 p_nonterminal: Float32[Array, ' 2**d'],
717 leaf_to_grow: Int32[Array, ''],
718) -> Float32[Array, '']:
719 """
720 Compute the product of the transition and prior ratios of a grow move.
722 Parameters
723 ----------
724 prob_choose
725 The probability that the leaf had to be chosen amongst the growable
726 leaves.
727 num_prunable
728 The number of leaf parents that could be pruned, after converting the
729 leaf to be grown to a non-terminal node.
730 p_nonterminal
731 The a priori probability of each node being nonterminal conditional on
732 its ancestors.
733 leaf_to_grow
734 The index of the leaf to grow.
736 Returns
737 -------
738 The partial transition ratio times the prior ratio.
740 Notes
741 -----
742 The transition ratio is P(new tree => old tree) / P(old tree => new tree).
743 The "partial" transition ratio returned is missing the factor P(propose
744 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
745 "partial" prior ratio is missing the factor P(children are leaves).
746 """
747 # the two ratios also contain factors num_available_split *
748 # num_available_var * s[var], but they cancel out
750 # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
751 # computed here because they need the count trees, which are computed in the
752 # acceptance phase
754 prune_allowed = leaf_to_grow != 1 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
755 # prune allowed <---> the initial tree is not a root
756 # leaf to grow is root --> the tree can only be a root
757 # tree is a root --> the only leaf I can grow is root
758 p_grow = jnp.where(prune_allowed, 0.5, 1) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
759 inv_trans_ratio = p_grow * prob_choose * num_prunable 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
761 # .at.get because if leaf_to_grow is out of bounds (move not allowed), this
762 # would produce a 0 and then an inf when `complete_ratio` takes the log
763 pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
764 tree_ratio = pnt / (1 - pnt) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
766 return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
769class PruneMoves(Module):
770 """
771 Represent a proposed prune move for each tree.
773 Parameters
774 ----------
775 allowed
776 Whether the move is possible.
777 node
778 The index of the node to prune. ``2 ** d`` if no node can be pruned.
779 partial_ratio
780 A factor of the Metropolis-Hastings ratio of the move. It lacks the
781 likelihood ratio, the probability of proposing the prune move, and the
782 prior probability that the children of the node to prune are leaves.
783 This ratio is inverted, and is meant to be inverted back in
784 `accept_move_and_sample_leaves`.
785 affluence_tree
786 A partially updated `affluence_tree`, marking the node to prune as
787 growable.
788 """
790 allowed: Bool[Array, ' num_trees']
791 node: UInt[Array, ' num_trees']
792 partial_ratio: Float32[Array, ' num_trees']
793 affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
796@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
797def propose_prune_moves(
798 key: Key[Array, ''],
799 split_tree: UInt[Array, ' 2**(d-1)'],
800 affluence_tree: Bool[Array, ' 2**(d-1)'],
801 p_nonterminal: Float32[Array, ' 2**d'],
802 p_propose_grow: Float32[Array, ' 2**(d-1)'],
803) -> PruneMoves:
804 """
805 Tree structure prune move proposal of BART MCMC.
807 Parameters
808 ----------
809 key
810 A jax random key.
811 split_tree
812 The splitting points of the tree.
813 affluence_tree
814 Whether each leaf can be grown.
815 p_nonterminal
816 The a priori probability of a node to be nonterminal conditional on
817 the ancestors, including at the maximum depth where it should be zero.
818 p_propose_grow
819 The unnormalized probability of choosing a leaf to grow.
821 Returns
822 -------
823 An object representing the proposed moves.
824 """
825 node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
826 key, split_tree, affluence_tree, p_propose_grow
827 )
829 ratio = compute_partial_ratio( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
830 prob_choose, num_prunable, p_nonterminal, node_to_prune
831 )
833 return PruneMoves( 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
834 allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
835 node=node_to_prune,
836 partial_ratio=ratio,
837 affluence_tree=affluence_tree,
838 )
841def choose_leaf_parent(
842 key: Key[Array, ''],
843 split_tree: UInt[Array, ' 2**(d-1)'],
844 affluence_tree: Bool[Array, ' 2**(d-1)'],
845 p_propose_grow: Float32[Array, ' 2**(d-1)'],
846) -> tuple[
847 Int32[Array, ''],
848 Int32[Array, ''],
849 Float32[Array, ''],
850 Bool[Array, 'num_trees 2**(d-1)'],
851]:
852 """
853 Pick a non-terminal node with leaf children to prune in a tree.
855 Parameters
856 ----------
857 key
858 A jax random key.
859 split_tree
860 The splitting points of the tree.
861 affluence_tree
862 Whether a leaf has enough points to be grown.
863 p_propose_grow
864 The unnormalized probability of choosing a leaf to grow.
866 Returns
867 -------
868 node_to_prune : Int32[Array, '']
869 The index of the node to prune. If ``num_prunable == 0``, return
870 ``2 ** d``.
871 num_prunable : Int32[Array, '']
872 The number of leaf parents that could be pruned.
873 prob_choose : Float32[Array, '']
874 The (normalized) probability that `choose_leaf` would chose
875 `node_to_prune` as leaf to grow, if passed the tree where
876 `node_to_prune` had been pruned.
877 affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
878 A partially updated `affluence_tree`, marking the node to prune as
879 growable.
880 """
881 # sample a node to prune
882 is_prunable = grove.is_leaves_parent(split_tree) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
883 num_prunable = jnp.count_nonzero(is_prunable) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
884 node_to_prune = randint_masked(key, is_prunable) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
885 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
887 # compute stuff for reverse move
888 split_tree = split_tree.at[node_to_prune].set(0) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
889 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
890 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
891 distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
892 prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
893 prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1) 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
895 return node_to_prune, num_prunable, prob_choose, affluence_tree 1zaABbCDcE+F,GHIJKdLMeNfghOiPQjRSTkUVWXYZ0123456789lmnopqrstuvwxy!#$%'()*
898def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']:
899 """
900 Return a random integer in a range, including only some values.
902 Parameters
903 ----------
904 key
905 A jax random key.
906 mask
907 The mask indicating the allowed values.
909 Returns
910 -------
911 A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
913 Notes
914 -----
915 If all values in the mask are `False`, return `n`. This function is
916 optimized for small `n`.
917 """
918 ecdf = jnp.cumsum(mask) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k hbU V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y ibjbkblb! # $ % ' ( ) *
919 u = random.randint(key, (), 0, ecdf[-1]) 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k hbU V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y ibjbkblb! # $ % ' ( ) *
920 return jnp.searchsorted(ecdf, u, 'right', method='compare_all') 2z a A B b C D c E + F , G H I J K d L M e N f g h O i P Q j R S T k hbU V W X Y Z 0 1 2 3 4 5 6 7 8 9 l m n o p q r s t u v w x y ibjbkblb! # $ % ' ( ) *