bartz.grove.traverse_forest

bartz.grove.traverse_forest(X, var_trees, split_trees)[source]

Find the leaves where points falls into for each tree in a set.

Parameters:
  • X (UInt[Array, 'p n']) – The coordinates to evaluate the trees at.

  • var_trees (UInt[Array, '*forest_shape half_tree_size']) – The decision axes of the trees.

  • split_trees (UInt[Array, '*forest_shape half_tree_size']) – The decision boundaries of the trees.

Returns:

UInt[Array, '*forest_shape n']The indices of the leaves.