bartz.debug.sample_prior¶
- bartz.debug.sample_prior(key, trace_length, num_trees, max_split, p_nonterminal, sigma_mu)[source]¶
Sample independent trees from the BART prior.
- Parameters:
key (
Key[Array, '']) – A jax random key.trace_length (
int) – The number of iterations.num_trees (
int) – The number of trees for each iteration.max_split (
UInt[Array, 'p']) – The number of cutpoints along each variable.p_nonterminal (
Float32[Array, 'd_minus_1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth. This determines the maximum depth of the trees.sigma_mu (
Float32[Array, '']) – The prior standard deviation of each leaf.
- Returns:
TreesTrace– An object containing the generated trees, with batch shape (trace_length, num_trees).