bartz.debug.PriorSample

class bartz.debug.PriorSample(var_tree, split_tree, leaf_tree, log_s, theta)[source]

Output of sample_prior.

axes_from_dataclass(obj)[source]

Project the per-field vmap axis specs of obj onto this template.

self supplies the (array) pytree; the same-named fields of obj (axis specs, i.e. ints or None) replace its leaves. Built with equinox.tree_at, which bypasses the type-checked __init__, so the deliberately off-type axis values are allowed.

Return type:

TreesTrace

classmethod from_dataclass(obj)[source]

Create a TreesTrace from any bartz.grove.TreeHeaps.

Return type:

TreesTrace

var_tree: UInt[Array, '*batch_shape half_tree_size']

The axes along which the decision nodes operate. This array can be dirty but for the always unused node at index 0 which must be set to 0.

split_tree: UInt[Array, '*batch_shape half_tree_size']

The decision boundaries of the trees. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding ‘split’ element being 0. Unused nodes also have split set to 0. This array can’t be dirty.

leaf_tree: Float32[Array, '*batch_shape 2*half_tree_size'] | Float32[Array, '*batch_shape k 2*half_tree_size']

The values in the leaves of the trees. This array can be dirty, i.e., unused nodes can have whatever value. It may have an additional axis for multivariate leaves.

log_s: Float32[Array, 'trace_length p'] | None

The per-iteration log unnormalized pmf for choosing variables to split on, None means uniform distribution.

theta: Float32[Array, 'trace_length'] | None

The per-iteration Dirichlet concentration, None if s is not drawn from a Dirichlet prior.