Coverage for src / bartz / grove / _grove.py: 93%
187 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
1# bartz/src/bartz/grove/_grove.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions to create and manipulate binary decision trees."""
27import math
28from dataclasses import fields
29from functools import partial
30from typing import Literal, Protocol
32from equinox import Module
33from jax import jit, lax, vmap
34from jax import numpy as jnp
35from jaxtyping import Array, Bool, Float32, Int32, Shaped, UInt
36from numpy.lib.array_utils import normalize_axis_tuple
38from bartz.jaxext import autobatch, minimal_unsigned_dtype, vmap_nodoc
41class TreeHeaps(Protocol):
42 """A protocol for dataclasses that represent trees.
44 A tree is represented with arrays as a heap. The root node is at index 1.
45 The children nodes of a node at index :math:`i` are at indices :math:`2i`
46 (left child) and :math:`2i + 1` (right child). The array element at index 0
47 is unused.
49 Since the nodes at the bottom can only be leaves and not decision nodes,
50 `var_tree` and `split_tree` are half as long as `leaf_tree`.
52 Arrays may have additional initial axes to represent multiple trees.
53 """
55 leaf_tree: (
56 Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d']
57 )
58 """The values in the leaves of the trees. This array can be dirty, i.e.,
59 unused nodes can have whatever value. It may have an additional axis
60 for multivariate leaves."""
62 var_tree: UInt[Array, '*batch_shape 2**(d-1)']
63 """The axes along which the decision nodes operate. This array can be
64 dirty but for the always unused node at index 0 which must be set to 0."""
66 split_tree: UInt[Array, '*batch_shape 2**(d-1)']
67 """The decision boundaries of the trees. The boundaries are open on the
68 right, i.e., a point belongs to the left child iff x < split. Whether a
69 node is a leaf is indicated by the corresponding 'split' element being
70 0. Unused nodes also have split set to 0. This array can't be dirty."""
73class TreesTrace(Module):
74 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
76 leaf_tree: (
77 Float32[Array, '*trace_shape num_trees 2**d']
78 | Float32[Array, '*trace_shape num_trees k 2**d']
79 )
80 var_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
81 split_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
83 @classmethod
84 def from_dataclass(cls, obj: TreeHeaps) -> 'TreesTrace':
85 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
86 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) 1agc
89def tree_depth(tree: Shaped[Array, '*batch_shape 2**d']) -> int:
90 """
91 Return the maximum depth of a tree.
93 Parameters
94 ----------
95 tree
96 A tree array like those in a `TreeHeaps`. If the array is ND, the tree
97 structure is assumed to be along the last axis.
99 Returns
100 -------
101 The maximum depth of the tree.
102 """
103 return round(math.log2(tree.shape[-1])) 1ai
106def traverse_tree(
107 x: UInt[Array, ' p'],
108 var_tree: UInt[Array, ' 2**(d-1)'],
109 split_tree: UInt[Array, ' 2**(d-1)'],
110) -> UInt[Array, '']:
111 """
112 Find the leaf where a point falls into.
114 Parameters
115 ----------
116 x
117 The coordinates to evaluate the tree at.
118 var_tree
119 The decision axes of the tree.
120 split_tree
121 The decision boundaries of the tree.
123 Returns
124 -------
125 The index of the leaf.
126 """
127 carry = ( 1agc
128 jnp.zeros((), bool),
129 jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
130 )
132 def loop( 1agc
133 carry: tuple[Bool[Array, ''], UInt[Array, '']], _: None
134 ) -> tuple[tuple[Bool[Array, ''], UInt[Array, '']], None]:
135 leaf_found, index = carry 1agc
137 split = split_tree[index] 1agc
138 var = var_tree[index] 1agc
140 leaf_found |= split == 0 1agc
141 child_index = (index << 1) + (x[var] >= split) 1agc
142 index = jnp.where(leaf_found, index, child_index) 1agc
144 return (leaf_found, index), None 1agc
146 depth = tree_depth(var_tree) 1agc
147 (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16) 1agc
148 return index 1agc
151@jit
152@partial(jnp.vectorize, excluded=(0,), signature='(hts),(hts)->(n)')
153@partial(vmap_nodoc, in_axes=(1, None, None))
154def traverse_forest(
155 X: UInt[Array, 'p n'],
156 var_trees: UInt[Array, '*forest_shape 2**(d-1)'],
157 split_trees: UInt[Array, '*forest_shape 2**(d-1)'],
158) -> UInt[Array, '*forest_shape n']:
159 """
160 Find the leaves where points falls into for each tree in a set.
162 Parameters
163 ----------
164 X
165 The coordinates to evaluate the trees at.
166 var_trees
167 The decision axes of the trees.
168 split_trees
169 The decision boundaries of the trees.
171 Returns
172 -------
173 The indices of the leaves.
174 """
175 return traverse_tree(X, var_trees, split_trees) 1agc
178@partial(jit, static_argnames=('sum_batch_axis',))
179def evaluate_forest(
180 X: UInt[Array, 'p n'],
181 trees: TreeHeaps,
182 *,
183 sum_batch_axis: int | tuple[int, ...] = (),
184) -> (
185 Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n']
186):
187 """
188 Evaluate an ensemble of trees at an array of points.
190 Parameters
191 ----------
192 X
193 The coordinates to evaluate the trees at.
194 trees
195 The trees.
196 sum_batch_axis
197 The batch axes to sum over. By default, no summation is performed.
198 Note that negative indices count from the end of the batch dimensions,
199 the core dimensions n and k can't be summed over by this function.
201 Returns
202 -------
203 The (sum of) the values of the trees at the points in `X`.
204 """
205 indices: UInt[Array, '*forest_shape n']
206 indices = traverse_forest(X, trees.var_tree, trees.split_tree) 1agc
208 is_mv = trees.leaf_tree.ndim != trees.var_tree.ndim 1agc
210 bc_indices: UInt[Array, '*forest_shape n 1'] | UInt[Array, '*forest_shape 1 n 1']
211 bc_indices = indices[..., None, :, None] if is_mv else indices[..., None] 1algc
213 bc_leaf_tree: (
214 Float32[Array, '*forest_shape 1 tree_size']
215 | Float32[Array, '*forest_shape k 1 tree_size']
216 )
217 bc_leaf_tree = ( 1algc
218 trees.leaf_tree[..., :, None, :] if is_mv else trees.leaf_tree[..., None, :]
219 )
221 bc_leaves: (
222 Float32[Array, '*forest_shape n 1'] | Float32[Array, '*forest_shape k n 1']
223 )
224 bc_leaves = jnp.take_along_axis(bc_leaf_tree, bc_indices, -1) 1algc
226 leaves: Float32[Array, '*forest_shape n'] | Float32[Array, '*forest_shape k n']
227 leaves = jnp.squeeze(bc_leaves, -1) 1agc
229 axis = normalize_axis_tuple(sum_batch_axis, trees.var_tree.ndim - 1) 1agc
230 return jnp.sum(leaves, axis=axis) 1agc
233def is_actual_leaf(
234 split_tree: UInt[Array, ' 2**(d-1)'], *, add_bottom_level: bool = False
235) -> Bool[Array, ' 2**(d-1)'] | Bool[Array, ' 2**d']:
236 """
237 Return a mask indicating the leaf nodes in a tree.
239 Parameters
240 ----------
241 split_tree
242 The splitting points of the tree.
243 add_bottom_level
244 If True, the bottom level of the tree is also considered.
246 Returns
247 -------
248 The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True.
249 """
250 size = split_tree.size 1ai
251 is_leaf = split_tree == 0 1ai
252 if add_bottom_level: 1aim
253 size *= 2 1am
254 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 1am
255 index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1)) 1ai
256 parent_index = index >> 1 1ai
257 parent_nonleaf = split_tree[parent_index].astype(bool) 1ai
258 parent_nonleaf = parent_nonleaf.at[1].set(True) 1ai
259 return is_leaf & parent_nonleaf 1ai
262def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(d-1)']:
263 """
264 Return a mask indicating the nodes with leaf (and only leaf) children.
266 Parameters
267 ----------
268 split_tree
269 The decision boundaries of the tree.
271 Returns
272 -------
273 The mask indicating which nodes have leaf children.
274 """
275 index = jnp.arange( 1ai
276 split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1)
277 )
278 left_index = index << 1 # left child 1ai
279 right_index = left_index + 1 # right child 1ai
280 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 1ai
281 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 1ai
282 is_not_leaf = split_tree.astype(bool) 1ai
283 return is_not_leaf & left_leaf & right_leaf 1ai
284 # the 0-th item has split == 0, so it's not counted
287def tree_depths(tree_size: int) -> Int32[Array, ' {tree_size}']:
288 """
289 Return the depth of each node in a binary tree.
291 Parameters
292 ----------
293 tree_size
294 The length of the tree array, i.e., 2 ** d.
296 Returns
297 -------
298 The depth of each node.
300 Notes
301 -----
302 The root node (index 1) has depth 0. The depth is the position of the most
303 significant non-zero bit in the index. The first element (the unused node)
304 is marked as depth 0.
305 """
306 depths = [] 1aj
307 depth = 0 1aj
308 for i in range(tree_size): 1aj
309 if i == 2**depth: 1aj
310 depth += 1 1aj
311 depths.append(depth - 1) 1aj
312 depths[0] = 0 1aj
313 return jnp.array(depths, minimal_unsigned_dtype(max(depths))) 1aj
316@partial(jnp.vectorize, signature='(half_tree_size)->(tree_size)')
317def is_used(
318 split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
319) -> Bool[Array, '*batch_shape 2**d']:
320 """
321 Return a mask indicating the used nodes in a tree.
323 Parameters
324 ----------
325 split_tree
326 The decision boundaries of the tree.
328 Returns
329 -------
330 A mask indicating which nodes are actually used.
331 """
332 internal_node = split_tree.astype(bool) 1ahc
333 internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)]) 1ahc
334 actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ahc
335 return internal_node | actual_leaf 1ahc
338@jit
339def forest_fill(split_tree: UInt[Array, '*batch_shape 2**(d-1)']) -> Float32[Array, '']:
340 """
341 Return the fraction of used nodes in a set of trees.
343 Parameters
344 ----------
345 split_tree
346 The decision boundaries of the trees.
348 Returns
349 -------
350 Number of tree nodes over the maximum number that could be stored.
351 """
352 used = is_used(split_tree) 1ahc
353 count = jnp.count_nonzero(used) 1ahc
354 batch_size = split_tree.size // split_tree.shape[-1] 1ahc
355 return count / (used.size - batch_size) 1ahc
358@partial(jit, static_argnames=('p', 'sum_batch_axis'))
359def var_histogram(
360 p: int,
361 var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
362 split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
363 *,
364 sum_batch_axis: int | tuple[int, ...] = (),
365) -> Int32[Array, '*reduced_batch_shape {p}']:
366 """
367 Count how many times each variable appears in a tree.
369 Parameters
370 ----------
371 p
372 The number of variables (the maximum value that can occur in `var_tree`
373 is ``p - 1``).
374 var_tree
375 The decision axes of the tree.
376 split_tree
377 The decision boundaries of the tree.
378 sum_batch_axis
379 The batch axes to sum over. By default, no summation is performed. Note
380 that negative indices count from the end of the batch dimensions, the
381 core dimension p can't be summed over by this function.
383 Returns
384 -------
385 The histogram(s) of the variables used in the tree.
386 """
387 is_internal = split_tree.astype(bool) 1ahc
389 def scatter_add( 1ahc
390 var_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
391 is_internal: Bool[Array, '*summed_batch_axes half_tree_size'],
392 ) -> Int32[Array, ' p']:
393 return jnp.zeros(p, int).at[var_tree].add(is_internal) 1ahc
395 # vmap scatter_add over non-batched dims
396 batch_ndim = var_tree.ndim - 1 1ahc
397 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) 1ahc
398 for i in reversed(range(batch_ndim)): 1ahc
399 neg_i = i - var_tree.ndim 1ahc
400 if i not in axes: 1pahqc
401 scatter_add = vmap(scatter_add, in_axes=neg_i) 1pqc
403 return scatter_add(var_tree, is_internal) 1ahc
406def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str:
407 """Convert a tree to a human-readable string.
409 Parameters
410 ----------
411 tree
412 A single tree to format.
413 print_all
414 If `True`, also print the contents of unused node slots in the arrays.
416 Returns
417 -------
418 A string representation of the tree.
419 """
420 tee = '├──' 1d
421 corner = '└──' 1d
422 join = '│ ' 1d
423 space = ' ' 1d
424 down = '┐' 1d
425 bottom = '╢' # '┨' # 1d
427 def traverse_tree( 1d
428 lines: list[str],
429 index: int,
430 depth: int,
431 indent: str,
432 first_indent: str,
433 next_indent: str,
434 unused: bool,
435 ) -> None:
436 if index >= len(tree.leaf_tree): 436 ↛ 437line 436 didn't jump to line 437 because the condition on line 436 was never true1d
437 return
439 var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item() 1d
440 split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item() 1d
442 is_leaf = split == 0 1d
443 left_child = 2 * index 1d
444 right_child = 2 * index + 1 1d
446 if print_all: 446 ↛ 447line 446 didn't jump to line 447 because the condition on line 446 was never true1d
447 if unused:
448 category = 'unused'
449 elif is_leaf:
450 category = 'leaf'
451 else:
452 category = 'decision'
453 node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})'
454 else:
455 assert not unused 1d
456 if is_leaf: 1d
457 node_str = f'{tree.leaf_tree[index]:#.2g}' 1d
458 else:
459 node_str = f'x{var} < {split}' 1d
461 if not is_leaf or (print_all and left_child < len(tree.leaf_tree)): 1d
462 link = down 1d
463 elif not print_all and left_child >= len(tree.leaf_tree): 1d
464 link = bottom 1d
465 else:
466 link = ' ' 1d
468 max_number = len(tree.leaf_tree) - 1 1d
469 ndigits = len(str(max_number)) 1d
470 number = str(index).rjust(ndigits) 1d
472 lines.append(f' {number} {indent}{first_indent}{link}{node_str}') 1d
474 indent += next_indent 1d
475 unused = unused or is_leaf 1d
477 if unused and not print_all: 1d
478 return 1d
480 traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused) 1d
481 traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused) 1d
483 lines = [] 1d
484 traverse_tree(lines, 1, 0, '', '', '', False) 1d
485 return '\n'.join(lines) 1d
488def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']:
489 """Measure the depth of the tree.
491 Parameters
492 ----------
493 split_tree
494 The cutpoints of the decision rules.
496 Returns
497 -------
498 The depth of the deepest leaf in the tree. The root is at depth 0.
499 """
500 # this could be done just with split_tree != 0
501 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1k
502 depth = tree_depths(is_leaf.size) 1k
503 depth = jnp.where(is_leaf, depth, 0) 1k
504 return jnp.max(depth) 1k
507@jit
508@partial(jnp.vectorize, signature='(nt,hts)->(d)')
509def forest_depth_distr(
510 split_tree: UInt[Array, '*batch_shape num_trees 2**(d-1)'],
511) -> Int32[Array, '*batch_shape d']:
512 """Histogram the depths of a set of trees.
514 Parameters
515 ----------
516 split_tree
517 The cutpoints of the decision rules of the trees.
519 Returns
520 -------
521 An integer vector where the i-th element counts how many trees have depth i.
522 """
523 depth = tree_depth(split_tree) + 1 1k
524 depths = vmap(tree_actual_depth)(split_tree) 1k
525 return jnp.bincount(depths, length=depth) 1k
528@partial(jit, static_argnames=('node_type', 'sum_batch_axis'))
529def points_per_node_distr(
530 X: UInt[Array, 'p n'],
531 var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
532 split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
533 node_type: Literal['leaf', 'leaf-parent'],
534 *,
535 sum_batch_axis: int | tuple[int, ...] = (),
536) -> Int32[Array, '*reduced_batch_shape n+1']:
537 """Histogram points-per-node counts in a set of trees.
539 Count how many nodes in a tree select each possible amount of points,
540 over a certain subset of nodes.
542 Parameters
543 ----------
544 X
545 The set of points to count.
546 var_tree
547 The variables of the decision rules.
548 split_tree
549 The cutpoints of the decision rules.
550 node_type
551 The type of nodes to consider. Can be:
553 'leaf'
554 Count only leaf nodes.
555 'leaf-parent'
556 Count only parent-of-leaf nodes.
557 sum_batch_axis
558 Aggregate the histogram over these batch axes, counting how many nodes
559 have each possible amount of points over subsets of trees instead of
560 in each tree separately.
562 Returns
563 -------
564 A vector where the i-th element counts how many nodes have i points.
565 """
566 batch_ndim = var_tree.ndim - 1 1ef
567 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) 1ef
569 def func( 1ef
570 var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
571 split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
572 ) -> Int32[Array, '*reduced_batch_shape n+1']:
573 indices: UInt[Array, '*batch_shape n']
574 indices = traverse_forest(X, var_tree, split_tree) 1ef
576 @partial(jnp.vectorize, signature='(hts),(n)->(ts_or_hts),(ts_or_hts)') 1ef
577 def count_points( 1ef
578 split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
579 indices: UInt[Array, '*batch_shape n'],
580 ) -> (
581 tuple[UInt[Array, '*batch_shape 2**d'], Bool[Array, '*batch_shape 2**d']]
582 | tuple[
583 UInt[Array, '*batch_shape 2**(d-1)'],
584 Bool[Array, '*batch_shape 2**(d-1)'],
585 ]
586 ):
587 if node_type == 'leaf-parent': 1enfo
588 indices >>= 1 1ef
589 predicate = is_leaves_parent(split_tree) 1ef
590 elif node_type == 'leaf': 590 ↛ 593line 590 didn't jump to line 593 because the condition on line 590 was always true1no
591 predicate = is_actual_leaf(split_tree, add_bottom_level=True) 1no
592 else:
593 raise ValueError(node_type)
594 count_tree = jnp.zeros(predicate.size, int).at[indices].add(1).at[0].set(0) 1ef
595 return count_tree, predicate 1ef
597 count_tree, predicate = count_points(split_tree, indices) 1ef
599 def count_nodes( 1ef
600 count_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
601 predicate: Bool[Array, '*summed_batch_axes half_tree_size'],
602 ) -> Int32[Array, ' n+1']:
603 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(predicate) 1ef
605 # vmap count_nodes over non-batched dims
606 for i in reversed(range(batch_ndim)): 1ef
607 neg_i = i - var_tree.ndim 1ef
608 if i not in axes: 1ef
609 count_nodes = vmap(count_nodes, in_axes=neg_i) 1ef
611 return count_nodes(count_tree, predicate) 1ef
613 # automatically batch over all batch dimensions
614 max_io_nbytes = 2**27 # 128 MiB 1ef
615 out_dim_shift = len(axes) 1ef
616 for i in reversed(range(batch_ndim)): 1ef
617 if i in axes: 1ef
618 out_dim_shift -= 1 1ef
619 else:
620 func = autobatch(func, max_io_nbytes, i, i - out_dim_shift) 1ef
621 assert out_dim_shift == 0 1ef
623 return func(var_tree, split_tree) 1ef