Coverage for src / bartz / grove.py: 100%
97 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
1# bartz/src/bartz/grove.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions to create and manipulate binary decision trees."""
27import math
28from functools import partial
29from typing import Protocol
31from jax import jit, lax, vmap
32from jax import numpy as jnp
33from jaxtyping import Array, Bool, DTypeLike, Float32, Int32, Shaped, UInt
35try:
36 from numpy.lib.array_utils import normalize_axis_tuple # numpy 2
37except ImportError:
38 from numpy.core.numeric import normalize_axis_tuple # numpy 1
40from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc
43class TreeHeaps(Protocol):
44 """A protocol for dataclasses that represent trees.
46 A tree is represented with arrays as a heap. The root node is at index 1.
47 The children nodes of a node at index :math:`i` are at indices :math:`2i`
48 (left child) and :math:`2i + 1` (right child). The array element at index 0
49 is unused.
51 Arrays may have additional initial axes to represent multple trees.
53 Parameters
54 ----------
55 leaf_tree
56 The values in the leaves of the trees. This array can be dirty, i.e.,
57 unused nodes can have whatever value. It may have an additional axis
58 for multivariate leaves.
59 var_tree
60 The axes along which the decision nodes operate. This array can be
61 dirty but for the always unused node at index 0 which must be set to 0.
62 split_tree
63 The decision boundaries of the trees. The boundaries are open on the
64 right, i.e., a point belongs to the left child iff x < split. Whether a
65 node is a leaf is indicated by the corresponding 'split' element being
66 0. Unused nodes also have split set to 0. This array can't be dirty.
68 Notes
69 -----
70 Since the nodes at the bottom can only be leaves and not decision nodes,
71 `var_tree` and `split_tree` are half as long as `leaf_tree`.
72 """
74 leaf_tree: (
75 Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d']
76 )
77 var_tree: UInt[Array, '*batch_shape 2**(d-1)']
78 split_tree: UInt[Array, '*batch_shape 2**(d-1)']
81def make_tree(
82 depth: int, dtype: DTypeLike, batch_shape: tuple[int, ...] = ()
83) -> Shaped[Array, '*batch_shape 2**{depth}']:
84 """
85 Make an array to represent a binary tree.
87 Parameters
88 ----------
89 depth
90 The maximum depth of the tree. Depth 1 means that there is only a root
91 node.
92 dtype
93 The dtype of the array.
94 batch_shape
95 The leading shape of the array, to represent multiple trees and/or
96 multivariate trees.
98 Returns
99 -------
100 An array of zeroes with the appropriate shape.
101 """
102 shape = (*batch_shape, 2**depth) 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %
103 return jnp.zeros(shape, dtype) 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %
106def tree_depth(tree: Shaped[Array, '*batch_shape 2**d']) -> int:
107 """
108 Return the maximum depth of a tree.
110 Parameters
111 ----------
112 tree
113 A tree created by `make_tree`. If the array is ND, the tree structure is
114 assumed to be along the last axis.
116 Returns
117 -------
118 The maximum depth of the tree.
119 """
120 return round(math.log2(tree.shape[-1])) 2I J a u v w B ' C ( b ) l * c q r s j K k L M F G d m n e N O z A o h i + f p g t P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdb(b)b*b+beb,bS T U V W X Y Z 0 1 2 3 4 5 -b.b/bfb:b;b=b?b@b[b6 7 8 9 ! # $ %
123def traverse_tree(
124 x: UInt[Array, ' p'],
125 var_tree: UInt[Array, ' 2**(d-1)'],
126 split_tree: UInt[Array, ' 2**(d-1)'],
127) -> UInt[Array, '']:
128 """
129 Find the leaf where a point falls into.
131 Parameters
132 ----------
133 x
134 The coordinates to evaluate the tree at.
135 var_tree
136 The decision axes of the tree.
137 split_tree
138 The decision boundaries of the tree.
140 Returns
141 -------
142 The index of the leaf.
143 """
144 carry = ( 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
145 jnp.zeros((), bool),
146 jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
147 )
149 def loop(carry, _): 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
150 leaf_found, index = carry 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
152 split = split_tree[index] 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
153 var = var_tree[index] 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
155 leaf_found |= split == 0 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
156 child_index = (index << 1) + (x[var] >= split) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
157 index = jnp.where(leaf_found, index, child_index) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
159 return (leaf_found, index), None 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
161 depth = tree_depth(var_tree) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
162 (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
163 return index 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
166@jit
167@partial(jnp.vectorize, excluded=(0,), signature='(hts),(hts)->(n)')
168@partial(vmap_nodoc, in_axes=(1, None, None))
169def traverse_forest(
170 X: UInt[Array, 'p n'],
171 var_trees: UInt[Array, '*forest_shape 2**(d-1)'],
172 split_trees: UInt[Array, '*forest_shape 2**(d-1)'],
173) -> UInt[Array, '*forest_shape n']:
174 """
175 Find the leaves where points falls into for each tree in a set.
177 Parameters
178 ----------
179 X
180 The coordinates to evaluate the trees at.
181 var_trees
182 The decision axes of the trees.
183 split_trees
184 The decision boundaries of the trees.
186 Returns
187 -------
188 The indices of the leaves.
189 """
190 return traverse_tree(X, var_trees, split_trees) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
193@partial(jit, static_argnames=('sum_batch_axis',))
194def evaluate_forest(
195 X: UInt[Array, 'p n'],
196 trees: TreeHeaps,
197 *,
198 sum_batch_axis: int | tuple[int, ...] = (),
199) -> (
200 Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n']
201):
202 """
203 Evaluate an ensemble of trees at an array of points.
205 Parameters
206 ----------
207 X
208 The coordinates to evaluate the trees at.
209 trees
210 The trees.
211 sum_batch_axis
212 The batch axes to sum over. By default, no summation is performed.
213 Note that negative indices count from the end of the batch dimensions,
214 the core dimensions n and k can't be summed over by this function.
216 Returns
217 -------
218 The (sum of) the values of the trees at the points in `X`.
219 """
220 indices: UInt[Array, '*forest_shape n']
221 indices = traverse_forest(X, trees.var_tree, trees.split_tree) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
223 is_mv = trees.leaf_tree.ndim != trees.var_tree.ndim 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
225 bc_indices: UInt[Array, '*forest_shape n 1'] | UInt[Array, '*forest_shape 1 n 1']
226 bc_indices = indices[..., None, :, None] if is_mv else indices[..., None] 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
228 bc_leaf_tree: (
229 Float32[Array, '*forest_shape 1 tree_size']
230 | Float32[Array, '*forest_shape k 1 tree_size']
231 )
232 bc_leaf_tree = ( 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
233 trees.leaf_tree[..., :, None, :] if is_mv else trees.leaf_tree[..., None, :]
234 )
236 bc_leaves: (
237 Float32[Array, '*forest_shape n 1'] | Float32[Array, '*forest_shape k n 1']
238 )
239 bc_leaves = jnp.take_along_axis(bc_leaf_tree, bc_indices, -1) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
241 leaves: Float32[Array, '*forest_shape n'] | Float32[Array, '*forest_shape k n']
242 leaves = jnp.squeeze(bc_leaves, -1) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
244 axis = normalize_axis_tuple(sum_batch_axis, trees.var_tree.ndim - 1) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
245 return jnp.sum(leaves, axis=axis) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt
248def is_actual_leaf(
249 split_tree: UInt[Array, ' 2**(d-1)'], *, add_bottom_level: bool = False
250) -> Bool[Array, ' 2**(d-1)'] | Bool[Array, ' 2**d']:
251 """
252 Return a mask indicating the leaf nodes in a tree.
254 Parameters
255 ----------
256 split_tree
257 The splitting points of the tree.
258 add_bottom_level
259 If True, the bottom level of the tree is also considered.
261 Returns
262 -------
263 The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True.
264 """
265 size = split_tree.size 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %
266 is_leaf = split_tree == 0 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %
267 if add_bottom_level: 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %
268 size *= 2 2a u v w b l c E F , G d m n e o h i + f p g x H y : 0b1b2b3b4b5beb6b7b8bfb9b!b#b$b%b'b
269 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 2a u v w b l c E F , G d m n e o h i + f p g x H y : 0b1b2b3b4b5beb6b7b8bfb9b!b#b$b%b'b
270 index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1)) 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %
271 parent_index = index >> 1 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %
272 parent_nonleaf = split_tree[parent_index].astype(bool) 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %
273 parent_nonleaf = parent_nonleaf.at[1].set(True) 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %
274 return is_leaf & parent_nonleaf 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %
277def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(d-1)']:
278 """
279 Return a mask indicating the nodes with leaf (and only leaf) children.
281 Parameters
282 ----------
283 split_tree
284 The decision boundaries of the tree.
286 Returns
287 -------
288 The mask indicating which nodes have leaf children.
289 """
290 index = jnp.arange( 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %
291 split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1)
292 )
293 left_index = index << 1 # left child 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %
294 right_index = left_index + 1 # right child 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %
295 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %
296 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %
297 is_not_leaf = split_tree.astype(bool) 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %
298 return is_not_leaf & left_leaf & right_leaf 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %
299 # the 0-th item has split == 0, so it's not counted
302def tree_depths(tree_size: int) -> Int32[Array, ' {tree_size}']:
303 """
304 Return the depth of each node in a binary tree.
306 Parameters
307 ----------
308 tree_size
309 The length of the tree array, i.e., 2 ** d.
311 Returns
312 -------
313 The depth of each node.
315 Notes
316 -----
317 The root node (index 1) has depth 0. The depth is the position of the most
318 significant non-zero bit in the index. The first element (the unused node)
319 is marked as depth 0.
320 """
321 depths = [] 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %
322 depth = 0 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %
323 for i in range(tree_size): 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %
324 if i == 2**depth: 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %
325 depth += 1 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %
326 depths.append(depth - 1) 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %
327 depths[0] = 0 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %
328 return jnp.array(depths, minimal_unsigned_dtype(max(depths))) 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %
331@partial(jnp.vectorize, signature='(half_tree_size)->(tree_size)')
332def is_used(
333 split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
334) -> Bool[Array, '*batch_shape 2**d']:
335 """
336 Return a mask indicating the used nodes in a tree.
338 Parameters
339 ----------
340 split_tree
341 The decision boundaries of the tree.
343 Returns
344 -------
345 A mask indicating which nodes are actually used.
346 """
347 internal_node = split_tree.astype(bool) 1abcEdefgy
348 internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)]) 1abcEdefgy
349 actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1abcEdefgy
350 return internal_node | actual_leaf 1abcEdefgy
353@jit
354def forest_fill(split_tree: UInt[Array, '*batch_shape 2**(d-1)']) -> Float32[Array, '']:
355 """
356 Return the fraction of used nodes in a set of trees.
358 Parameters
359 ----------
360 split_tree
361 The decision boundaries of the trees.
363 Returns
364 -------
365 Number of tree nodes over the maximum number that could be stored.
366 """
367 used = is_used(split_tree) 1abcEdefgy
368 count = jnp.count_nonzero(used) 1abcEdefgy
369 batch_size = split_tree.size // split_tree.shape[-1] 1abcEdefgy
370 return count / (used.size - batch_size) 1abcEdefgy
373@partial(jit, static_argnames=('p', 'sum_batch_axis'))
374def var_histogram(
375 p: int,
376 var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
377 split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
378 *,
379 sum_batch_axis: int | tuple[int, ...] = (),
380) -> Int32[Array, '*reduced_batch_shape {p}']:
381 """
382 Count how many times each variable appears in a tree.
384 Parameters
385 ----------
386 p
387 The number of variables (the maximum value that can occur in `var_tree`
388 is ``p - 1``).
389 var_tree
390 The decision axes of the tree.
391 split_tree
392 The decision boundaries of the tree.
393 sum_batch_axis
394 The batch axes to sum over. By default, no summation is performed. Note
395 that negative indices count from the end of the batch dimensions, the
396 core dimension p can't be summed over by this function.
398 Returns
399 -------
400 The histogram(s) of the variables used in the tree.
401 """
402 is_internal = split_tree.astype(bool) 1auvwBCbcjkde-./hifgtx
404 def scatter_add( 1auvwBCbcjkde-./hifgtx
405 var_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
406 is_internal: Bool[Array, '*summed_batch_axes half_tree_size'],
407 ) -> Int32[Array, ' p']:
408 return jnp.zeros(p, int).at[var_tree].add(is_internal) 1auvwBCbcjkde-./hifgtx
410 # vmap scatter_add over non-batched dims
411 batch_ndim = var_tree.ndim - 1 1auvwBCbcjkde-./hifgtx
412 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) 1auvwBCbcjkde-./hifgtx
413 for i in reversed(range(batch_ndim)): 1auvwBCbcjkde-./hifgtx
414 neg_i = i - var_tree.ndim 1auvwBCbcjkde-./hifgtx
415 if i not in axes: 1auvwBCbcjkde-./hifgtx
416 scatter_add = vmap(scatter_add, in_axes=neg_i) 1uvwBC-./hit
418 return scatter_add(var_tree, is_internal) 1auvwBCbcjkde-./hifgtx