bartz.grove.points_per_node_distr

bartz.grove.points_per_node_distr(X, var_tree, split_tree, node_type, *, sum_batch_axis=())[source]

Histogram points-per-node counts in a set of trees.

Count how many nodes in a tree select each possible amount of points, over a certain subset of nodes.

Parameters:
  • X (UInt[Array, 'p n']) – The set of points to count.

  • var_tree (UInt[Array, '*batch_shape half_tree_size']) – The variables of the decision rules.

  • split_tree (UInt[Array, '*batch_shape half_tree_size']) – The cutpoints of the decision rules.

  • node_type (Literal['leaf', 'leaf-parent']) –

    The type of nodes to consider. Can be:

    ’leaf’

    Count only leaf nodes.

    ’leaf-parent’

    Count only parent-of-leaf nodes.

  • sum_batch_axis (int | tuple[int, ...], default: ()) – Aggregate the histogram over these batch axes, counting how many nodes have each possible amount of points over subsets of trees instead of in each tree separately.

Returns:

Int32[Array, '*reduced_batch_shape n+1']A vector where the i-th element counts how many nodes have i points.