bartz.grove

Functions to create, manipulate, and check binary decision trees.

Tree representation

TreeHeaps(*args, **kwargs)

A protocol for dataclasses that represent trees.

TreesTrace(var_tree, split_tree, leaf_tree)

Implementation of bartz.grove.TreeHeaps for an MCMC trace.

is_multivariate(trees)

Return whether the trees have vector-valued leaves.

Evaluation

evaluate_forest(X, trees, *[, sum_batch_axis])

Evaluate an ensemble of trees at an array of points.

traverse_tree(x, var_tree, split_tree)

Find the leaf where a point falls into.

traverse_forest(X, var_trees, split_trees)

Find the leaves where points falls into for each tree in a set.

Tree and node properties

is_actual_leaf(split_tree, *[, add_bottom_level])

Return a mask indicating the leaf nodes in a tree.

is_leaves_parent(split_tree)

Return a mask indicating the nodes with leaf (and only leaf) children.

tree_depth(tree)

Return the maximum depth of a tree.

tree_actual_depth(split_tree)

Measure the depth of the tree.

tree_depths(tree_size)

Return the depth of each node in a binary tree.

Forest summaries

forest_mean_leaves(split_tree)

Return the average number of leaves per tree in a set of trees.

forest_depth_distr(split_tree)

Histogram the depths of a set of trees.

points_per_node_distr(X, var_tree, ...[, ...])

Histogram points-per-node counts in a set of trees.

var_histogram(p, var_tree, split_tree, *[, ...])

Count how many times each variable appears in a tree.

Validation and inspection

check_trace(trace, max_split)

Check the validity of a set of trees.

describe_error(error)

Describe an error code returned by check_trace.

format_tree(tree, *[, print_all])

Convert a tree to a human-readable string.