bartz.grove.TreesTrace

class bartz.grove.TreesTrace(var_tree, split_tree, leaf_tree)[source]

Implementation of bartz.grove.TreeHeaps for an MCMC trace.

var_tree: UInt[Array, '*batch_shape half_tree_size']

The axes along which the decision nodes operate. This array can be dirty but for the always unused node at index 0 which must be set to 0.

split_tree: UInt[Array, '*batch_shape half_tree_size']

The decision boundaries of the trees. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding ‘split’ element being 0. Unused nodes also have split set to 0. This array can’t be dirty.

leaf_tree: Float32[Array, '*batch_shape 2*half_tree_size'] | Float32[Array, '*batch_shape k 2*half_tree_size']

The values in the leaves of the trees. This array can be dirty, i.e., unused nodes can have whatever value. It may have an additional axis for multivariate leaves.

classmethod from_dataclass(obj)[source]

Create a TreesTrace from any bartz.grove.TreeHeaps.

Return type:

TreesTrace

axes_from_dataclass(obj)[source]

Project the per-field vmap axis specs of obj onto this template.

self supplies the (array) pytree; the same-named fields of obj (axis specs, i.e. ints or None) replace its leaves. Built with equinox.tree_at, which bypasses the type-checked __init__, so the deliberately off-type axis values are allowed.

Return type:

TreesTrace