bartz.debug.sample_prior¶
- bartz.debug.sample_prior(key, trace_length, num_trees, max_split, p_nonterminal, sigma_mu, log_s=None, theta=None, a=None, b=None, rho=None)[source]¶
Sample independent trees from the BART prior.
- Parameters:
key (
Key[Array, '']) – A jax random key.trace_length (
int) – The number of iterations.num_trees (
int) – The number of trees for each iteration.max_split (
UInt[Array, 'p']) – The number of cutpoints along each variable.p_nonterminal (
Float32[Array, 'd_minus_1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth. This determines the maximum depth of the trees.sigma_mu (
Float32[Array, '']) – The prior standard deviation of each leaf.log_s (
Float32[Array, 'p']|Float32[Array, 'trace_length p']|None, default:None) – The logarithm of the unnormalized prior probability of splitting on each variable, either shared across all trees (shape(p,)) or specified per iteration (shape(trace_length, p)). IfNone, variables are chosen uniformly at random. Mutually exclusive withtheta,a,b,rho, which instead sample it from its prior.theta (
float|Float32[Array, '']|None, default:None) – The Dirichlet concentration parameter. If set, andrho,a,bare not,log_sis sampled with this fixed concentration.rho (
float|Float32[Array, '']|None, default:None) – Parameters of the prior \(\theta/(\theta+\rho) \sim \mathrm{Beta}(a, b)\). If all set, boththetaandlog_sare sampled.
- Returns:
PriorSample– A trace of trees sampled from the prior.