bartz.debug.sample_prior

bartz.debug.sample_prior(key, trace_length, num_trees, max_split, p_nonterminal, sigma_mu)[source]

Sample independent trees from the BART prior.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • trace_length (int) – The number of iterations.

  • num_trees (int) – The number of trees for each iteration.

  • max_split (UInt[Array, 'p']) – The number of cutpoints along each variable.

  • p_nonterminal (Float32[Array, 'd_minus_1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth. This determines the maximum depth of the trees.

  • sigma_mu (Float32[Array, '']) – The prior standard deviation of each leaf.

Returns:

TreesTraceAn object containing the generated trees, with batch shape (trace_length, num_trees).