Coverage for src/bartz/grove/_grove.py: 94%
194 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/grove/_grove.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions to create and manipulate binary decision trees."""
27import math
28from dataclasses import fields
29from functools import partial
30from typing import Literal, Protocol, runtime_checkable
32from equinox import tree_at
33from jax import numpy as jnp
34from jax import vmap
35from jaxtyping import Array, Bool, Float32, Int32, Shaped, UInt
36from numpy.lib.array_utils import normalize_axis_tuple
38from bartz._jaxext import Module, autobatch, jit, minimal_unsigned_dtype, vmap_nodoc
41@runtime_checkable
42class TreeHeaps(Protocol):
43 """A protocol for dataclasses that represent trees.
45 A tree is represented with arrays as a heap. The root node is at index 1.
46 The children nodes of a node at index :math:`i` are at indices :math:`2i`
47 (left child) and :math:`2i + 1` (right child). The array element at index 0
48 is unused.
50 Since the nodes at the bottom can only be leaves and not decision nodes,
51 `var_tree` and `split_tree` are half as long as `leaf_tree`.
53 Arrays may have additional initial axes to represent multiple trees.
54 """
56 leaf_tree: (
57 Float32[Array, '*batch_shape 2*half_tree_size']
58 | Float32[Array, '*batch_shape k 2*half_tree_size']
59 )
60 """The values in the leaves of the trees. This array can be dirty, i.e.,
61 unused nodes can have whatever value. It may have an additional axis
62 for multivariate leaves."""
64 var_tree: UInt[Array, '*batch_shape half_tree_size']
65 """The axes along which the decision nodes operate. This array can be
66 dirty but for the always unused node at index 0 which must be set to 0."""
68 split_tree: UInt[Array, '*batch_shape half_tree_size']
69 """The decision boundaries of the trees. The boundaries are open on the
70 right, i.e., a point belongs to the left child iff x < split. Whether a
71 node is a leaf is indicated by the corresponding 'split' element being
72 0. Unused nodes also have split set to 0. This array can't be dirty."""
75def is_multivariate(trees: TreeHeaps) -> bool:
76 """
77 Return whether the trees have vector-valued leaves.
79 Parameters
80 ----------
81 trees
82 The trees to inspect.
84 Returns
85 -------
86 Whether the leaves are vector-valued (an extra `k` axis on `leaf_tree`).
87 """
88 return trees.leaf_tree.ndim > trees.var_tree.ndim
91class TreesTrace(Module):
92 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
94 # `var_tree`/`split_tree` are declared before `leaf_tree` so their single
95 # (union-free) annotations bind the variadic `*batch_shape` axis first;
96 # otherwise the runtime typechecker (which evaluates union members in a
97 # hash-randomized order) can mis-bind it against the `k` axis of
98 # `leaf_tree`'s union for a multivariate tree (the layouts are
99 # rank-ambiguous). See `bartz.mcmcstep._state.Forest`. The leaf-bearing axis
100 # is `2*half_tree_size` rather than `tree_size`, so the half-of-leaf
101 # relationship is still checked here: `half_tree_size` is bound first by the
102 # anchors, then `leaf_tree` is checked against twice it.
103 var_tree: UInt[Array, '*batch_shape half_tree_size']
104 """The axes along which the decision nodes operate. This array can be
105 dirty but for the always unused node at index 0 which must be set to 0."""
107 split_tree: UInt[Array, '*batch_shape half_tree_size']
108 """The decision boundaries of the trees. The boundaries are open on the
109 right, i.e., a point belongs to the left child iff x < split. Whether a
110 node is a leaf is indicated by the corresponding 'split' element being
111 0. Unused nodes also have split set to 0. This array can't be dirty."""
113 leaf_tree: (
114 Float32[Array, '*batch_shape 2*half_tree_size']
115 | Float32[Array, '*batch_shape k 2*half_tree_size']
116 )
117 """The values in the leaves of the trees. This array can be dirty, i.e.,
118 unused nodes can have whatever value. It may have an additional axis
119 for multivariate leaves."""
121 @classmethod
122 def from_dataclass(cls, obj: TreeHeaps) -> 'TreesTrace':
123 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
124 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)})
126 def axes_from_dataclass(self, obj: TreeHeaps) -> 'TreesTrace':
127 """Project the per-field vmap axis specs of `obj` onto this template.
129 `self` supplies the (array) pytree; the same-named fields of `obj`
130 (axis specs, i.e. ints or `None`) replace its leaves. Built with
131 `equinox.tree_at`, which bypasses the type-checked `__init__`, so the
132 deliberately off-type axis values are allowed.
133 """
134 names = [f.name for f in fields(type(self))]
135 return tree_at(
136 lambda t: [getattr(t, name) for name in names],
137 self,
138 [getattr(obj, name) for name in names],
139 )
142def tree_depth(tree: Shaped[Array, '*batch_shape tree_size']) -> int:
143 """
144 Return the maximum depth of a tree.
146 Parameters
147 ----------
148 tree
149 A tree array like those in a `TreeHeaps`. If the array is ND, the tree
150 structure is assumed to be along the last axis.
152 Returns
153 -------
154 The maximum depth of the tree.
155 """
156 return round(math.log2(tree.shape[-1]))
159def traverse_tree(
160 x: UInt[Array, ' p'],
161 var_tree: UInt[Array, ' half_tree_size'],
162 split_tree: UInt[Array, ' half_tree_size'],
163) -> UInt[Array, '']:
164 """
165 Find the leaf where a point falls into.
167 Parameters
168 ----------
169 x
170 The coordinates to evaluate the tree at.
171 var_tree
172 The decision axes of the tree.
173 split_tree
174 The decision boundaries of the tree.
176 Returns
177 -------
178 The index of the leaf.
179 """
180 leaf_found = jnp.zeros((), bool)
181 index = jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1))
183 # the depth is a small static integer, so a plain python loop is equivalent
184 # to (and clearer than) a fully-unrolled lax.scan
185 for _ in range(tree_depth(var_tree)):
186 split = split_tree[index]
187 var = var_tree[index]
189 leaf_found |= split == 0
190 child_index = (index << 1) + (x[var] >= split)
191 index = jnp.where(leaf_found, index, child_index)
193 return index
196@jit
197def traverse_forest(
198 X: UInt[Array, 'p n'],
199 var_trees: UInt[Array, '*forest_shape half_tree_size'],
200 split_trees: UInt[Array, '*forest_shape half_tree_size'],
201) -> UInt[Array, '*forest_shape n']:
202 """
203 Find the leaves where points falls into for each tree in a set.
205 Parameters
206 ----------
207 X
208 The coordinates to evaluate the trees at.
209 var_trees
210 The decision axes of the trees.
211 split_trees
212 The decision boundaries of the trees.
214 Returns
215 -------
216 The indices of the leaves.
217 """
218 return _traverse_forest(X, var_trees, split_trees)
221@partial(jnp.vectorize, excluded=(0,), signature='(hts),(hts)->(n)')
222@partial(vmap_nodoc, in_axes=(1, None, None))
223def _traverse_forest(
224 X: UInt[Array, ' p'],
225 var_trees: UInt[Array, ' half_tree_size'],
226 split_trees: UInt[Array, ' half_tree_size'],
227) -> UInt[Array, '']:
228 """Implement `traverse_forest`."""
229 return traverse_tree(X, var_trees, split_trees)
232@jit(static_argnames=('sum_batch_axis',))
233def evaluate_forest(
234 X: UInt[Array, 'p n'],
235 trees: TreeHeaps,
236 *,
237 sum_batch_axis: int | tuple[int, ...] = (),
238) -> (
239 Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n']
240):
241 """
242 Evaluate an ensemble of trees at an array of points.
244 Parameters
245 ----------
246 X
247 The coordinates to evaluate the trees at.
248 trees
249 The trees.
250 sum_batch_axis
251 The batch axes to sum over. By default, no summation is performed.
252 Note that negative indices count from the end of the batch dimensions,
253 the core dimensions n and k can't be summed over by this function.
255 Returns
256 -------
257 The (sum of) the values of the trees at the points in `X`.
258 """
259 indices: UInt[Array, '*forest_shape n']
260 indices = traverse_forest(X, trees.var_tree, trees.split_tree)
262 is_mv = is_multivariate(trees)
264 bc_indices: UInt[Array, '*forest_shape n 1'] | UInt[Array, '*forest_shape 1 n 1']
265 bc_indices = indices[..., None, :, None] if is_mv else indices[..., None]
267 bc_leaf_tree: (
268 Float32[Array, '*forest_shape 1 tree_size']
269 | Float32[Array, '*forest_shape k 1 tree_size']
270 )
271 bc_leaf_tree = (
272 trees.leaf_tree[..., :, None, :] if is_mv else trees.leaf_tree[..., None, :]
273 )
275 bc_leaves: (
276 Float32[Array, '*forest_shape n 1'] | Float32[Array, '*forest_shape k n 1']
277 )
278 bc_leaves = jnp.take_along_axis(bc_leaf_tree, bc_indices, -1)
280 leaves: Float32[Array, '*forest_shape n'] | Float32[Array, '*forest_shape k n']
281 leaves = jnp.squeeze(bc_leaves, -1)
283 axis = normalize_axis_tuple(sum_batch_axis, trees.var_tree.ndim - 1)
284 return jnp.sum(leaves, axis=axis)
287def is_actual_leaf(
288 split_tree: UInt[Array, ' half_tree_size'], *, add_bottom_level: bool = False
289) -> Bool[Array, ' half_tree_size'] | Bool[Array, ' 2*half_tree_size']:
290 """
291 Return a mask indicating the leaf nodes in a tree.
293 Parameters
294 ----------
295 split_tree
296 The splitting points of the tree.
297 add_bottom_level
298 If True, the bottom level of the tree is also considered.
300 Returns
301 -------
302 The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True.
303 """
304 size = split_tree.size
305 is_leaf = split_tree == 0
306 if add_bottom_level:
307 size *= 2
308 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
309 index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
310 parent_index = index >> 1
311 parent_nonleaf = split_tree[parent_index].astype(bool)
312 parent_nonleaf = parent_nonleaf.at[1].set(True)
313 return is_leaf & parent_nonleaf
316def is_leaves_parent(
317 split_tree: UInt[Array, ' half_tree_size'],
318) -> Bool[Array, ' half_tree_size']:
319 """
320 Return a mask indicating the nodes with leaf (and only leaf) children.
322 Parameters
323 ----------
324 split_tree
325 The decision boundaries of the tree.
327 Returns
328 -------
329 The mask indicating which nodes have leaf children.
330 """
331 index = jnp.arange(
332 split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1)
333 )
334 left_index = index << 1 # left child
335 right_index = left_index + 1 # right child
336 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0
337 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0
338 is_not_leaf = split_tree.astype(bool)
339 return is_not_leaf & left_leaf & right_leaf
340 # the 0-th item has split == 0, so it's not counted
343def tree_depths(tree_size: int) -> UInt[Array, ' {tree_size}']:
344 """
345 Return the depth of each node in a binary tree.
347 Parameters
348 ----------
349 tree_size
350 The length of the tree array, i.e., 2 ** d.
352 Returns
353 -------
354 The depth of each node.
356 Notes
357 -----
358 The root node (index 1) has depth 0. The depth is the position of the most
359 significant non-zero bit in the index. The first element (the unused node)
360 is marked as depth 0.
361 """
362 depths = []
363 depth = 0
364 for i in range(tree_size):
365 if i == 2**depth:
366 depth += 1
367 depths.append(depth - 1)
368 depths[0] = 0
369 return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
372@jit
373def forest_mean_leaves(
374 split_tree: UInt[Array, '*batch_shape half_tree_size'],
375) -> Float32[Array, '']:
376 """
377 Return the average number of leaves per tree in a set of trees.
379 Parameters
380 ----------
381 split_tree
382 The decision boundaries of the trees.
384 Returns
385 -------
386 The mean number of leaves across the trees.
387 """
388 # a tree with k internal nodes (the nonzero entries of split_tree) has k + 1
389 # leaves; the maximum possible is split_tree.shape[-1]
390 num_internal = jnp.count_nonzero(split_tree, axis=-1)
391 return (num_internal + 1).mean()
394@jit(static_argnames=('p', 'sum_batch_axis'))
395def var_histogram(
396 p: int,
397 var_tree: UInt[Array, '*batch_shape half_tree_size'],
398 split_tree: UInt[Array, '*batch_shape half_tree_size'],
399 *,
400 sum_batch_axis: int | tuple[int, ...] = (),
401) -> Int32[Array, '*reduced_batch_shape {p}']:
402 """
403 Count how many times each variable appears in a tree.
405 Parameters
406 ----------
407 p
408 The number of variables (the maximum value that can occur in `var_tree`
409 is ``p - 1``).
410 var_tree
411 The decision axes of the tree.
412 split_tree
413 The decision boundaries of the tree.
414 sum_batch_axis
415 The batch axes to sum over. By default, no summation is performed. Note
416 that negative indices count from the end of the batch dimensions, the
417 core dimension p can't be summed over by this function.
419 Returns
420 -------
421 The histogram(s) of the variables used in the tree.
422 """
423 is_internal = split_tree.astype(bool)
425 def scatter_add(
426 var_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
427 is_internal: Bool[Array, '*summed_batch_axes half_tree_size'],
428 ) -> Int32[Array, ' p']:
429 return jnp.zeros(p, int).at[var_tree].add(is_internal)
431 # vmap scatter_add over non-batched dims
432 batch_ndim = var_tree.ndim - 1
433 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim)
434 for i in reversed(range(batch_ndim)):
435 neg_i = i - var_tree.ndim
436 if i not in axes:
437 scatter_add = vmap(scatter_add, in_axes=neg_i)
439 return scatter_add(var_tree, is_internal)
442def _format_leaf(leaf: Float32[Array, ''] | Float32[Array, ' k'], is_mv: bool) -> str:
443 """Format a (possibly multivariate) leaf value to 2 significant digits."""
444 if is_mv:
445 return '[' + ', '.join(f'{v:#.2g}' for v in leaf) + ']'
446 return f'{leaf:#.2g}'
449def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str:
450 """Convert a tree to a human-readable string.
452 Parameters
453 ----------
454 tree
455 A single tree to format.
456 print_all
457 If `True`, also print the contents of unused node slots in the arrays.
459 Returns
460 -------
461 A string representation of the tree.
462 """
463 tee = '├──'
464 corner = '└──'
465 join = '│ '
466 space = ' '
467 down = '┐'
468 bottom = '╢' # '┨' #
470 *_, tree_size = tree.leaf_tree.shape
471 is_mv = is_multivariate(tree)
473 def traverse_tree(
474 lines: list[str],
475 index: int,
476 depth: int,
477 indent: str,
478 first_indent: str,
479 next_indent: str,
480 unused: bool,
481 ) -> None:
482 if index >= tree_size: 482 ↛ 483line 482 didn't jump to line 483 because the condition on line 482 was never true
483 return
485 var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item()
486 split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item()
488 is_leaf = split == 0
489 left_child = 2 * index
490 right_child = 2 * index + 1
492 if print_all: 492 ↛ 493line 492 didn't jump to line 493 because the condition on line 492 was never true
493 if unused:
494 category = 'unused'
495 elif is_leaf:
496 category = 'leaf'
497 else:
498 category = 'decision'
499 node_str = f'{category}({var}, {split}, {tree.leaf_tree[..., index]})'
500 else:
501 assert not unused
502 if is_leaf:
503 node_str = _format_leaf(tree.leaf_tree[..., index], is_mv)
504 else:
505 node_str = f'x{var} < {split}'
507 if not is_leaf or (print_all and left_child < tree_size):
508 link = down
509 elif not print_all and left_child >= tree_size:
510 link = bottom
511 else:
512 link = ' '
514 max_number = tree_size - 1
515 ndigits = len(str(max_number))
516 number = str(index).rjust(ndigits)
518 lines.append(f' {number} {indent}{first_indent}{link}{node_str}')
520 indent += next_indent
521 unused = unused or is_leaf
523 if unused and not print_all:
524 return
526 traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused)
527 traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused)
529 lines = []
530 traverse_tree(lines, 1, 0, '', '', '', False)
531 return '\n'.join(lines)
534def tree_actual_depth(split_tree: UInt[Array, ' half_tree_size']) -> UInt[Array, '']:
535 """Measure the depth of the tree.
537 Parameters
538 ----------
539 split_tree
540 The cutpoints of the decision rules.
542 Returns
543 -------
544 The depth of the deepest leaf in the tree. The root is at depth 0.
545 """
546 # this could be done just with split_tree != 0
547 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True)
548 depth = tree_depths(is_leaf.size)
549 depth = jnp.where(is_leaf, depth, 0)
550 return jnp.max(depth)
553@jit
554@partial(jnp.vectorize, signature='(nt,hts)->(d)')
555def forest_depth_distr(
556 split_tree: UInt[Array, '*batch_shape num_trees half_tree_size'],
557) -> Int32[Array, '*batch_shape d']:
558 """Histogram the depths of a set of trees.
560 Parameters
561 ----------
562 split_tree
563 The cutpoints of the decision rules of the trees.
565 Returns
566 -------
567 An integer vector where the i-th element counts how many trees have depth i.
568 """
569 depth = tree_depth(split_tree) + 1
570 depths = vmap(tree_actual_depth)(split_tree)
571 return jnp.bincount(depths, length=depth)
574@jit(static_argnames=('node_type', 'sum_batch_axis'))
575def points_per_node_distr(
576 X: UInt[Array, 'p n'],
577 var_tree: UInt[Array, '*batch_shape half_tree_size'],
578 split_tree: UInt[Array, '*batch_shape half_tree_size'],
579 node_type: Literal['leaf', 'leaf-parent'],
580 *,
581 sum_batch_axis: int | tuple[int, ...] = (),
582) -> Int32[Array, '*reduced_batch_shape n+1']:
583 """Histogram points-per-node counts in a set of trees.
585 Count how many nodes in a tree select each possible amount of points,
586 over a certain subset of nodes.
588 Parameters
589 ----------
590 X
591 The set of points to count.
592 var_tree
593 The variables of the decision rules.
594 split_tree
595 The cutpoints of the decision rules.
596 node_type
597 The type of nodes to consider. Can be:
599 'leaf'
600 Count only leaf nodes.
601 'leaf-parent'
602 Count only parent-of-leaf nodes.
603 sum_batch_axis
604 Aggregate the histogram over these batch axes, counting how many nodes
605 have each possible amount of points over subsets of trees instead of
606 in each tree separately.
608 Returns
609 -------
610 A vector where the i-th element counts how many nodes have i points.
611 """
612 batch_ndim = var_tree.ndim - 1
613 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim)
615 def func(
616 var_tree: UInt[Array, '*batch_shape half_tree_size'],
617 split_tree: UInt[Array, '*batch_shape half_tree_size'],
618 ) -> Int32[Array, '*reduced_batch_shape n_plus_1']:
619 indices: UInt[Array, '*batch_shape n']
620 indices = traverse_forest(X, var_tree, split_tree)
622 @partial(jnp.vectorize, signature='(hts),(n)->(ts_or_hts),(ts_or_hts)')
623 def count_points(
624 split_tree: UInt[Array, '*batch_shape half_tree_size'],
625 indices: UInt[Array, '*batch_shape n'],
626 ) -> (
627 tuple[
628 Int32[Array, '*batch_shape 2*half_tree_size'],
629 Bool[Array, '*batch_shape 2*half_tree_size'],
630 ]
631 | tuple[
632 Int32[Array, '*batch_shape half_tree_size'],
633 Bool[Array, '*batch_shape half_tree_size'],
634 ]
635 ):
636 if node_type == 'leaf-parent':
637 indices >>= 1
638 predicate = is_leaves_parent(split_tree)
639 elif node_type == 'leaf': 639 ↛ 642line 639 didn't jump to line 642 because the condition on line 639 was always true
640 predicate = is_actual_leaf(split_tree, add_bottom_level=True)
641 else:
642 raise ValueError(node_type)
643 count_tree = jnp.zeros(predicate.size, int).at[indices].add(1).at[0].set(0)
644 return count_tree, predicate
646 count_tree, predicate = count_points(split_tree, indices)
648 def count_nodes(
649 count_tree: Int32[Array, '*summed_batch_axes half_tree_size'],
650 predicate: Bool[Array, '*summed_batch_axes half_tree_size'],
651 ) -> Int32[Array, ' n_plus_1']:
652 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(predicate)
654 # vmap count_nodes over non-batched dims
655 for i in reversed(range(batch_ndim)):
656 neg_i = i - var_tree.ndim
657 if i not in axes:
658 count_nodes = vmap(count_nodes, in_axes=neg_i)
660 return count_nodes(count_tree, predicate)
662 # automatically batch over all batch dimensions
663 max_io_nbytes = 2**27 # 128 MiB
664 out_dim_shift = len(axes)
665 batched_func = func
666 for i in reversed(range(batch_ndim)):
667 if i in axes:
668 out_dim_shift -= 1
669 else:
670 batched_func = autobatch(batched_func, max_io_nbytes, i, i - out_dim_shift)
671 assert out_dim_shift == 0
673 return batched_func(var_tree, split_tree)