Debugging

Debugging utilities.

class bartz.debug.debug_gbart(*args, check_trees=True, check_replicated_trees=True, **kwargs)[source]

A subclass of gbart that adds debugging functionality.

Parameters:
  • *args (Any) – Passed to gbart.

  • check_trees (bool, default: True) – If True, check all trees with check_trace after running the MCMC, and assert that they are all valid.

  • check_replicated_trees (bool, default: True) – If the data is sharded across devices, check that the trees are equal on all devices in the final state. Set to False to allow jax tracing.

  • **kw – Passed to gbart.

class bartz.debug.debug_mc_gbart(*args, check_trees=True, check_replicated_trees=True, **kwargs)[source]

A subclass of mc_gbart that adds debugging functionality.

Parameters:
  • *args (Any) – Passed to mc_gbart.

  • check_trees (bool, default: True) – If True, check all trees with check_trace after running the MCMC, and assert that they are all valid.

  • check_replicated_trees (bool, default: True) – If the data is sharded across devices, check that the trees are equal on all devices in the final state. Set to False to allow jax tracing.

  • **kwargs (Any) – Passed to mc_gbart.

print_tree(i_chain, i_sample, i_tree, print_all=False)[source]

Print a single tree in human-readable format.

Parameters:
  • i_chain (int) – The index of the MCMC chain.

  • i_sample (int) – The index of the (post-burnin) sample in the chain.

  • i_tree (int) – The index of the tree in the sample.

  • print_all (bool, default: False) – If True, also print the content of unused node slots.

Return type:

None

sigma_harmonic_mean(prior=False)[source]

Return the harmonic mean of the error variance.

Parameters:

prior (bool, default: False) – If True, use the prior distribution, otherwise use the full conditional at the last MCMC iteration.

Returns:

Float32[Array, 'mc_cores']The harmonic mean 1/E[1/sigma^2] in the selected distribution.

compare_resid(y=None)[source]

Re-compute residuals to compare them with the updated ones.

Return type:

KeyPath[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]

avg_acc()[source]

Compute the average acceptance rates of tree moves.

Returns:

  • acc_grow (Float32[Array, ‘mc_cores’]) – The average acceptance rate of grow moves.

  • acc_prune (Float32[Array, ‘mc_cores’]) – The average acceptance rate of prune moves.

avg_prop()[source]

Compute the average proposal rate of grow and prune moves.

Returns:

  • prop_grow (Float32[Array, ‘mc_cores’]) – The fraction of times grow was proposed instead of prune.

  • prop_prune (Float32[Array, ‘mc_cores’]) – The fraction of times prune was proposed instead of grow.

Notes

This function does not take into account cases where no move was proposed.

avg_move()[source]

Compute the move rate.

Returns:

  • rate_grow (Float32[Array, ‘mc_cores’]) – The fraction of times a grow move was proposed and accepted.

  • rate_prune (Float32[Array, ‘mc_cores’]) – The fraction of times a prune move was proposed and accepted.

depth_distr()[source]

Histogram of tree depths for each state of the trees.

Return type:

Int32[Array, 'mc_cores ndpost/mc_cores d']

points_per_decision_node_distr()[source]

Histogram of number of points belonging to parent-of-leaf nodes.

Return type:

Int32[Array, 'mc_cores ndpost/mc_cores n+1']

points_per_leaf_distr()[source]

Histogram of number of points belonging to leaves.

Return type:

Int32[Array, 'mc_cores ndpost/mc_cores n+1']

check_trees()[source]

Apply check_trace to all the tree draws.

Return type:

UInt[Array, 'mc_cores ndpost/mc_cores ntree']

tree_goes_bad()[source]

Find iterations where a tree becomes invalid.

Returns:

Bool[Array, 'mc_cores ndpost/mc_cores ntree'] – A where (i,j) is True if tree j is invalid at iteration i but not i-1.

class bartz.debug.SamplePriorTrees(leaf_tree, var_tree, split_tree)[source]

Object holding the trees generated by sample_prior.

leaf_tree: Float32[Array, '* 2**d']

The array representing the trees, see bartz.grove.

var_tree: UInt[Array, '* 2**(d-1)']

The array representing the trees, see bartz.grove.

split_tree: UInt[Array, '* 2**(d-1)']

The array representing the trees, see bartz.grove.

classmethod initial(key, sigma_mu, p_nonterminal, max_split)[source]

Initialize the trees.

The leaves are already correct and do not need to be changed.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • sigma_mu (Float32[Array, '']) – The prior standard deviation of each leaf.

  • p_nonterminal (Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth.

  • max_split (UInt[Array, 'p']) – The number of cutpoints along each variable.

Returns:

SamplePriorTreesTrees initialized with random leaves and stub tree structures.

bartz.debug.sample_prior(key, trace_length, num_trees, max_split, p_nonterminal, sigma_mu)[source]

Sample independent trees from the BART prior.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • trace_length (int) – The number of iterations.

  • num_trees (int) – The number of trees for each iteration.

  • max_split (UInt[Array, 'p']) – The number of cutpoints along each variable.

  • p_nonterminal (Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth. This determines the maximum depth of the trees.

  • sigma_mu (Float32[Array, '']) – The prior standard deviation of each leaf.

Returns:

SamplePriorTreesAn object containing the generated trees, with batch shape (trace_length, num_trees).

class bartz.debug.BARTTraceMeta(ndpost, ntree, numcut, heap_size)[source]

Metadata of R BART tree traces.

ndpost: int

The number of posterior draws.

ntree: int

The number of trees in the model.

numcut: UInt[Array, 'p']

The maximum split value for each variable.

heap_size: int

The size of the heap required to store the trees.

class bartz.debug.TraceWithOffset(leaf_tree, var_tree, split_tree, offset)[source]

Implementation of bartz.mcmcloop.Trace.

classmethod from_trees_trace(trees, offset)[source]

Create a TraceWithOffset from a TreeHeaps.

Return type:

TraceWithOffset

bartz.debug.trees_BART_to_bartz(trees, *, min_maxdepth=0, offset=None)[source]

Convert trees from the R BART format to the bartz format.

Parameters:
  • trees (str) – The string representation of a trace of trees of the R BART package. Can be accessed from mc_gbart(...).treedraws['trees'].

  • min_maxdepth (int, default: 0) – The maximum tree depth of the output will be set to the maximum observed depth in the input trees. Use this parameter to require at least this maximum depth in the output format.

  • offset (float | Float[Any, ''] | None, default: None) – The trace returned by bartz.mcmcloop.run_mcmc contains an offset to be summed to the sum of trees. To match that behavior, this function returns an offset as well, zero by default. Set with this parameter otherwise.

Returns:

  • trace (TraceWithOffset) – A representation of the trees compatible with the trace returned by bartz.mcmcloop.run_mcmc.

  • meta (BARTTraceMeta) – The metadata of the trace, containing the number of iterations, trees, and the maximum split value.