Debugging

Debugging utilities.

bartz.debug.check_trace(trace, max_split)[source]

Check the validity of a set of trees.

Use describe_error to parse the error codes returned by this function.

Parameters:
  • trace (TreeHeaps) – The set of trees to check. This object can have additional attributes beyond the tree arrays, they are ignored.

  • max_split (UInt[Array, 'p']) – The maximum split value for each variable.

Returns:

UInt[Array, '*batch_shape']A tensor of error codes for each tree.

bartz.debug.describe_error(error)[source]

Describe an error code returned by check_trace.

Parameters:

error (DTypeLike[int, Integer[Array, '']]) – An error code returned by check_trace.

Returns:

list[str]A list of the function names that implement the failed checks.

class bartz.debug.debug_gbart(*args, check_trees=True, check_replicated_trees=True, **kwargs)[source]

A subclass of gbart that adds debugging functionality.

Parameters:
  • *args (Any) – Passed to gbart.

  • check_trees (bool, default: True) – If True, check all trees with check_trace after running the MCMC, and assert that they are all valid.

  • check_replicated_trees (bool, default: True) – If the data is sharded across devices, check that the trees are equal on all devices in the final state. Set to False to allow jax tracing.

  • **kw – Passed to gbart.

class bartz.debug.debug_mc_gbart(*args, check_trees=True, check_replicated_trees=True, **kwargs)[source]

A subclass of mc_gbart that adds debugging functionality.

Parameters:
  • *args (Any) – Passed to mc_gbart.

  • check_trees (bool, default: True) – If True, check all trees with check_trace after running the MCMC, and assert that they are all valid.

  • check_replicated_trees (bool, default: True) – If the data is sharded across devices, check that the trees are equal on all devices in the final state. Set to False to allow jax tracing.

  • **kwargs (Any) – Passed to mc_gbart.

print_tree(i_chain, i_sample, i_tree, print_all=False)[source]

Print a single tree in human-readable format.

Parameters:
  • i_chain (int) – The index of the MCMC chain.

  • i_sample (int) – The index of the (post-burnin) sample in the chain.

  • i_tree (int) – The index of the tree in the sample.

  • print_all (bool, default: False) – If True, also print the content of unused node slots.

Return type:

None

sigma_harmonic_mean(prior=False)[source]

Return the harmonic mean of the error variance.

Parameters:

prior (bool, default: False) – If True, use the prior distribution, otherwise use the full conditional at the last MCMC iteration.

Returns:

Float32[Array, 'mc_cores']The harmonic mean 1/E[1/sigma^2] in the selected distribution.

compare_resid()[source]

Re-compute residuals to compare them with the updated ones.

Returns:

  • resid1 (Float32[Array, ‘mc_cores n’]) – The final state of the residuals updated during the MCMC.

  • resid2 (Float32[Array, ‘mc_cores n’]) – The residuals computed from the final state of the trees.

avg_acc()[source]

Compute the average acceptance rates of tree moves.

Returns:

  • acc_grow (Float32[Array, ‘mc_cores’]) – The average acceptance rate of grow moves.

  • acc_prune (Float32[Array, ‘mc_cores’]) – The average acceptance rate of prune moves.

avg_prop()[source]

Compute the average proposal rate of grow and prune moves.

Returns:

  • prop_grow (Float32[Array, ‘mc_cores’]) – The fraction of times grow was proposed instead of prune.

  • prop_prune (Float32[Array, ‘mc_cores’]) – The fraction of times prune was proposed instead of grow.

Notes

This function does not take into account cases where no move was proposed.

avg_move()[source]

Compute the move rate.

Returns:

  • rate_grow (Float32[Array, ‘mc_cores’]) – The fraction of times a grow move was proposed and accepted.

  • rate_prune (Float32[Array, ‘mc_cores’]) – The fraction of times a prune move was proposed and accepted.

depth_distr()[source]

Histogram of tree depths for each state of the trees.

Returns:

Int32[Array, 'mc_cores ndpost/mc_cores d']A matrix where each row contains a histogram of tree depths.

