bartz.grove.var_histogram

bartz.grove.var_histogram(p, var_tree, split_tree, *, sum_batch_axis=())[source]

Count how many times each variable appears in a tree.

Parameters:
  • p (int) – The number of variables (the maximum value that can occur in var_tree is p - 1).

  • var_tree (UInt[Array, '*batch_shape half_tree_size']) – The decision axes of the tree.

  • split_tree (UInt[Array, '*batch_shape half_tree_size']) – The decision boundaries of the tree.

  • sum_batch_axis (int | tuple[int, ...], default: ()) – The batch axes to sum over. By default, no summation is performed. Note that negative indices count from the end of the batch dimensions, the core dimension p can’t be summed over by this function.

Returns:

Int32[Array, '*reduced_batch_shape {p}']The histogram(s) of the variables used in the tree.