bartz.testing.gen_data

bartz.testing.gen_data(key, *, n, p, k=None, q, lambda_=None, sigma2_lin, sigma2_quad, sigma2_eps, offset=0.0, x_distr=Uniform(), beta_distr=DiscreteUniform(m=2), A_distr=DiscreteUniform(m=2), gamma_distr=DiscreteUniform(m=2), error_distr=Normal(), s_distr=Constant(), outcome_type='continuous', het_strength=None, het_shape=None, error_corr=None)[source]

Generate data from a quadratic multivariate DGP.

Thin wrapper around gen_params followed by gen_data_from_params. To batch across n (e.g. to fit memory), call gen_params once and then invoke gen_data_from_params per batch. See Params for the generative model and DGP for the returned fields.

Parameters:
  • key (Key[Array, '']) – JAX random key.

  • n (int) – Number of observations.

  • p (int) – Number of predictors.

  • k (int | None, default: None) – Number of outcome components. If None, produces a univariate output with y.shape == (n,) and skips the separate code path entirely.

  • q (Integer[Array, ''] | int)

  • lambda_ (Float[Array, ''] | float | None, default: None)

  • sigma2_lin (Float[Array, ''] | float)

  • sigma2_quad (Float[Array, ''] | float)

  • sigma2_eps (Float[Array, ''] | float)

  • offset (Float[Array, ''] | Float[Array, 'k'] | float, default: 0.0)

  • x_distr (Distr, default: Uniform())

  • beta_distr (Distr, default: DiscreteUniform(m=2))

  • A_distr (Distr, default: DiscreteUniform(m=2))

  • gamma_distr (Distr, default: DiscreteUniform(m=2))

  • error_distr (Distr, default: Normal())

  • s_distr (ScaleDistr, default: Constant())

  • outcome_type (OutcomeType | str | tuple[OutcomeType | str, ...], default: 'continuous')

  • het_strength (Float[Array, ''] | float | None, default: None)

  • het_shape (Literal['scalar', 'vector'] | None, default: None)

  • error_corr (Float[Array, 'k k'] | None, default: None) – Forwarded to gen_params; see Params.

Returns:

DGP – A DGP object with the sampled data and parameters.