points_per_decision_node_distr()[source]

Histogram of number of points belonging to parent-of-leaf nodes.

Returns:

Int32[Array, 'mc_cores ndpost/mc_cores n+1']For each chain, a matrix where each row contains a histogram of number of points.

points_per_leaf_distr()[source]

Histogram of number of points belonging to leaves.

Returns:

Int32[Array, 'mc_cores ndpost/mc_cores n+1']A matrix where each row contains a histogram of number of points.

check_trees()[source]

Apply check_trace to all the tree draws.

Return type:

UInt[Array, 'mc_cores ndpost/mc_cores ntree']

tree_goes_bad()[source]

Find iterations where a tree becomes invalid.

Returns:

Bool[Array, 'mc_cores ndpost/mc_cores ntree'] – A where (i,j) is True if tree j is invalid at iteration i but not i-1.

class bartz.debug.SamplePriorTrees(leaf_tree, var_tree, split_tree)[source]

Object holding the trees generated by sample_prior.

leaf_tree: Float32[Array, '* 2**d']

The array representing the trees, see bartz.grove.

var_tree: UInt[Array, '* 2**(d-1)']

The array representing the trees, see bartz.grove.

split_tree: UInt[Array, '* 2**(d-1)']

The array representing the trees, see bartz.grove.

classmethod initial(key, sigma_mu, p_nonterminal, max_split)[source]

Initialize the trees.

The leaves are already correct and do not need to be changed.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • sigma_mu (Float32[Array, '']) – The prior standard deviation of each leaf.

  • p_nonterminal (Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth.

  • max_split (UInt[Array, 'p']) – The number of cutpoints along each variable.

Returns:

SamplePriorTreesTrees initialized with random leaves and stub tree structures.

bartz.debug.sample_prior(key, trace_length, num_trees, max_split, p_nonterminal, sigma_mu)[source]

Sample independent trees from the BART prior.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • trace_length (int) – The number of iterations.

  • num_trees (int) – The number of trees for each iteration.

  • max_split (UInt[Array, 'p']) – The number of cutpoints along each variable.

  • p_nonterminal (Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth. This determines the maximum depth of the trees.

  • sigma_mu (Float32[Array, '']) – The prior standard deviation of each leaf.

Returns:

SamplePriorTreesAn object containing the generated trees, with batch shape (trace_length, num_trees).

class bartz.debug.BARTTraceMeta(ndpost, ntree, numcut, heap_size)[source]

Metadata of R BART tree traces.

ndpost: int

The number of posterior draws.

ntree: int

The number of trees in the model.

numcut: UInt[Array, 'p']

The maximum split value for each variable.

heap_size: int

The size of the heap required to store the trees.

class bartz.debug.TraceWithOffset(leaf_tree, var_tree, split_tree, offset)[source]

Implementation of bartz.mcmcloop.Trace.

classmethod from_trees_trace(trees, offset)[source]

Create a TraceWithOffset from a TreeHeaps.

Return type:

TraceWithOffset

bartz.debug.trees_BART_to_bartz(trees, *, min_maxdepth=0, offset=None)[source]

Convert trees from the R BART format to the bartz format.

Parameters:
  • trees (str) – The string representation of a trace of trees of the R BART package. Can be accessed from mc_gbart(...).treedraws['trees'].

  • min_maxdepth (int, default: 0) – The maximum tree depth of the output will be set to the maximum observed depth in the input trees. Use this parameter to require at least this maximum depth in the output format.

  • offset (DTypeLike[float, Float[Any, ''], None], default: None) – The trace returned by bartz.mcmcloop.run_mcmc contains an offset to be summed to the sum of trees. To match that behavior, this function returns an offset as well, zero by default. Set with this parameter otherwise.

Returns:

  • trace (TraceWithOffset) – A representation of the trees compatible with the trace returned by bartz.mcmcloop.run_mcmc.

  • meta (BARTTraceMeta) – The metadata of the trace, containing the number of iterations, trees, and the maximum split value.