Coverage for src / bartz / mcmcstep / _moves.py: 100%
175 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/mcmcstep/_moves.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement `propose_moves` and associated dataclasses."""
27from functools import partial
29import jax
30from equinox import Module
31from jax import named_call, random
32from jax import numpy as jnp
33from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt
35from bartz import grove
36from bartz.jaxext import minimal_unsigned_dtype, split, vmap_nodoc
37from bartz.mcmcstep._state import Forest, field
40class Moves(Module):
41 """Moves proposed to modify each tree."""
43 allowed: Bool[Array, '*chains num_trees'] = field(chains=True)
44 """Whether there is a possible move. If `False`, the other values may not
45 make sense. The only case in which a move is marked as allowed but is
46 then vetoed is if it does not satisfy `min_points_per_leaf`, which for
47 efficiency is implemented post-hoc without changing the rest of the
48 MCMC logic."""
50 grow: Bool[Array, '*chains num_trees'] = field(chains=True)
51 """Whether the move is a grow move or a prune move."""
53 num_growable: UInt[Array, '*chains num_trees'] = field(chains=True)
54 """The number of growable leaves in the original tree."""
56 node: UInt[Array, '*chains num_trees'] = field(chains=True)
57 """The index of the leaf to grow or node to prune."""
59 left: UInt[Array, '*chains num_trees'] = field(chains=True)
60 """The index of the left child of 'node'."""
62 right: UInt[Array, '*chains num_trees'] = field(chains=True)
63 """The index of the right child of 'node'."""
65 partial_ratio: Float32[Array, '*chains num_trees'] | None = field(chains=True)
66 """A factor of the Metropolis-Hastings ratio of the move. It lacks the
67 likelihood ratio, the probability of proposing the prune move, and the
68 probability that the children of the modified node are terminal. If the
69 move is PRUNE, the ratio is inverted. `None` once
70 `log_trans_prior_ratio` has been computed."""
72 log_trans_prior_ratio: None | Float32[Array, '*chains num_trees'] = field(
73 chains=True
74 )
75 """The logarithm of the product of the transition and prior terms of the
76 Metropolis-Hastings ratio for the acceptance of the proposed move.
77 `None` if not yet computed. If PRUNE, the log-ratio is negated."""
79 grow_var: UInt[Array, '*chains num_trees'] = field(chains=True)
80 """The decision axes of the new rules."""
82 grow_split: UInt[Array, '*chains num_trees'] = field(chains=True)
83 """The decision boundaries of the new rules."""
85 var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
86 """The updated decision axes of the trees, valid whatever move."""
88 affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
89 """A partially updated `affluence_tree`, marking non-leaf nodes that would
90 become leaves if the move was accepted. This mark initially (out of
91 `propose_moves`) takes into account if there would be available decision
92 rules to grow the leaf, and whether there are enough datapoints in the
93 node is instead checked later in `accept_moves_parallel_stage`."""
95 logu: Float32[Array, '*chains num_trees'] = field(chains=True)
96 """The logarithm of a uniform (0, 1] random variable to be used to
97 accept the move. It's in (-oo, 0]."""
99 acc: None | Bool[Array, '*chains num_trees'] = field(chains=True)
100 """Whether the move was accepted. `None` if not yet computed."""
102 to_prune: None | Bool[Array, '*chains num_trees'] = field(chains=True)
103 """Whether the final operation to apply the move is pruning. This indicates
104 an accepted prune move or a rejected grow move. `None` if not yet
105 computed."""
108@named_call
109def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves:
110 """
111 Propose moves for all the trees.
113 There are two types of moves: GROW (convert a leaf to a decision node and
114 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
115 leaf, deleting its children).
117 Parameters
118 ----------
119 key
120 A jax random key.
121 forest
122 The `forest` field of a BART MCMC state.
124 Returns
125 -------
126 The proposed move for each tree.
127 """
128 num_trees = forest.leaf_tree.shape[0] 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
129 keys = split(key, 2) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
130 grow_keys, prune_keys = keys.pop((2, num_trees)) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
132 # compute moves
133 grow_moves = propose_grow_moves( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
134 grow_keys,
135 forest.var_tree,
136 forest.split_tree,
137 forest.affluence_tree,
138 forest.max_split,
139 forest.blocked_vars,
140 forest.p_nonterminal,
141 forest.p_propose_grow,
142 forest.log_s,
143 )
144 prune_moves = propose_prune_moves( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
145 prune_keys,
146 forest.split_tree,
147 grow_moves.affluence_tree,
148 forest.p_nonterminal,
149 forest.p_propose_grow,
150 )
152 u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees)) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
154 # choose between grow or prune
155 p_grow = jnp.where( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
156 grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
157 )
158 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
160 # compute children indices
161 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
162 left, right = (node << 1) | jnp.arange(2)[:, None] 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
164 return Moves( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
165 allowed=grow_moves.allowed | prune_moves.allowed,
166 grow=grow,
167 num_growable=grow_moves.num_growable,
168 node=node,
169 left=left,
170 right=right,
171 partial_ratio=jnp.where(
172 grow, grow_moves.partial_ratio, prune_moves.partial_ratio
173 ),
174 log_trans_prior_ratio=None, # will be set in complete_ratio
175 grow_var=grow_moves.var,
176 grow_split=grow_moves.split,
177 # var_tree does not need to be updated if prune
178 var_tree=grow_moves.var_tree,
179 # affluence_tree is updated for both moves unconditionally, prune last
180 affluence_tree=prune_moves.affluence_tree,
181 logu=jnp.log1p(-exp1mlogu),
182 acc=None, # will be set in accept_moves_sequential_stage
183 to_prune=None, # will be set in accept_moves_sequential_stage
184 )
187class GrowMoves(Module):
188 """Represent a proposed grow move for each tree."""
190 allowed: Bool[Array, ' num_trees']
191 """Whether the move is allowed for proposal."""
193 num_growable: UInt[Array, ' num_trees']
194 """The number of leaves that can be proposed for grow."""
196 node: UInt[Array, ' num_trees']
197 """The index of the leaf to grow. ``2 ** d`` if there are no growable
198 leaves."""
200 var: UInt[Array, ' num_trees']
201 """The decision axis of the new rule."""
203 split: UInt[Array, ' num_trees']
204 """The decision boundary of the new rule."""
206 partial_ratio: Float32[Array, ' num_trees']
207 """A factor of the Metropolis-Hastings ratio of the move. It lacks
208 the likelihood ratio and the probability of proposing the prune
209 move."""
211 var_tree: UInt[Array, 'num_trees 2**(d-1)']
212 """The updated decision axes of the tree."""
214 affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
215 """A partially updated `affluence_tree` that marks each new leaf that
216 would be produced as `True` if it would have available decision rules."""
219@named_call
220@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None))
221def propose_grow_moves(
222 key: Key[Array, ' num_trees'],
223 var_tree: UInt[Array, 'num_trees 2**(d-1)'],
224 split_tree: UInt[Array, 'num_trees 2**(d-1)'],
225 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
226 max_split: UInt[Array, ' p'],
227 blocked_vars: Int32[Array, ' k'] | None,
228 p_nonterminal: Float32[Array, ' 2**d'],
229 p_propose_grow: Float32[Array, ' 2**(d-1)'],
230 log_s: Float32[Array, ' p'] | None,
231) -> GrowMoves:
232 """
233 Propose a GROW move for each tree.
235 A GROW move picks a leaf node and converts it to a non-terminal node with
236 two leaf children.
238 Parameters
239 ----------
240 key
241 A jax random key.
242 var_tree
243 The splitting axes of the tree.
244 split_tree
245 The splitting points of the tree.
246 affluence_tree
247 Whether each leaf has enough points to be grown.
248 max_split
249 The maximum split index for each variable.
250 blocked_vars
251 The indices of the variables that have no available cutpoints.
252 p_nonterminal
253 The a priori probability of a node to be nonterminal conditional on the
254 ancestors, including at the maximum depth where it should be zero.
255 p_propose_grow
256 The unnormalized probability of choosing a leaf to grow.
257 log_s
258 Unnormalized log-probability used to choose a variable to split on
259 amongst the available ones.
261 Returns
262 -------
263 An object representing the proposed move.
265 Notes
266 -----
267 The move is not proposed if each leaf is already at maximum depth, or has
268 less datapoints than the requested threshold `min_points_per_decision_node`,
269 or it does not have any available decision rules given its ancestors. This
270 is marked by setting `allowed` to `False` and `num_growable` to 0.
271 """
272 keys = split(key, 3) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
274 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
275 keys.pop(), split_tree, affluence_tree, p_propose_grow
276 )
278 # sample a decision rule
279 var, num_available_var = choose_variable( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
280 keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
281 )
282 split_idx, l, r = choose_split( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
283 keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
284 )
286 # determine if the new leaves would have available decision rules; if the
287 # move is blocked, these values may not make sense
288 leftright_growable = (num_available_var > 1) | jnp.stack( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
289 [l < split_idx, split_idx + 1 < r]
290 )
291 leftright = (leaf_to_grow << 1) | jnp.arange(2) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
292 affluence_tree = affluence_tree.at[leftright].set(leftright_growable) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
294 ratio = compute_partial_ratio( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
295 prob_choose, num_prunable, p_nonterminal, leaf_to_grow
296 )
298 return GrowMoves( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
299 allowed=num_growable > 0,
300 num_growable=num_growable,
301 node=leaf_to_grow,
302 var=var,
303 split=split_idx,
304 partial_ratio=ratio,
305 var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
306 affluence_tree=affluence_tree,
307 )
310def choose_leaf(
311 key: Key[Array, ''],
312 split_tree: UInt[Array, ' 2**(d-1)'],
313 affluence_tree: Bool[Array, ' 2**(d-1)'],
314 p_propose_grow: Float32[Array, ' 2**(d-1)'],
315) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
316 """
317 Choose a leaf node to grow in a tree.
319 Parameters
320 ----------
321 key
322 A jax random key.
323 split_tree
324 The splitting points of the tree.
325 affluence_tree
326 Whether a leaf has enough points that it could be split into two leaves
327 satisfying the `min_points_per_decision_node` requirement.
328 p_propose_grow
329 The unnormalized probability of choosing a leaf to grow.
331 Returns
332 -------
333 leaf_to_grow : Int32[Array, '']
334 The index of the leaf to grow. If ``num_growable == 0``, return
335 ``2 ** d``.
336 num_growable : Int32[Array, '']
337 The number of leaf nodes that can be grown, i.e., are nonterminal
338 and have at least twice `min_points_per_decision_node`.
339 prob_choose : Float32[Array, '']
340 The (normalized) probability that this function had to choose that
341 specific leaf, given the arguments.
342 num_prunable : Int32[Array, '']
343 The number of leaf parents that could be pruned, after converting the
344 selected leaf to a non-terminal node.
345 """
346 is_growable = growable_leaves(split_tree, affluence_tree) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
347 num_growable = jnp.count_nonzero(is_growable) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
348 distr = jnp.where(is_growable, p_propose_grow, 0) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
349 leaf_to_grow, distr_norm = categorical(key, distr) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
350 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
351 prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
352 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
353 num_prunable = jnp.count_nonzero(is_parent) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
354 return leaf_to_grow, num_growable, prob_choose, num_prunable 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
357def growable_leaves(
358 split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
359) -> Bool[Array, ' 2**(d-1)']:
360 """
361 Return a mask indicating the leaf nodes that can be proposed for growth.
363 The condition is that a leaf is not at the bottom level, has available
364 decision rules given its ancestors, and has at least
365 `min_points_per_decision_node` points.
367 Parameters
368 ----------
369 split_tree
370 The splitting points of the tree.
371 affluence_tree
372 Marks leaves that can be grown.
374 Returns
375 -------
376 The mask indicating the leaf nodes that can be proposed to grow.
378 Notes
379 -----
380 This function needs `split_tree` and not just `affluence_tree` because
381 `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
382 """
383 return grove.is_actual_leaf(split_tree) & affluence_tree 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
386def categorical(
387 key: Key[Array, ''], distr: Float32[Array, ' n']
388) -> tuple[Int32[Array, ''], Float32[Array, '']]:
389 """
390 Return a random integer from an arbitrary distribution.
392 Parameters
393 ----------
394 key
395 A jax random key.
396 distr
397 An unnormalized probability distribution.
399 Returns
400 -------
401 u : Int32[Array, '']
402 A random integer in the range ``[0, n)``. If all probabilities are zero,
403 return ``n``.
404 norm : Float32[Array, '']
405 The sum of `distr`.
407 Notes
408 -----
409 This function uses a cumsum instead of the Gumbel trick, so it's ok only
410 for small ranges with probabilities well greater than 0.
411 """
412 ecdf = jnp.cumsum(distr) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
413 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
414 return jnp.searchsorted(ecdf, u, 'right', method='compare_all'), ecdf[-1] 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
417def choose_variable(
418 key: Key[Array, ''],
419 var_tree: UInt[Array, ' 2**(d-1)'],
420 split_tree: UInt[Array, ' 2**(d-1)'],
421 max_split: UInt[Array, ' p'],
422 leaf_index: Int32[Array, ''],
423 blocked_vars: Int32[Array, ' k'] | None,
424 log_s: Float32[Array, ' p'] | None,
425) -> tuple[Int32[Array, ''], Int32[Array, '']]:
426 """
427 Choose a variable to split on for a new non-terminal node.
429 Parameters
430 ----------
431 key
432 A jax random key.
433 var_tree
434 The variable indices of the tree.
435 split_tree
436 The splitting points of the tree.
437 max_split
438 The maximum split index for each variable.
439 leaf_index
440 The index of the leaf to grow.
441 blocked_vars
442 The indices of the variables that have no available cutpoints. If
443 `None`, all variables are assumed unblocked.
444 log_s
445 The logarithm of the prior probability for choosing a variable. If
446 `None`, use a uniform distribution.
448 Returns
449 -------
450 var : Int32[Array, '']
451 The index of the variable to split on.
452 num_available_var : Int32[Array, '']
453 The number of variables with available decision rules `var` was chosen
454 from.
455 """
456 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
457 if blocked_vars is not None: 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
458 var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars]) 1ab
460 if log_s is None: 2a Y Z 0 eb1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X
461 return randint_exclude(key, max_split.size, var_to_ignore) 1acdefghijklmbnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
462 else:
463 return categorical_exclude(key, log_s, var_to_ignore) 2Y Z 0 eb1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . /
466def fully_used_variables(
467 var_tree: UInt[Array, ' 2**(d-1)'],
468 split_tree: UInt[Array, ' 2**(d-1)'],
469 max_split: UInt[Array, ' p'],
470 leaf_index: Int32[Array, ''],
471) -> UInt[Array, ' d-2']:
472 """
473 Find variables in the ancestors of a node that have an empty split range.
475 Parameters
476 ----------
477 var_tree
478 The variable indices of the tree.
479 split_tree
480 The splitting points of the tree.
481 max_split
482 The maximum split index for each variable.
483 leaf_index
484 The index of the node, assumed to be valid for `var_tree`.
486 Returns
487 -------
488 The indices of the variables that have an empty split range.
490 Notes
491 -----
492 The number of unused variables is not known in advance. Unused values in the
493 array are filled with `p`. The fill values are not guaranteed to be placed
494 in any particular order, and variables may appear more than once.
495 """
496 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
497 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
498 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
499 num_split = r - l 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
500 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
501 # the type of var_to_ignore is already sufficient to hold max_split.size,
502 # see ancestor_variables()
505def ancestor_variables(
506 var_tree: UInt[Array, ' 2**(d-1)'],
507 max_split: UInt[Array, ' p'],
508 node_index: Int32[Array, ''],
509) -> UInt[Array, ' d-2']:
510 """
511 Return the list of variables in the ancestors of a node.
513 Parameters
514 ----------
515 var_tree
516 The variable indices of the tree.
517 max_split
518 The maximum split index for each variable. Used only to get `p`.
519 node_index
520 The index of the node, assumed to be valid for `var_tree`.
522 Returns
523 -------
524 The variable indices of the ancestors of the node.
526 Notes
527 -----
528 The ancestors are the nodes going from the root to the parent of the node.
529 The number of ancestors is not known at tracing time; unused spots in the
530 output array are filled with `p`.
531 """
532 max_num_ancestors = grove.tree_depth(var_tree) - 1 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D fbgbhbibjbkbE F G H I J K L M N O P Q R S T U V W X
533 index = node_index >> jnp.arange(max_num_ancestors, 0, -1) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D fbgbhbibjbkbE F G H I J K L M N O P Q R S T U V W X
534 var = var_tree[index] 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D fbgbhbibjbkbE F G H I J K L M N O P Q R S T U V W X
535 var_type = minimal_unsigned_dtype(max_split.size) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D fbgbhbibjbkbE F G H I J K L M N O P Q R S T U V W X
536 p = jnp.array(max_split.size, var_type) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D fbgbhbibjbkbE F G H I J K L M N O P Q R S T U V W X
537 return jnp.where(index, var, p) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D fbgbhbibjbkbE F G H I J K L M N O P Q R S T U V W X
540def split_range(
541 var_tree: UInt[Array, ' 2**(d-1)'],
542 split_tree: UInt[Array, ' 2**(d-1)'],
543 max_split: UInt[Array, ' p'],
544 node_index: Int32[Array, ''],
545 ref_var: Int32[Array, ''],
546) -> tuple[Int32[Array, ''], Int32[Array, '']]:
547 """
548 Return the range of allowed splits for a variable at a given node.
550 Parameters
551 ----------
552 var_tree
553 The variable indices of the tree.
554 split_tree
555 The splitting points of the tree.
556 max_split
557 The maximum split index for each variable.
558 node_index
559 The index of the node, assumed to be valid for `var_tree`.
560 ref_var
561 The variable for which to measure the split range.
563 Returns
564 -------
565 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
566 """
567 max_num_ancestors = grove.tree_depth(var_tree) - 1 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX
568 index = node_index >> jnp.arange(max_num_ancestors) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX
569 right_child = (index & 1).astype(bool) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX
570 index >>= 1 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX
571 split = split_tree[index].astype(jnp.int32) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX
572 cond = (var_tree[index] == ref_var) & index.astype(bool) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX
573 l = jnp.max(split, initial=0, where=cond & right_child) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX
574 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX
575 jnp.int32
576 )
577 r = jnp.min(split, initial=initial_r, where=cond & ~right_child) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX
579 return l + 1, r 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOP;=?@[]^_`{QRSTUVWX
582def randint_exclude(
583 key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
584) -> tuple[Int32[Array, ''], Int32[Array, '']]:
585 """
586 Return a random integer in a range, excluding some values.
588 Parameters
589 ----------
590 key
591 A jax random key.
592 sup
593 The exclusive upper bound of the range, must be >= 1.
594 exclude
595 The values to exclude from the range. Values greater than or equal to
596 `sup` are ignored. Values can appear more than once.
598 Returns
599 -------
600 u : Int32[Array, '']
601 A random integer `u` in the range ``[0, sup)`` such that ``u not in
602 exclude``.
603 num_allowed : Int32[Array, '']
604 The number of integers in the range that were not excluded.
606 Notes
607 -----
608 If all values in the range are excluded, return `sup`.
609 """
610 exclude, num_allowed = _process_exclude(sup, exclude) 2a c d e f g h i j k l m b n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X
611 u = random.randint(key, (), 0, num_allowed) 2a c d e f g h i j k l m b n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X
612 u_shifted = u + jnp.arange(exclude.size) 2a c d e f g h i j k l m b n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X
613 u_shifted = jnp.minimum(u_shifted, sup - 1) 2a c d e f g h i j k l m b n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X
614 u += jnp.sum(u_shifted >= exclude) 2a c d e f g h i j k l m b n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X
615 return u, num_allowed 2a c d e f g h i j k l m b n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X
618def _process_exclude(
619 sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
620) -> tuple[Integer[Array, ' n'], Integer[Array, '']]:
621 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X
622 num_allowed = sup - jnp.sum(exclude < sup) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X
623 return exclude, num_allowed 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / n o p q r s t u v w x y z A B C D E F G H I J K L M N O P | } ~ abbbcbdbQ R S T U V W X
626def categorical_exclude(
627 key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
628) -> tuple[Int32[Array, ''], Int32[Array, '']]:
629 """
630 Draw from a categorical distribution, excluding a set of values.
632 Parameters
633 ----------
634 key
635 A jax random key.
636 logits
637 The unnormalized log-probabilities of each category.
638 exclude
639 The values to exclude from the range [0, k). Values greater than or
640 equal to `logits.size` are ignored. Values can appear more than once.
642 Returns
643 -------
644 u : Int32[Array, '']
645 A random integer in the range ``[0, k)`` such that ``u not in exclude``.
646 num_allowed : Int32[Array, '']
647 The number of integers in the range that were not excluded.
649 Notes
650 -----
651 If all values in the range are excluded, the result is unspecified.
652 """
653 exclude, num_allowed = _process_exclude(logits.size, exclude) 2Y Z 0 eb1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . /
654 kinda_neg_inf = jnp.finfo(logits.dtype).min 2Y Z 0 eb1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . /
655 logits = logits.at[exclude].set(kinda_neg_inf) 2Y Z 0 eb1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . /
656 u = random.categorical(key, logits) 2Y Z 0 eb1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . /
657 return u, num_allowed 2Y Z 0 eb1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + , - . /
660def choose_split(
661 key: Key[Array, ''],
662 var: Int32[Array, ''],
663 var_tree: UInt[Array, ' 2**(d-1)'],
664 split_tree: UInt[Array, ' 2**(d-1)'],
665 max_split: UInt[Array, ' p'],
666 leaf_index: Int32[Array, ''],
667) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
668 """
669 Choose a split point for a new non-terminal node.
671 Parameters
672 ----------
673 key
674 A jax random key.
675 var
676 The variable to split on.
677 var_tree
678 The splitting axes of the tree. Does not need to already contain `var`
679 at `leaf_index`.
680 split_tree
681 The splitting points of the tree.
682 max_split
683 The maximum split index for each variable.
684 leaf_index
685 The index of the leaf to grow.
687 Returns
688 -------
689 split : Int32[Array, '']
690 The cutpoint.
691 l : Int32[Array, '']
692 r : Int32[Array, '']
693 The integer range `split` was drawn from is [l, r).
695 Notes
696 -----
697 If `var` is out of bounds, or if the available split range on that variable
698 is empty, return 0.
699 """
700 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
701 return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
704def compute_partial_ratio(
705 prob_choose: Float32[Array, ''],
706 num_prunable: Int32[Array, ''],
707 p_nonterminal: Float32[Array, ' 2**d'],
708 leaf_to_grow: Int32[Array, ''],
709) -> Float32[Array, '']:
710 """
711 Compute the product of the transition and prior ratios of a grow move.
713 Parameters
714 ----------
715 prob_choose
716 The probability that the leaf had to be chosen amongst the growable
717 leaves.
718 num_prunable
719 The number of leaf parents that could be pruned, after converting the
720 leaf to be grown to a non-terminal node.
721 p_nonterminal
722 The a priori probability of each node being nonterminal conditional on
723 its ancestors.
724 leaf_to_grow
725 The index of the leaf to grow.
727 Returns
728 -------
729 The partial transition ratio times the prior ratio.
731 Notes
732 -----
733 The transition ratio is P(new tree => old tree) / P(old tree => new tree).
734 The "partial" transition ratio returned is missing the factor P(propose
735 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
736 "partial" prior ratio is missing the factor P(children are leaves).
737 """
738 # the two ratios also contain factors num_available_split *
739 # num_available_var * s[var], but they cancel out
741 # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
742 # computed here because they need the count trees, which are computed in the
743 # acceptance phase
745 prune_allowed = leaf_to_grow != 1 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
746 # prune allowed <---> the initial tree is not a root
747 # leaf to grow is root --> the tree can only be a root
748 # tree is a root --> the only leaf I can grow is root
749 p_grow = jnp.where(prune_allowed, 0.5, 1) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
750 inv_trans_ratio = p_grow * prob_choose * num_prunable 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
752 # .at.get because if leaf_to_grow is out of bounds (move not allowed), this
753 # would produce a 0 and then an inf when `complete_ratio` takes the log
754 pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
755 tree_ratio = pnt / (1 - pnt) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
757 return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
760class PruneMoves(Module):
761 """Represent a proposed prune move for each tree."""
763 allowed: Bool[Array, ' num_trees']
764 """Whether the move is possible."""
766 node: UInt[Array, ' num_trees']
767 """The index of the node to prune. ``2 ** d`` if no node can be pruned."""
769 partial_ratio: Float32[Array, ' num_trees']
770 """A factor of the Metropolis-Hastings ratio of the move. It lacks the
771 likelihood ratio, the probability of proposing the prune move, and the
772 prior probability that the children of the node to prune are leaves.
773 This ratio is inverted, and is meant to be inverted back in
774 `accept_move_and_sample_leaves`."""
776 affluence_tree: Bool[Array, 'num_trees 2**(d-1)']
777 """A partially updated `affluence_tree`, marking the node to prune as
778 growable."""
781@named_call
782@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
783def propose_prune_moves(
784 key: Key[Array, ''],
785 split_tree: UInt[Array, ' 2**(d-1)'],
786 affluence_tree: Bool[Array, ' 2**(d-1)'],
787 p_nonterminal: Float32[Array, ' 2**d'],
788 p_propose_grow: Float32[Array, ' 2**(d-1)'],
789) -> PruneMoves:
790 """
791 Tree structure prune move proposal of BART MCMC.
793 Parameters
794 ----------
795 key
796 A jax random key.
797 split_tree
798 The splitting points of the tree.
799 affluence_tree
800 Whether each leaf can be grown.
801 p_nonterminal
802 The a priori probability of a node to be nonterminal conditional on
803 the ancestors, including at the maximum depth where it should be zero.
804 p_propose_grow
805 The unnormalized probability of choosing a leaf to grow.
807 Returns
808 -------
809 An object representing the proposed moves.
810 """
811 node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
812 key, split_tree, affluence_tree, p_propose_grow
813 )
815 ratio = compute_partial_ratio( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
816 prob_choose, num_prunable, p_nonterminal, node_to_prune
817 )
819 return PruneMoves( 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
820 allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
821 node=node_to_prune,
822 partial_ratio=ratio,
823 affluence_tree=affluence_tree,
824 )
827def choose_leaf_parent(
828 key: Key[Array, ''],
829 split_tree: UInt[Array, ' 2**(d-1)'],
830 affluence_tree: Bool[Array, ' 2**(d-1)'],
831 p_propose_grow: Float32[Array, ' 2**(d-1)'],
832) -> tuple[
833 Int32[Array, ''],
834 Int32[Array, ''],
835 Float32[Array, ''],
836 Bool[Array, 'num_trees 2**(d-1)'],
837]:
838 """
839 Pick a non-terminal node with leaf children to prune in a tree.
841 Parameters
842 ----------
843 key
844 A jax random key.
845 split_tree
846 The splitting points of the tree.
847 affluence_tree
848 Whether a leaf has enough points to be grown.
849 p_propose_grow
850 The unnormalized probability of choosing a leaf to grow.
852 Returns
853 -------
854 node_to_prune : Int32[Array, '']
855 The index of the node to prune. If ``num_prunable == 0``, return
856 ``2 ** d``.
857 num_prunable : Int32[Array, '']
858 The number of leaf parents that could be pruned.
859 prob_choose : Float32[Array, '']
860 The (normalized) probability that `choose_leaf` would chose
861 `node_to_prune` as leaf to grow, if passed the tree where
862 `node_to_prune` had been pruned.
863 affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
864 A partially updated `affluence_tree`, marking the node to prune as
865 growable.
866 """
867 # sample a node to prune
868 is_prunable = grove.is_leaves_parent(split_tree) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
869 num_prunable = jnp.count_nonzero(is_prunable) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
870 node_to_prune = randint_masked(key, is_prunable) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
871 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
873 # compute stuff for reverse move
874 split_tree = split_tree.at[node_to_prune].set(0) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
875 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
876 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
877 distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
878 prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
879 prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1) 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
881 return node_to_prune, num_prunable, prob_choose, affluence_tree 1aYZ01c2d3e45f6789!#$%g'h(i)jkl*m+,b-./nopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWX
884def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']:
885 """
886 Return a random integer in a range, including only some values.
888 Parameters
889 ----------
890 key
891 A jax random key.
892 mask
893 The mask indicating the allowed values.
895 Returns
896 -------
897 A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
899 Notes
900 -----
901 If all values in the mask are `False`, return `n`. This function is
902 optimized for small `n`.
903 """
904 ecdf = jnp.cumsum(mask) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / lbn o p q r s t u v w x y z A B C D E F G H I J K L M N O P mbnbobpbQ R S T U V W X
905 u = random.randint(key, (), 0, ecdf[-1]) 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / lbn o p q r s t u v w x y z A B C D E F G H I J K L M N O P mbnbobpbQ R S T U V W X
906 return jnp.searchsorted(ecdf, u, 'right', method='compare_all') 2a Y Z 0 1 c 2 d 3 e 4 5 f 6 7 8 9 ! # $ % g ' h ( i ) j k l * m + , b - . / lbn o p q r s t u v w x y z A B C D E F G H I J K L M N O P mbnbobpbQ R S T U V W X