Coverage for src / bartz / grove.py: 93%
184 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/grove.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions to create and manipulate binary decision trees."""
27import math
28from functools import partial
29from typing import Literal, Protocol
31from jax import jit, lax, vmap
32from jax import numpy as jnp
33from jaxtyping import Array, Bool, Float32, Int32, Shaped, UInt
35try:
36 from numpy.lib.array_utils import normalize_axis_tuple # numpy 2
37except ImportError:
38 from numpy.core.numeric import normalize_axis_tuple # numpy 1
40from bartz.jaxext import autobatch, minimal_unsigned_dtype, vmap_nodoc
43class TreeHeaps(Protocol):
44 """A protocol for dataclasses that represent trees.
46 A tree is represented with arrays as a heap. The root node is at index 1.
47 The children nodes of a node at index :math:`i` are at indices :math:`2i`
48 (left child) and :math:`2i + 1` (right child). The array element at index 0
49 is unused.
51 Since the nodes at the bottom can only be leaves and not decision nodes,
52 `var_tree` and `split_tree` are half as long as `leaf_tree`.
54 Arrays may have additional initial axes to represent multiple trees.
55 """
57 leaf_tree: (
58 Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d']
59 )
60 """The values in the leaves of the trees. This array can be dirty, i.e.,
61 unused nodes can have whatever value. It may have an additional axis
62 for multivariate leaves."""
64 var_tree: UInt[Array, '*batch_shape 2**(d-1)']
65 """The axes along which the decision nodes operate. This array can be
66 dirty but for the always unused node at index 0 which must be set to 0."""
68 split_tree: UInt[Array, '*batch_shape 2**(d-1)']
69 """The decision boundaries of the trees. The boundaries are open on the
70 right, i.e., a point belongs to the left child iff x < split. Whether a
71 node is a leaf is indicated by the corresponding 'split' element being
72 0. Unused nodes also have split set to 0. This array can't be dirty."""
75def tree_depth(tree: Shaped[Array, '*batch_shape 2**d']) -> int:
76 """
77 Return the maximum depth of a tree.
79 Parameters
80 ----------
81 tree
82 A tree array like those in a `TreeHeaps`. If the array is ND, the tree
83 structure is assumed to be along the last axis.
85 Returns
86 -------
87 The maximum depth of the tree.
88 """
89 return round(math.log2(tree.shape[-1])) 2M R S T E F G N . O / a : r ; b k c s d y z A l m p q n o e t B I u J f V W P Q v g j - h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib$b%b'b(bjb)b0 1 2 3 4 5 6 7 8 9 ! # *b+b,b-b.b/b:b;b=b?b$ % ' ( ) * + ,
92def traverse_tree(
93 x: UInt[Array, ' p'],
94 var_tree: UInt[Array, ' 2**(d-1)'],
95 split_tree: UInt[Array, ' 2**(d-1)'],
96) -> UInt[Array, '']:
97 """
98 Find the leaf where a point falls into.
100 Parameters
101 ----------
102 x
103 The coordinates to evaluate the tree at.
104 var_tree
105 The decision axes of the tree.
106 split_tree
107 The decision boundaries of the tree.
109 Returns
110 -------
111 The index of the leaf.
112 """
113 carry = ( 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
114 jnp.zeros((), bool),
115 jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
116 )
118 def loop( 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
119 carry: tuple[Bool[Array, ''], UInt[Array, '']], _: None
120 ) -> tuple[tuple[Bool[Array, ''], UInt[Array, '']], None]:
121 leaf_found, index = carry 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
123 split = split_tree[index] 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
124 var = var_tree[index] 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
126 leaf_found |= split == 0 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
127 child_index = (index << 1) + (x[var] >= split) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
128 index = jnp.where(leaf_found, index, child_index) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
130 return (leaf_found, index), None 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
132 depth = tree_depth(var_tree) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
133 (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
134 return index 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
137@jit
138@partial(jnp.vectorize, excluded=(0,), signature='(hts),(hts)->(n)')
139@partial(vmap_nodoc, in_axes=(1, None, None))
140def traverse_forest(
141 X: UInt[Array, 'p n'],
142 var_trees: UInt[Array, '*forest_shape 2**(d-1)'],
143 split_trees: UInt[Array, '*forest_shape 2**(d-1)'],
144) -> UInt[Array, '*forest_shape n']:
145 """
146 Find the leaves where points falls into for each tree in a set.
148 Parameters
149 ----------
150 X
151 The coordinates to evaluate the trees at.
152 var_trees
153 The decision axes of the trees.
154 split_trees
155 The decision boundaries of the trees.
157 Returns
158 -------
159 The indices of the leaves.
160 """
161 return traverse_tree(X, var_trees, split_trees) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
164@partial(jit, static_argnames=('sum_batch_axis',))
165def evaluate_forest(
166 X: UInt[Array, 'p n'],
167 trees: TreeHeaps,
168 *,
169 sum_batch_axis: int | tuple[int, ...] = (),
170) -> (
171 Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n']
172):
173 """
174 Evaluate an ensemble of trees at an array of points.
176 Parameters
177 ----------
178 X
179 The coordinates to evaluate the trees at.
180 trees
181 The trees.
182 sum_batch_axis
183 The batch axes to sum over. By default, no summation is performed.
184 Note that negative indices count from the end of the batch dimensions,
185 the core dimensions n and k can't be summed over by this function.
187 Returns
188 -------
189 The (sum of) the values of the trees at the points in `X`.
190 """
191 indices: UInt[Array, '*forest_shape n']
192 indices = traverse_forest(X, trees.var_tree, trees.split_tree) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
194 is_mv = trees.leaf_tree.ndim != trees.var_tree.ndim 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
196 bc_indices: UInt[Array, '*forest_shape n 1'] | UInt[Array, '*forest_shape 1 n 1']
197 bc_indices = indices[..., None, :, None] if is_mv else indices[..., None] 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
199 bc_leaf_tree: (
200 Float32[Array, '*forest_shape 1 tree_size']
201 | Float32[Array, '*forest_shape k 1 tree_size']
202 )
203 bc_leaf_tree = ( 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
204 trees.leaf_tree[..., :, None, :] if is_mv else trees.leaf_tree[..., None, :]
205 )
207 bc_leaves: (
208 Float32[Array, '*forest_shape n 1'] | Float32[Array, '*forest_shape k n 1']
209 )
210 bc_leaves = jnp.take_along_axis(bc_leaf_tree, bc_indices, -1) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
212 leaves: Float32[Array, '*forest_shape n'] | Float32[Array, '*forest_shape k n']
213 leaves = jnp.squeeze(bc_leaves, -1) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
215 axis = normalize_axis_tuple(sum_batch_axis, trees.var_tree.ndim - 1) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
216 return jnp.sum(leaves, axis=axis) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi
219def is_actual_leaf(
220 split_tree: UInt[Array, ' 2**(d-1)'], *, add_bottom_level: bool = False
221) -> Bool[Array, ' 2**(d-1)'] | Bool[Array, ' 2**d']:
222 """
223 Return a mask indicating the leaf nodes in a tree.
225 Parameters
226 ----------
227 split_tree
228 The splitting points of the tree.
229 add_bottom_level
230 If True, the bottom level of the tree is also considered.
232 Returns
233 -------
234 The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True.
235 """
236 size = split_tree.size 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,
237 is_leaf = split_tree == 0 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,
238 if add_bottom_level: 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,
239 size *= 2 2M R S T E F G a r b k c s d U l m n C o e t I u J f v g j - h w i L D = kbx lbmbjbnb
240 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 2M R S T E F G a r b k c s d U l m n C o e t I u J f v g j - h w i L D = kbx lbmbjbnb
241 index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1)) 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,
242 parent_index = index >> 1 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,
243 parent_nonleaf = split_tree[parent_index].astype(bool) 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,
244 parent_nonleaf = parent_nonleaf.at[1].set(True) 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,
245 return is_leaf & parent_nonleaf 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,
248def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(d-1)']:
249 """
250 Return a mask indicating the nodes with leaf (and only leaf) children.
252 Parameters
253 ----------
254 split_tree
255 The decision boundaries of the tree.
257 Returns
258 -------
259 The mask indicating which nodes have leaf children.
260 """
261 index = jnp.arange( 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,
262 split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1)
263 )
264 left_index = index << 1 # left child 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,
265 right_index = left_index + 1 # right child 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,
266 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,
267 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,
268 is_not_leaf = split_tree.astype(bool) 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,
269 return is_not_leaf & left_leaf & right_leaf 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,
270 # the 0-th item has split == 0, so it's not counted
273def tree_depths(tree_size: int) -> Int32[Array, ' {tree_size}']:
274 """
275 Return the depth of each node in a binary tree.
277 Parameters
278 ----------
279 tree_size
280 The length of the tree array, i.e., 2 ** d.
282 Returns
283 -------
284 The depth of each node.
286 Notes
287 -----
288 The root node (index 1) has depth 0. The depth is the position of the most
289 significant non-zero bit in the index. The first element (the unused node)
290 is marked as depth 0.
291 """
292 depths = [] 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,
293 depth = 0 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,
294 for i in range(tree_size): 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,
295 if i == 2**depth: 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,
296 depth += 1 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,
297 depths.append(depth - 1) 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,
298 depths[0] = 0 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,
299 return jnp.array(depths, minimal_unsigned_dtype(max(depths))) 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,
302@partial(jnp.vectorize, signature='(half_tree_size)->(tree_size)')
303def is_used(
304 split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
305) -> Bool[Array, '*batch_shape 2**d']:
306 """
307 Return a mask indicating the used nodes in a tree.
309 Parameters
310 ----------
311 split_tree
312 The decision boundaries of the tree.
314 Returns
315 -------
316 A mask indicating which nodes are actually used.
317 """
318 internal_node = split_tree.astype(bool) 2M #ba b c d U e f h i D
319 internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)]) 2M #ba b c d U e f h i D
320 actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 2M #ba b c d U e f h i D
321 return internal_node | actual_leaf 2M #ba b c d U e f h i D
324@jit
325def forest_fill(split_tree: UInt[Array, '*batch_shape 2**(d-1)']) -> Float32[Array, '']:
326 """
327 Return the fraction of used nodes in a set of trees.
329 Parameters
330 ----------
331 split_tree
332 The decision boundaries of the trees.
334 Returns
335 -------
336 Number of tree nodes over the maximum number that could be stored.
337 """
338 used = is_used(split_tree) 2M #ba b c d U e f h i D
339 count = jnp.count_nonzero(used) 2M #ba b c d U e f h i D
340 batch_size = split_tree.size // split_tree.shape[-1] 2M #ba b c d U e f h i D
341 return count / (used.size - batch_size) 2M #ba b c d U e f h i D
344@partial(jit, static_argnames=('p', 'sum_batch_axis'))
345def var_histogram(
346 p: int,
347 var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
348 split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
349 *,
350 sum_batch_axis: int | tuple[int, ...] = (),
351) -> Int32[Array, '*reduced_batch_shape {p}']:
352 """
353 Count how many times each variable appears in a tree.
355 Parameters
356 ----------
357 p
358 The number of variables (the maximum value that can occur in `var_tree`
359 is ``p - 1``).
360 var_tree
361 The decision axes of the tree.
362 split_tree
363 The decision boundaries of the tree.
364 sum_batch_axis
365 The batch axes to sum over. By default, no summation is performed. Note
366 that negative indices count from the end of the batch dimensions, the
367 core dimension p can't be summed over by this function.
369 Returns
370 -------
371 The histogram(s) of the variables used in the tree.
372 """
373 is_internal = split_tree.astype(bool) 1EFGNOabkcdlmef?@[gjhiLD
375 def scatter_add( 1EFGNOabkcdlmef?@[gjhiLD
376 var_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
377 is_internal: Bool[Array, '*summed_batch_axes half_tree_size'],
378 ) -> Int32[Array, ' p']:
379 return jnp.zeros(p, int).at[var_tree].add(is_internal) 1EFGNOabkcdlmef?@[gjhiLD
381 # vmap scatter_add over non-batched dims
382 batch_ndim = var_tree.ndim - 1 1EFGNOabkcdlmef?@[gjhiLD
383 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) 1EFGNOabkcdlmef?@[gjhiLD
384 for i in reversed(range(batch_ndim)): 1EFGNOabkcdlmef?@[gjhiLD
385 neg_i = i - var_tree.ndim 1EFGNOabkcdlmef?@[gjhiLD
386 if i not in axes: 1EFGNOabkcdlmef?@[gjhiLD
387 scatter_add = vmap(scatter_add, in_axes=neg_i) 1EFGNOk?@[gj
389 return scatter_add(var_tree, is_internal) 1EFGNOabkcdlmef?@[gjhiLD
392def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str:
393 """Convert a tree to a human-readable string.
395 Parameters
396 ----------
397 tree
398 A single tree to format.
399 print_all
400 If `True`, also print the contents of unused node slots in the arrays.
402 Returns
403 -------
404 A string representation of the tree.
405 """
406 tee = '├──' 1x
407 corner = '└──' 1x
408 join = '│ ' 1x
409 space = ' ' 1x
410 down = '┐' 1x
411 bottom = '╢' # '┨' # 1x
413 def traverse_tree( 1x
414 lines: list[str],
415 index: int,
416 depth: int,
417 indent: str,
418 first_indent: str,
419 next_indent: str,
420 unused: bool,
421 ) -> None:
422 if index >= len(tree.leaf_tree): 422 ↛ 423line 422 didn't jump to line 423 because the condition on line 422 was never true1x
423 return
425 var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item() 1x
426 split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item() 1x
428 is_leaf = split == 0 1x
429 left_child = 2 * index 1x
430 right_child = 2 * index + 1 1x
432 if print_all: 432 ↛ 433line 432 didn't jump to line 433 because the condition on line 432 was never true1x
433 if unused:
434 category = 'unused'
435 elif is_leaf:
436 category = 'leaf'
437 else:
438 category = 'decision'
439 node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})'
440 else:
441 assert not unused 1x
442 if is_leaf: 1x
443 node_str = f'{tree.leaf_tree[index]:#.2g}' 1x
444 else:
445 node_str = f'x{var} < {split}' 1x
447 if not is_leaf or (print_all and left_child < len(tree.leaf_tree)): 1x
448 link = down 1x
449 elif not print_all and left_child >= len(tree.leaf_tree): 1x
450 link = bottom 1x
451 else:
452 link = ' ' 1x
454 max_number = len(tree.leaf_tree) - 1 1x
455 ndigits = len(str(max_number)) 1x
456 number = str(index).rjust(ndigits) 1x
458 lines.append(f' {number} {indent}{first_indent}{link}{node_str}') 1x
460 indent += next_indent 1x
461 unused = unused or is_leaf 1x
463 if unused and not print_all: 1x
464 return 1x
466 traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused) 1x
467 traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused) 1x
469 lines = [] 1x
470 traverse_tree(lines, 1, 0, '', '', '', False) 1x
471 return '\n'.join(lines) 1x
474def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']:
475 """Measure the depth of the tree.
477 Parameters
478 ----------
479 split_tree
480 The cutpoints of the decision rules.
482 Returns
483 -------
484 The depth of the deepest leaf in the tree. The root is at depth 0.
485 """
486 # this could be done just with split_tree != 0
487 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1gj-
488 depth = tree_depths(is_leaf.size) 1gj-
489 depth = jnp.where(is_leaf, depth, 0) 1gj-
490 return jnp.max(depth) 1gj-
493@jit
494@partial(jnp.vectorize, signature='(nt,hts)->(d)')
495def forest_depth_distr(
496 split_tree: UInt[Array, '*batch_shape num_trees 2**(d-1)'],
497) -> Int32[Array, '*batch_shape d']:
498 """Histogram the depths of a set of trees.
500 Parameters
501 ----------
502 split_tree
503 The cutpoints of the decision rules of the trees.
505 Returns
506 -------
507 An integer vector where the i-th element counts how many trees have depth i.
508 """
509 depth = tree_depth(split_tree) + 1 1g
510 depths = vmap(tree_actual_depth)(split_tree) 1g
511 return jnp.bincount(depths, length=depth) 1g
514@partial(jit, static_argnames=('node_type', 'sum_batch_axis'))
515def points_per_node_distr(
516 X: UInt[Array, 'p n'],
517 var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
518 split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
519 node_type: Literal['leaf', 'leaf-parent'],
520 *,
521 sum_batch_axis: int | tuple[int, ...] = (),
522) -> Int32[Array, '*reduced_batch_shape n+1']:
523 """Histogram points-per-node counts in a set of trees.
525 Count how many nodes in a tree select each possible amount of points,
526 over a certain subset of nodes.
528 Parameters
529 ----------
530 X
531 The set of points to count.
532 var_tree
533 The variables of the decision rules.
534 split_tree
535 The cutpoints of the decision rules.
536 node_type
537 The type of nodes to consider. Can be:
539 'leaf'
540 Count only leaf nodes.
541 'leaf-parent'
542 Count only parent-of-leaf nodes.
543 sum_batch_axis
544 Aggregate the histogram over these batch axes, counting how many nodes
545 have each possible amount of points over subsets of trees instead of
546 in each tree separately.
548 Returns
549 -------
550 A vector where the i-th element counts how many nodes have i points.
551 """
552 batch_ndim = var_tree.ndim - 1 1pHqnCo
553 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) 1pHqnCo
555 def func( 1pHqnCo
556 var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
557 split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
558 ) -> Int32[Array, '*reduced_batch_shape n+1']:
559 indices: UInt[Array, '*batch_shape n']
560 indices = traverse_forest(X, var_tree, split_tree) 1pHqnCo
562 @partial(jnp.vectorize, signature='(hts),(n)->(ts_or_hts),(ts_or_hts)') 1pHqnCo
563 def count_points( 1pHqnCo
564 split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
565 indices: UInt[Array, '*batch_shape n'],
566 ) -> (
567 tuple[UInt[Array, '*batch_shape 2**d'], Bool[Array, '*batch_shape 2**d']]
568 | tuple[
569 UInt[Array, '*batch_shape 2**(d-1)'],
570 Bool[Array, '*batch_shape 2**(d-1)'],
571 ]
572 ):
573 if node_type == 'leaf-parent': 1pHqnCo
574 indices >>= 1 1pHq
575 predicate = is_leaves_parent(split_tree) 1pHq
576 elif node_type == 'leaf': 576 ↛ 579line 576 didn't jump to line 579 because the condition on line 576 was always true1nCo
577 predicate = is_actual_leaf(split_tree, add_bottom_level=True) 1nCo
578 else:
579 raise ValueError(node_type)
580 count_tree = jnp.zeros(predicate.size, int).at[indices].add(1).at[0].set(0) 1pHqnCo
581 return count_tree, predicate 1pHqnCo
583 count_tree, predicate = count_points(split_tree, indices) 1pHqnCo
585 def count_nodes( 1pHqnCo
586 count_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
587 predicate: Bool[Array, '*summed_batch_axes half_tree_size'],
588 ) -> Int32[Array, ' n+1']:
589 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(predicate) 1pHqnCo
591 # vmap count_nodes over non-batched dims
592 for i in reversed(range(batch_ndim)): 1pHqnCo
593 neg_i = i - var_tree.ndim 1pHqnCo
594 if i not in axes: 1pHqnCo
595 count_nodes = vmap(count_nodes, in_axes=neg_i) 1pHqnCo
597 return count_nodes(count_tree, predicate) 1pHqnCo
599 # automatically batch over all batch dimensions
600 max_io_nbytes = 2**27 # 128 MiB 1pHqnCo
601 out_dim_shift = len(axes) 1pHqnCo
602 for i in reversed(range(batch_ndim)): 1pHqnCo
603 if i in axes: 1pHqnCo
604 out_dim_shift -= 1 1pHqnCo
605 else:
606 func = autobatch(func, max_io_nbytes, i, i - out_dim_shift) 1pHqnCo
607 assert out_dim_shift == 0 1pHqnCo
609 return func(var_tree, split_tree) 1pHqnCo