Tree manipulation

Functions to create and manipulate binary decision trees.

class bartz.grove.TreeHeaps(*args, **kwargs)[source]

A protocol for dataclasses that represent trees.

A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index \(i\) are at indices \(2i\) (left child) and \(2i + 1\) (right child). The array element at index 0 is unused.

Since the nodes at the bottom can only be leaves and not decision nodes, var_tree and split_tree are half as long as leaf_tree.

Arrays may have additional initial axes to represent multiple trees.

leaf_tree: Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d']

The values in the leaves of the trees. This array can be dirty, i.e., unused nodes can have whatever value. It may have an additional axis for multivariate leaves.

var_tree: UInt[Array, '*batch_shape 2**(d-1)']

The axes along which the decision nodes operate. This array can be dirty but for the always unused node at index 0 which must be set to 0.

split_tree: UInt[Array, '*batch_shape 2**(d-1)']

The decision boundaries of the trees. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding ‘split’ element being 0. Unused nodes also have split set to 0. This array can’t be dirty.

bartz.grove.make_tree(depth, dtype, batch_shape=())[source]

Make an array to represent a binary tree.

Parameters:
  • depth (int) – The maximum depth of the tree. Depth 1 means that there is only a root node.

  • dtype (str | type[Any] | dtype | SupportsDType) – The dtype of the array.

  • batch_shape (tuple[int, ...], default: ()) – The leading shape of the array, to represent multiple trees and/or multivariate trees.

Returns:

Shaped[Array, '*batch_shape 2**{depth}']An array of zeroes with the appropriate shape.

bartz.grove.tree_depth(tree)[source]

Return the maximum depth of a tree.

Parameters:

tree (Shaped[Array, '*batch_shape 2**d']) – A tree created by make_tree. If the array is ND, the tree structure is assumed to be along the last axis.

Returns:

intThe maximum depth of the tree.

bartz.grove.traverse_tree(x, var_tree, split_tree)[source]

Find the leaf where a point falls into.

Parameters:
  • x (UInt[Array, 'p']) – The coordinates to evaluate the tree at.

  • var_tree (UInt[Array, '2**(d-1)']) – The decision axes of the tree.

  • split_tree (UInt[Array, '2**(d-1)']) – The decision boundaries of the tree.

Returns:

UInt[Array, '']The index of the leaf.

bartz.grove.traverse_forest(X: UInt[Array, 'p n'], var_trees: UInt[Array, '*forest_shape 2**(d-1)'], split_trees: UInt[Array, '*forest_shape 2**(d-1)']) UInt[Array, '*forest_shape n'][source]

Find the leaves where points falls into for each tree in a set.

Parameters:
  • X (UInt[Array, 'p n']) – The coordinates to evaluate the trees at.

  • var_trees (UInt[Array, '*forest_shape 2**(d-1)']) – The decision axes of the trees.

  • split_trees (UInt[Array, '*forest_shape 2**(d-1)']) – The decision boundaries of the trees.

Returns:

UInt[Array, '*forest_shape n']The indices of the leaves.

bartz.grove.evaluate_forest(X, trees, *, sum_batch_axis=())[source]

Evaluate an ensemble of trees at an array of points.

Parameters:
  • X (UInt[Array, 'p n']) – The coordinates to evaluate the trees at.

  • trees (TreeHeaps) – The trees.

  • sum_batch_axis (int | tuple[int, ...], default: ()) – The batch axes to sum over. By default, no summation is performed. Note that negative indices count from the end of the batch dimensions, the core dimensions n and k can’t be summed over by this function.

Returns:

Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n'] – The (sum of) the values of the trees at the points in X.

bartz.grove.is_actual_leaf(split_tree, *, add_bottom_level=False)[source]

Return a mask indicating the leaf nodes in a tree.

Parameters:
  • split_tree (UInt[Array, '2**(d-1)']) – The splitting points of the tree.

  • add_bottom_level (bool, default: False) – If True, the bottom level of the tree is also considered.

Returns:

Bool[Array, '2**(d-1)'] | Bool[Array, '2**d'] – The mask marking the leaf nodes. Length doubled if add_bottom_level is True.

bartz.grove.is_leaves_parent(split_tree)[source]

Return a mask indicating the nodes with leaf (and only leaf) children.

Parameters:

split_tree (UInt[Array, '2**(d-1)']) – The decision boundaries of the tree.

Returns:

Bool[Array, '2**(d-1)']The mask indicating which nodes have leaf children.

bartz.grove.tree_depths(tree_size)[source]

Return the depth of each node in a binary tree.

Parameters:

tree_size (int) – The length of the tree array, i.e., 2 ** d.

Returns:

Int32[Array, '{tree_size}']The depth of each node.

Notes

The root node (index 1) has depth 0. The depth is the position of the most significant non-zero bit in the index. The first element (the unused node) is marked as depth 0.

bartz.grove.is_used(split_tree)[source]

Return a mask indicating the used nodes in a tree.

Parameters:

split_tree (UInt[Array, '*batch_shape 2**(d-1)']) – The decision boundaries of the tree.

Returns:

Bool[Array, '*batch_shape 2**d']A mask indicating which nodes are actually used.

bartz.grove.forest_fill(split_tree)[source]

Return the fraction of used nodes in a set of trees.

Parameters:

split_tree (UInt[Array, '*batch_shape 2**(d-1)']) – The decision boundaries of the trees.

Returns:

Float32[Array, '']Number of tree nodes over the maximum number that could be stored.

bartz.grove.var_histogram(p, var_tree, split_tree, *, sum_batch_axis=())[source]

Count how many times each variable appears in a tree.

Parameters:
  • p (int) – The number of variables (the maximum value that can occur in var_tree is p - 1).

  • var_tree (UInt[Array, '*batch_shape 2**(d-1)']) – The decision axes of the tree.

  • split_tree (UInt[Array, '*batch_shape 2**(d-1)']) – The decision boundaries of the tree.

  • sum_batch_axis (int | tuple[int, ...], default: ()) – The batch axes to sum over. By default, no summation is performed. Note that negative indices count from the end of the batch dimensions, the core dimension p can’t be summed over by this function.

Returns:

Int32[Array, '*reduced_batch_shape {p}']The histogram(s) of the variables used in the tree.