Tree manipulation

Functions to create and manipulate binary decision trees.

class bartz.grove.TreeHeaps(*args, **kwargs)[source]

A protocol for dataclasses that represent trees.

A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index \(i\) are at indices \(2i\) (left child) and \(2i + 1\) (right child). The array element at index 0 is unused.

Since the nodes at the bottom can only be leaves and not decision nodes, var_tree and split_tree are half as long as leaf_tree.

Arrays may have additional initial axes to represent multiple trees.

leaf_tree: Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d']

The values in the leaves of the trees. This array can be dirty, i.e., unused nodes can have whatever value. It may have an additional axis for multivariate leaves.

var_tree: UInt[Array, '*batch_shape 2**(d-1)']

The axes along which the decision nodes operate. This array can be dirty but for the always unused node at index 0 which must be set to 0.

split_tree: UInt[Array, '*batch_shape 2**(d-1)']

The decision boundaries of the trees. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding ‘split’ element being 0. Unused nodes also have split set to 0. This array can’t be dirty.

bartz.grove.tree_depth(tree)[source]

Return the maximum depth of a tree.

Parameters:

tree (Shaped[Array, '*batch_shape 2**d']) – A tree array like those in a TreeHeaps. If the array is ND, the tree structure is assumed to be along the last axis.

Returns:

intThe maximum depth of the tree.

bartz.grove.traverse_tree(x, var_tree, split_tree)[source]

Find the leaf where a point falls into.

Parameters:
  • x (UInt[Array, 'p']) – The coordinates to evaluate the tree at.

  • var_tree (UInt[Array, '2**(d-1)']) – The decision axes of the tree.

  • split_tree (UInt[Array, '2**(d-1)']) – The decision boundaries of the tree.

Returns:

UInt[Array, '']The index of the leaf.

bartz.grove.traverse_forest(X: UInt[Array, 'p n'], var_trees: UInt[Array, '*forest_shape 2**(d-1)'], split_trees: UInt[Array, '*forest_shape 2**(d-1)']) UInt[Array, '*forest_shape n'][source]

Find the leaves where points falls into for each tree in a set.

Parameters:
  • X (UInt[Array, 'p n']) – The coordinates to evaluate the trees at.

  • var_trees (UInt[Array, '*forest_shape 2**(d-1)']) – The decision axes of the trees.

  • split_trees (UInt[Array, '*forest_shape 2**(d-1)']) – The decision boundaries of the trees.

Returns:

UInt[Array, '*forest_shape n']The indices of the leaves.

bartz.grove.evaluate_forest(X, trees, *, sum_batch_axis=())[source]

Evaluate an ensemble of trees at an array of points.

Parameters:
  • X (UInt[Array, 'p n']) – The coordinates to evaluate the trees at.

  • trees (TreeHeaps) – The trees.

  • sum_batch_axis (DTypeLike[int, KeyPath[int, ...]], default: ()) – The batch axes to sum over. By default, no summation is performed. Note that negative indices count from the end of the batch dimensions, the core dimensions n and k can’t be summed over by this function.

Returns:

DTypeLike[Float32[Array, '*reduced_batch_size n'], Float32[Array, '*reduced_batch_size k n']] – The (sum of) the values of the trees at the points in X.

bartz.grove.is_actual_leaf(split_tree, *, add_bottom_level=False)[source]

Return a mask indicating the leaf nodes in a tree.

Parameters:
  • split_tree (UInt[Array, '2**(d-1)']) – The splitting points of the tree.

  • add_bottom_level (bool, default: False) – If True, the bottom level of the tree is also considered.

Returns:

DTypeLike[Bool[Array, '2**(d-1)'], Bool[Array, '2**d']] – The mask marking the leaf nodes. Length doubled if add_bottom_level is True.

bartz.grove.is_leaves_parent(split_tree)[source]

Return a mask indicating the nodes with leaf (and only leaf) children.

Parameters:

split_tree (UInt[Array, '2**(d-1)']) – The decision boundaries of the tree.

Returns:

