Tree manipulation¶
Functions to create and manipulate binary decision trees.
- class bartz.grove.TreeHeaps(*args, **kwargs)[source]¶
A protocol for dataclasses that represent trees.
A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index \(i\) are at indices \(2i\) (left child) and \(2i + 1\) (right child). The array element at index 0 is unused.
Since the nodes at the bottom can only be leaves and not decision nodes,
var_treeandsplit_treeare half as long asleaf_tree.Arrays may have additional initial axes to represent multiple trees.
- leaf_tree: Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d']¶
The values in the leaves of the trees. This array can be dirty, i.e., unused nodes can have whatever value. It may have an additional axis for multivariate leaves.
- var_tree: UInt[Array, '*batch_shape 2**(d-1)']¶
The axes along which the decision nodes operate. This array can be dirty but for the always unused node at index 0 which must be set to 0.
- split_tree: UInt[Array, '*batch_shape 2**(d-1)']¶
The decision boundaries of the trees. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding ‘split’ element being 0. Unused nodes also have split set to 0. This array can’t be dirty.
- bartz.grove.tree_depth(tree)[source]¶
Return the maximum depth of a tree.
- Parameters:
tree (
Shaped[Array, '*batch_shape 2**d']) – A tree array like those in aTreeHeaps. If the array is ND, the tree structure is assumed to be along the last axis.- Returns:
int– The maximum depth of the tree.
- bartz.grove.traverse_tree(x, var_tree, split_tree)[source]¶
Find the leaf where a point falls into.
- Parameters:
x (
UInt[Array, 'p']) – The coordinates to evaluate the tree at.var_tree (
UInt[Array, '2**(d-1)']) – The decision axes of the tree.split_tree (
UInt[Array, '2**(d-1)']) – The decision boundaries of the tree.
- Returns:
UInt[Array, '']– The index of the leaf.
- bartz.grove.traverse_forest(X: UInt[Array, 'p n'], var_trees: UInt[Array, '*forest_shape 2**(d-1)'], split_trees: UInt[Array, '*forest_shape 2**(d-1)']) UInt[Array, '*forest_shape n'][source]¶
Find the leaves where points falls into for each tree in a set.
- Parameters:
X (
UInt[Array, 'p n']) – The coordinates to evaluate the trees at.var_trees (
UInt[Array, '*forest_shape 2**(d-1)']) – The decision axes of the trees.split_trees (
UInt[Array, '*forest_shape 2**(d-1)']) – The decision boundaries of the trees.
- Returns:
UInt[Array, '*forest_shape n']– The indices of the leaves.
- bartz.grove.evaluate_forest(X, trees, *, sum_batch_axis=())[source]¶
Evaluate an ensemble of trees at an array of points.
- Parameters:
X (
UInt[Array, 'p n']) – The coordinates to evaluate the trees at.trees (
TreeHeaps) – The trees.sum_batch_axis (
DTypeLike[int,KeyPath[int,...]], default:()) – The batch axes to sum over. By default, no summation is performed. Note that negative indices count from the end of the batch dimensions, the core dimensions n and k can’t be summed over by this function.
- Returns:
DTypeLike[Float32[Array, '*reduced_batch_size n'],Float32[Array, '*reduced_batch_size k n']] – The (sum of) the values of the trees at the points inX.
- bartz.grove.is_actual_leaf(split_tree, *, add_bottom_level=False)[source]¶
Return a mask indicating the leaf nodes in a tree.
- Parameters:
split_tree (
UInt[Array, '2**(d-1)']) – The splitting points of the tree.add_bottom_level (
bool, default:False) – If True, the bottom level of the tree is also considered.
- Returns:
DTypeLike[Bool[Array, '2**(d-1)'],Bool[Array, '2**d']] – The mask marking the leaf nodes. Length doubled ifadd_bottom_levelis True.
- bartz.grove.is_leaves_parent(split_tree)[source]¶
Return a mask indicating the nodes with leaf (and only leaf) children.
- Parameters:
split_tree (
UInt[Array, '2**(d-1)']) – The decision boundaries of the tree.- Returns:
Bool[Array, '2**(d-1)']– The mask indicating which nodes have leaf children.
- bartz.grove.tree_depths(tree_size)[source]¶
Return the depth of each node in a binary tree.
- Parameters:
tree_size (
int) – The length of the tree array, i.e., 2 ** d.- Returns:
Int32[Array, '{tree_size}']– The depth of each node.
Notes
The root node (index 1) has depth 0. The depth is the position of the most significant non-zero bit in the index. The first element (the unused node) is marked as depth 0.
- bartz.grove.is_used(split_tree)[source]¶
Return a mask indicating the used nodes in a tree.
- Parameters:
split_tree (
UInt[Array, '*batch_shape 2**(d-1)']) – The decision boundaries of the tree.- Returns:
Bool[Array, '*batch_shape 2**d']– A mask indicating which nodes are actually used.
- bartz.grove.forest_fill(split_tree)[source]¶
Return the fraction of used nodes in a set of trees.
- Parameters:
split_tree (
UInt[Array, '*batch_shape 2**(d-1)']) – The decision boundaries of the trees.- Returns:
Float32[Array, '']– Number of tree nodes over the maximum number that could be stored.
- bartz.grove.var_histogram(p, var_tree, split_tree, *, sum_batch_axis=())[source]¶
Count how many times each variable appears in a tree.
- Parameters:
p (
int) – The number of variables (the maximum value that can occur invar_treeisp - 1).var_tree (
UInt[Array, '*batch_shape 2**(d-1)']) – The decision axes of the tree.split_tree (
UInt[Array, '*batch_shape 2**(d-1)']) – The decision boundaries of the tree.sum_batch_axis (
DTypeLike[int,KeyPath[int,...]], default:()) – The batch axes to sum over. By default, no summation is performed. Note that negative indices count from the end of the batch dimensions, the core dimension p can’t be summed over by this function.
- Returns:
Int32[Array, '*reduced_batch_shape {p}']– The histogram(s) of the variables used in the tree.
- bartz.grove.format_tree(tree, *, print_all=False)[source]¶
Convert a tree to a human-readable string.
- Parameters:
tree (
TreeHeaps) – A single tree to format.print_all (
bool, default:False) – IfTrue, also print the contents of unused node slots in the arrays.
- Returns:
str– A string representation of the tree.
- bartz.grove.tree_actual_depth(split_tree)[source]¶
Measure the depth of the tree.
- Parameters:
split_tree (
UInt[Array, '2**(d-1)']) – The cutpoints of the decision rules.- Returns:
Int32[Array, '']– The depth of the deepest leaf in the tree. The root is at depth 0.
- bartz.grove.forest_depth_distr(split_tree)[source]¶
Histogram the depths of a set of trees.
- Parameters:
split_tree (
UInt[Array, '*batch_shape num_trees 2**(d-1)']) – The cutpoints of the decision rules of the trees.- Returns:
Int32[Array, '*batch_shape d']– An integer vector where the i-th element counts how many trees have depth i.
- bartz.grove.points_per_node_distr(X, var_tree, split_tree, node_type, *, sum_batch_axis=())[source]¶
Histogram points-per-node counts in a set of trees.
Count how many nodes in a tree select each possible amount of points, over a certain subset of nodes.
- Parameters:
X (
UInt[Array, 'p n']) – The set of points to count.var_tree (
UInt[Array, '*batch_shape 2**(d-1)']) – The variables of the decision rules.split_tree (
UInt[Array, '*batch_shape 2**(d-1)']) – The cutpoints of the decision rules.node_type (
Literal['leaf','leaf-parent']) –The type of nodes to consider. Can be:
- ’leaf’
Count only leaf nodes.
- ’leaf-parent’
Count only parent-of-leaf nodes.
sum_batch_axis (
DTypeLike[int,KeyPath[int,...]], default:()) – Aggregate the histogram over these batch axes, counting how many nodes have each possible amount of points over subsets of trees instead of in each tree separately.
- Returns:
Int32[Array, '*reduced_batch_shape n+1']– A vector where the i-th element counts how many nodes have i points.