Debugging¶
Debugging utilities.
check_trace: check the validity of a set of trees.
debug_mc_gbart: version ofmc_gbartwith debug checks and methods.
trees_BART_to_bartz: convert an R package BART3 trace to a bartz trace.
sample_prior: sample the bart prior.
- bartz.debug.check_trace(trace, max_split)[source]¶
Check the validity of a set of trees.
Use
describe_errorto parse the error codes returned by this function.- Parameters:
trace (
TreeHeaps) – The set of trees to check. This object can have additional attributes beyond the tree arrays, they are ignored.max_split (
UInt[Array, 'p']) – The maximum split value for each variable.
- Returns:
UInt[Array, '*batch_shape']– A tensor of error codes for each tree.
- bartz.debug.describe_error(error)[source]¶
Describe an error code returned by
check_trace.- Parameters:
error (
DTypeLike[int,Integer[Array, '']]) – An error code returned bycheck_trace.- Returns:
list[str] – A list of the function names that implement the failed checks.
- class bartz.debug.debug_gbart(*args, check_trees=True, check_replicated_trees=True, **kwargs)[source]¶
A subclass of
gbartthat adds debugging functionality.- Parameters:
*args (
Any) – Passed togbart.check_trees (
bool, default:True) – IfTrue, check all trees withcheck_traceafter running the MCMC, and assert that they are all valid.check_replicated_trees (
bool, default:True) – If the data is sharded across devices, check that the trees are equal on all devices in the final state. Set toFalseto allow jax tracing.**kw – Passed to
gbart.
- class bartz.debug.debug_mc_gbart(*args, check_trees=True, check_replicated_trees=True, **kwargs)[source]¶
A subclass of
mc_gbartthat adds debugging functionality.- Parameters:
*args (
Any) – Passed tomc_gbart.check_trees (
bool, default:True) – IfTrue, check all trees withcheck_traceafter running the MCMC, and assert that they are all valid.check_replicated_trees (
bool, default:True) – If the data is sharded across devices, check that the trees are equal on all devices in the final state. Set toFalseto allow jax tracing.**kwargs (
Any) – Passed tomc_gbart.
- print_tree(i_chain, i_sample, i_tree, print_all=False)[source]¶
Print a single tree in human-readable format.
- Parameters:
i_chain (
int) – The index of the MCMC chain.i_sample (
int) – The index of the (post-burnin) sample in the chain.i_tree (
int) – The index of the tree in the sample.print_all (
bool, default:False) – IfTrue, also print the content of unused node slots.
- Return type:
None
- sigma_harmonic_mean(prior=False)[source]¶
Return the harmonic mean of the error variance.
- Parameters:
prior (
bool, default:False) – IfTrue, use the prior distribution, otherwise use the full conditional at the last MCMC iteration.- Returns:
Float32[Array, 'mc_cores']– The harmonic mean 1/E[1/sigma^2] in the selected distribution.
- compare_resid()[source]¶
Re-compute residuals to compare them with the updated ones.
- Returns:
resid1 (Float32[Array, ‘mc_cores n’]) – The final state of the residuals updated during the MCMC.
resid2 (Float32[Array, ‘mc_cores n’]) – The residuals computed from the final state of the trees.
- avg_acc()[source]¶
Compute the average acceptance rates of tree moves.
- Returns:
acc_grow (Float32[Array, ‘mc_cores’]) – The average acceptance rate of grow moves.
acc_prune (Float32[Array, ‘mc_cores’]) – The average acceptance rate of prune moves.
- avg_prop()[source]¶
Compute the average proposal rate of grow and prune moves.
- Returns:
prop_grow (Float32[Array, ‘mc_cores’]) – The fraction of times grow was proposed instead of prune.
prop_prune (Float32[Array, ‘mc_cores’]) – The fraction of times prune was proposed instead of grow.
Notes
This function does not take into account cases where no move was proposed.
- avg_move()[source]¶
Compute the move rate.
- Returns:
rate_grow (Float32[Array, ‘mc_cores’]) – The fraction of times a grow move was proposed and accepted.
rate_prune (Float32[Array, ‘mc_cores’]) – The fraction of times a prune move was proposed and accepted.
- depth_distr()[source]¶
Histogram of tree depths for each state of the trees.
- Returns:
Int32[Array, 'mc_cores ndpost/mc_cores d']– A matrix where each row contains a histogram of tree depths.
- points_per_decision_node_distr()[source]¶
Histogram of number of points belonging to parent-of-leaf nodes.
- Returns:
Int32[Array, 'mc_cores ndpost/mc_cores n+1']– For each chain, a matrix where each row contains a histogram of number of points.
- points_per_leaf_distr()[source]¶
Histogram of number of points belonging to leaves.
- Returns:
Int32[Array, 'mc_cores ndpost/mc_cores n+1']– A matrix where each row contains a histogram of number of points.
- check_trees()[source]¶
Apply
check_traceto all the tree draws.- Return type:
UInt[Array, 'mc_cores ndpost/mc_cores ntree']
- class bartz.debug.SamplePriorTrees(leaf_tree, var_tree, split_tree)[source]¶
Object holding the trees generated by
sample_prior.- leaf_tree: Float32[Array, '* 2**d']¶
The array representing the trees, see
bartz.grove.
- var_tree: UInt[Array, '* 2**(d-1)']¶
The array representing the trees, see
bartz.grove.
- split_tree: UInt[Array, '* 2**(d-1)']¶
The array representing the trees, see
bartz.grove.
- classmethod initial(key, sigma_mu, p_nonterminal, max_split)[source]¶
Initialize the trees.
The leaves are already correct and do not need to be changed.
- Parameters:
key (
Key[Array, '']) – A jax random key.sigma_mu (
Float32[Array, '']) – The prior standard deviation of each leaf.p_nonterminal (
Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth.max_split (
UInt[Array, 'p']) – The number of cutpoints along each variable.
- Returns:
SamplePriorTrees– Trees initialized with random leaves and stub tree structures.
- bartz.debug.sample_prior(key, trace_length, num_trees, max_split, p_nonterminal, sigma_mu)[source]¶
Sample independent trees from the BART prior.
- Parameters:
key (
Key[Array, '']) – A jax random key.trace_length (
int) – The number of iterations.num_trees (
int) – The number of trees for each iteration.max_split (
UInt[Array, 'p']) – The number of cutpoints along each variable.p_nonterminal (
Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth. This determines the maximum depth of the trees.sigma_mu (
Float32[Array, '']) – The prior standard deviation of each leaf.
- Returns:
SamplePriorTrees– An object containing the generated trees, with batch shape (trace_length, num_trees).
- class bartz.debug.BARTTraceMeta(ndpost, ntree, numcut, heap_size)[source]¶
Metadata of R BART tree traces.
- ndpost: int¶
The number of posterior draws.
- ntree: int¶
The number of trees in the model.
- numcut: UInt[Array, 'p']¶
The maximum split value for each variable.
- heap_size: int¶
The size of the heap required to store the trees.
- class bartz.debug.TraceWithOffset(leaf_tree, var_tree, split_tree, offset)[source]¶
Implementation of
bartz.mcmcloop.Trace.- classmethod from_trees_trace(trees, offset)[source]¶
Create a
TraceWithOffsetfrom aTreeHeaps.- Return type:
- bartz.debug.trees_BART_to_bartz(trees, *, min_maxdepth=0, offset=None)[source]¶
Convert trees from the R BART format to the bartz format.
- Parameters:
trees (
str) – The string representation of a trace of trees of the R BART package. Can be accessed frommc_gbart(...).treedraws['trees'].min_maxdepth (
int, default:0) – The maximum tree depth of the output will be set to the maximum observed depth in the input trees. Use this parameter to require at least this maximum depth in the output format.offset (
DTypeLike[float,Float[Any, ''],None], default:None) – The trace returned bybartz.mcmcloop.run_mcmccontains an offset to be summed to the sum of trees. To match that behavior, this function returns an offset as well, zero by default. Set with this parameter otherwise.
- Returns:
trace (TraceWithOffset) – A representation of the trees compatible with the trace returned by
bartz.mcmcloop.run_mcmc.meta (BARTTraceMeta) – The metadata of the trace, containing the number of iterations, trees, and the maximum split value.