bartz.testing.gen_params

bartz.testing.gen_params(key, *, p, k, q, lambda_=None, sigma2_lin, sigma2_quad, sigma2_eps, offset=0.0, x_distr=Uniform(), beta_distr=DiscreteUniform(m=2), A_distr=DiscreteUniform(m=2), gamma_distr=DiscreteUniform(m=2), error_distr=Normal(), s_distr=Constant(), outcome_type='continuous', het_strength=None, het_shape=None, error_corr=None)[source]

Sample DGP coefficients and parameters (no dependence on n).

See Params for the meaning of every parameter and the generative model they parametrize.

Parameters:
  • key (Key[Array, '']) – JAX random key.

  • p (int) – Number of predictors.

  • k (int | None) – Number of outcome components. If None, generate a univariate DGP and skip the separate code path: partition, beta_separate, A_separate and lambda_ are all set to None on the returned Params, and only the shared coefficients are drawn.

  • q (Integer[Array, ''] | int) – See Params.

  • lambda_ (Float[Array, ''] | float | None, default: None) – Coupling parameter; must be None iff k is None. See Params.

  • sigma2_lin (Float[Array, ''] | float)

  • sigma2_quad (Float[Array, ''] | float)

  • sigma2_eps (Float[Array, ''] | float)

  • offset (Float[Array, ''] | Float[Array, 'k'] | float, default: 0.0) – See Params.

  • x_distr (Distr, default: Uniform()) – Distribution family of the predictors (default Uniform). Binary predictors (DiscreteUniform with m=2, kurtosis 1) require q >= 2 because their squares are constant.

  • beta_distr (Distr, default: DiscreteUniform(m=2))

  • A_distr (Distr, default: DiscreteUniform(m=2)) – Families of the linear and quadratic coefficient draws. The default random signs (DiscreteUniform with m=2) give every predictor exactly its share of the variance budgets, so with the default Constant scales all predictors are equally important.

  • gamma_distr (Distr, default: DiscreteUniform(m=2)) – Family of the noise projection draws; its kurtosis enters var_v.

  • error_distr (Distr, default: Normal()) – Marginal family of the additive errors (default Normal), realized through the error_corr Gaussian copula. Any standardized family keeps the error variance at sigma2_eps; non-Normal families attenuate the realized error correlation relative to error_corr. See Params.

  • s_distr (ScaleDistr, default: Constant()) – Scale family of the per-predictor importance scales s (e.g. Gamma or SpikeSlab); more dispersed scales make the dependence on the predictors sparser. Constant (default) gives uniform importance. Use ScaleDistr.from_peff to set the dispersion via an effective number of active predictors instead of the raw family parameter.

  • outcome_type (OutcomeType | str | tuple[OutcomeType | str, ...], default: 'continuous') – 'continuous', 'binary', an OutcomeType, or a tuple of length k for mixed outcomes. Tuples with all elements equal are collapsed to the scalar form. Tuples are not allowed when k is None. See Params for the semantics.

  • het_strength (Float[Array, ''] | float | None, default: None) – Heteroskedasticity knob rho in [0, 1] (0 homoskedastic, 1 maximally heterogeneous); must be None iff het_shape is None. See Params.

  • het_shape (Literal['scalar', 'vector'] | None, default: None) – Heteroskedasticity mode: None (homoskedastic), 'scalar' (one error_scale per datapoint, scaling the whole outcome vector), or 'vector' (per-component scales, multivariate only). See Params.

  • error_corr (Float[Array, 'k k'] | None, default: None) – Across-component error correlation. A symmetric positive-definite matrix of shape (k, k), normalized to unit diagonal before use (only its correlation structure matters; the noise scale is sigma2_eps). None (default) gives independent errors. Multivariate only. See Params.

Returns:

Params – A Params with the sampled coefficients and forwarded hyperparameters.

Raises:

ValueError – If outcome_type is a tuple whose length does not match k, or if a tuple outcome_type is combined with k=None, or if (lambda_ is None) != (k is None), or if (het_strength is None) != (het_shape is None), or if het_shape='vector' is combined with k=None, or if a vector offset is combined with k=None or has a length other than k, or if error_corr is combined with k=None or does not have shape (k, k).