bartz.grove.TreesTrace¶
- class bartz.grove.TreesTrace(var_tree, split_tree, leaf_tree)[source]¶
Implementation of
bartz.grove.TreeHeapsfor an MCMC trace.- var_tree: UInt[Array, '*batch_shape half_tree_size']¶
The axes along which the decision nodes operate. This array can be dirty but for the always unused node at index 0 which must be set to 0.
- split_tree: UInt[Array, '*batch_shape half_tree_size']¶
The decision boundaries of the trees. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding ‘split’ element being 0. Unused nodes also have split set to 0. This array can’t be dirty.
- leaf_tree: Float32[Array, '*batch_shape 2*half_tree_size'] | Float32[Array, '*batch_shape k 2*half_tree_size']¶
The values in the leaves of the trees. This array can be dirty, i.e., unused nodes can have whatever value. It may have an additional axis for multivariate leaves.
- classmethod from_dataclass(obj)[source]¶
Create a
TreesTracefrom anybartz.grove.TreeHeaps.- Return type:
- axes_from_dataclass(obj)[source]¶
Project the per-field vmap axis specs of
objonto this template.selfsupplies the (array) pytree; the same-named fields ofobj(axis specs, i.e. ints orNone) replace its leaves. Built withequinox.tree_at, which bypasses the type-checked__init__, so the deliberately off-type axis values are allowed.- Return type: