bartz.debug.sample_prior

bartz.debug.sample_prior(key, trace_length, num_trees, max_split, p_nonterminal, sigma_mu, log_s=None, theta=None, a=None, b=None, rho=None)[source]

Sample independent trees from the BART prior.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • trace_length (int) – The number of iterations.

  • num_trees (int) – The number of trees for each iteration.

  • max_split (UInt[Array, 'p']) – The number of cutpoints along each variable.

  • p_nonterminal (Float32[Array, 'd_minus_1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth. This determines the maximum depth of the trees.

  • sigma_mu (Float32[Array, '']) – The prior standard deviation of each leaf.

  • log_s (Float32[Array, 'p'] | Float32[Array, 'trace_length p'] | None, default: None) – The logarithm of the unnormalized prior probability of splitting on each variable, either shared across all trees (shape (p,)) or specified per iteration (shape (trace_length, p)). If None, variables are chosen uniformly at random. Mutually exclusive with theta, a, b, rho, which instead sample it from its prior.

  • theta (float | Float32[Array, ''] | None, default: None) – The Dirichlet concentration parameter. If set, and rho, a, b are not, log_s is sampled with this fixed concentration.

  • a (float | Float32[Array, ''] | None, default: None)

  • b (float | Float32[Array, ''] | None, default: None)

  • rho (float | Float32[Array, ''] | None, default: None) – Parameters of the prior \(\theta/(\theta+\rho) \sim \mathrm{Beta}(a, b)\). If all set, both theta and log_s are sampled.

Returns:

PriorSampleA trace of trees sampled from the prior.