bartz.grove.evaluate_forest

bartz.grove.evaluate_forest(X, trees, *, sum_batch_axis=())[source]

Evaluate an ensemble of trees at an array of points.

Parameters:
  • X (UInt[Array, 'p n']) – The coordinates to evaluate the trees at.

  • trees (TreeHeaps) – The trees.

  • sum_batch_axis (int | tuple[int, ...], default: ()) – The batch axes to sum over. By default, no summation is performed. Note that negative indices count from the end of the batch dimensions, the core dimensions n and k can’t be summed over by this function.

Returns:

Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n'] – The (sum of) the values of the trees at the points in X.