Bool[Array, '2**(d-1)']The mask indicating which nodes have leaf children.

bartz.grove.tree_depths(tree_size)[source]

Return the depth of each node in a binary tree.

Parameters:

tree_size (int) – The length of the tree array, i.e., 2 ** d.

Returns:

Int32[Array, '{tree_size}']The depth of each node.

Notes

The root node (index 1) has depth 0. The depth is the position of the most significant non-zero bit in the index. The first element (the unused node) is marked as depth 0.

bartz.grove.is_used(split_tree)[source]

Return a mask indicating the used nodes in a tree.

Parameters:

split_tree (UInt[Array, '*batch_shape 2**(d-1)']) – The decision boundaries of the tree.

Returns:

Bool[Array, '*batch_shape 2**d']A mask indicating which nodes are actually used.

bartz.grove.forest_fill(split_tree)[source]

Return the fraction of used nodes in a set of trees.

Parameters:

split_tree (UInt[Array, '*batch_shape 2**(d-1)']) – The decision boundaries of the trees.

Returns:

Float32[Array, '']Number of tree nodes over the maximum number that could be stored.

bartz.grove.var_histogram(p, var_tree, split_tree, *, sum_batch_axis=())[source]

Count how many times each variable appears in a tree.

Parameters:
  • p (int) – The number of variables (the maximum value that can occur in var_tree is p - 1).

  • var_tree (UInt[Array, '*batch_shape 2**(d-1)']) – The decision axes of the tree.

  • split_tree (UInt[Array, '*batch_shape 2**(d-1)']) – The decision boundaries of the tree.

  • sum_batch_axis (DTypeLike[int, KeyPath[int, ...]], default: ()) – The batch axes to sum over. By default, no summation is performed. Note that negative indices count from the end of the batch dimensions, the core dimension p can’t be summed over by this function.

Returns:

Int32[Array, '*reduced_batch_shape {p}']The histogram(s) of the variables used in the tree.

bartz.grove.format_tree(tree, *, print_all=False)[source]

Convert a tree to a human-readable string.

Parameters:
  • tree (TreeHeaps) – A single tree to format.

  • print_all (bool, default: False) – If True, also print the contents of unused node slots in the arrays.

Returns:

strA string representation of the tree.

bartz.grove.tree_actual_depth(split_tree)[source]

Measure the depth of the tree.

Parameters:

split_tree (UInt[Array, '2**(d-1)']) – The cutpoints of the decision rules.

Returns:

Int32[Array, '']The depth of the deepest leaf in the tree. The root is at depth 0.

bartz.grove.forest_depth_distr(split_tree)[source]

Histogram the depths of a set of trees.

Parameters:

split_tree (UInt[Array, '*batch_shape num_trees 2**(d-1)']) – The cutpoints of the decision rules of the trees.

Returns:

Int32[Array, '*batch_shape d']An integer vector where the i-th element counts how many trees have depth i.

bartz.grove.points_per_node_distr(X, var_tree, split_tree, node_type, *, sum_batch_axis=())[source]

Histogram points-per-node counts in a set of trees.

Count how many nodes in a tree select each possible amount of points, over a certain subset of nodes.

Parameters:
  • X (UInt[Array, 'p n']) – The set of points to count.

  • var_tree (UInt[Array, '*batch_shape 2**(d-1)']) – The variables of the decision rules.

  • split_tree (UInt[Array, '*batch_shape 2**(d-1)']) – The cutpoints of the decision rules.

  • node_type (Literal['leaf', 'leaf-parent']) –

    The type of nodes to consider. Can be:

    ’leaf’

    Count only leaf nodes.

    ’leaf-parent’

    Count only parent-of-leaf nodes.

  • sum_batch_axis (DTypeLike[int, KeyPath[int, ...]], default: ()) – Aggregate the histogram over these batch axes, counting how many nodes have each possible amount of points over subsets of trees instead of in each tree separately.

Returns:

Int32[Array, '*reduced_batch_shape n+1']A vector where the i-th element counts how many nodes have i points.