bartz.grove.TreeHeaps

class bartz.grove.TreeHeaps(*args, **kwargs)[source]

A protocol for dataclasses that represent trees.

A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index \(i\) are at indices \(2i\) (left child) and \(2i + 1\) (right child). The array element at index 0 is unused.

Since the nodes at the bottom can only be leaves and not decision nodes, var_tree and split_tree are half as long as leaf_tree.

Arrays may have additional initial axes to represent multiple trees.

leaf_tree: Float32[Array, '*batch_shape 2*half_tree_size'] | Float32[Array, '*batch_shape k 2*half_tree_size']

The values in the leaves of the trees. This array can be dirty, i.e., unused nodes can have whatever value. It may have an additional axis for multivariate leaves.

var_tree: UInt[Array, '*batch_shape half_tree_size']

The axes along which the decision nodes operate. This array can be dirty but for the always unused node at index 0 which must be set to 0.

split_tree: UInt[Array, '*batch_shape half_tree_size']

The decision boundaries of the trees. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding ‘split’ element being 0. Unused nodes also have split set to 0. This array can’t be dirty.