Testing

Testing utilities.

class bartz.testing.DGP(x, y, partition, beta_shared, beta_separate, mulin_shared, mulin_separate, mulin, A_shared, A_separate, muquad_shared, muquad_separate, muquad, mu, q, lam, sigma2_lin, sigma2_quad, sigma2_eps, kurt_x=1.8)[source]

Output of gen_data.

Parameters:
  • x (Float[Array, 'p n']) – Predictors of shape (p, n), variance 1

  • y (DTypeLike[Float[Array, 'k n'], Float[Array, 'n']]) – Noisy outcomes of shape (k, n) or (n,)

  • partition (Bool[Array, 'k p']) – Predictor-outcome assignment partition of shape (k, p)

  • beta_shared (Float[Array, 'p']) – Shared linear coefficients of shape (p,)

  • beta_separate (Float[Array, 'k p']) – Separate linear coefficients of shape (k, p)

  • mulin_shared (Float[Array, 'n']) – Linear mean at lambda=1 (shared), shape (k, n), rows identical

  • mulin_separate (Float[Array, 'k n']) – Linear mean at lambda=0 (separate), shape (k, n), rows independent

  • mulin (Float[Array, 'k n']) – Linear part of latent mean of shape (k, n)

  • A_shared (Float[Array, 'p p']) – Shared quadratic coefficients of shape (p, p)

  • A_separate (Float[Array, 'k p p']) – Separate quadratic coefficients of shape (k, p, p)

  • muquad_shared (Float[Array, 'n']) – Quadratic mean at lambda=1 (shared), shape (k, n), rows identical

  • muquad_separate (Float[Array, 'k n']) – Quadratic mean at lambda=0 (separate), shape (k, n), rows independent

  • muquad (Float[Array, 'k n']) – Quadratic part of latent mean of shape (k, n)

  • mu (Float[Array, 'k n']) – True latent means of shape (k, n)

  • q (Integer[Array, '']) – Number of interactions per predictor

  • lam (Float[Array, '']) – Coupling parameter in [0, 1]

  • sigma2_lin (Float[Array, '']) – Prior and expected population variance of mulin

  • sigma2_quad (Float[Array, '']) – Expected population variance of muquad

  • sigma2_eps (Float[Array, '']) – Variance of the error

property sigma2_pri: Float[Array, ''][source]

Prior variance of y.

property sigma2_pop: Float[Array, ''][source]

Expected population variance of y.

property sigma2_mean: Float[Array, ''][source]

Variance of the mean function.

split(n_train=None)[source]

Split the data into training and test sets.

Parameters:

n_train (DTypeLike[int, None], default: None) – Number of training observations. If None, split in half.

Returns:

KeyPath[DGP, DGP] – Two DGP object with the train and test splits.

bartz.testing.gen_data(key, *, n, p, k=None, q, lam, sigma2_lin, sigma2_quad, sigma2_eps)[source]

Generate data from a quadratic multivariate DGP.

Parameters:
  • key (Key[Array, '']) – JAX random key

  • n (int) – Number of observations

  • p (int) – Number of predictors

  • k (DTypeLike[int, None], default: None) – Number of outcome components

  • q (DTypeLike[Integer[Array, ''], int]) – Number of interactions per predictor (must be even and < p // k)

  • lam (DTypeLike[Float[Array, ''], float]) – Coupling parameter in [0, 1]. 0=independent, 1=identical components

  • sigma2_lin (DTypeLike[Float[Array, ''], float]) – Prior and expected population variance of the linear term

  • sigma2_quad (DTypeLike[Float[Array, ''], float]) – Expected population variance of the quadratic term

  • sigma2_eps (DTypeLike[Float[Array, ''], float]) – Variance of the error term

Returns:

DGPAn object with all generated data and parameters.