Coverage for src / bartz / testing / _dgp.py: 93%
115 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/testing/_dgp.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
26"""Define `gen_data` that generates simulated data for testing."""
28from dataclasses import replace
29from functools import partial
31from equinox import Module, error_if
32from jax import jit, random
33from jax import numpy as jnp
34from jaxtyping import Array, Bool, Float, Int, Integer, Key
36from bartz.jaxext import split
39def generate_x(key: Key[Array, ''], n: int, p: int) -> Float[Array, 'p n']:
40 """Generate predictors with mean 0 and variance 1.
42 x_rj ~iid U(-√3, √3)
43 """
44 return random.uniform(key, (p, n), minval=-jnp.sqrt(3.0), maxval=jnp.sqrt(3.0)) 1ba
47def generate_partition(key: Key[Array, ''], p: int, k: int) -> Bool[Array, 'k p']:
48 """Partition x components amongst y components.
50 Each row i has either p // k or p // k + 1 non-zero entries.
51 """
52 keys = split(key) 1defgba
53 indices: Int[Array, 'p'] = jnp.linspace(0, k, p, endpoint=False) 1defgba
54 indices = jnp.trunc(indices).astype(jnp.int32) 1defgba
55 indices = random.permutation(keys.pop(), indices) 1defgba
56 assignments: Int[Array, 'k'] = random.permutation(keys.pop(), k) 1defgba
57 return indices == assignments[:, None] 1defgba
60def generate_beta_shared(
61 key: Key[Array, ''], p: int, sigma2_lin: Float[Array, '']
62) -> Float[Array, ' p']:
63 """Generate shared linear coefficients for the lambda=1 case."""
64 sigma2_beta = sigma2_lin / p 1ba
65 return random.normal(key, (p,)) * jnp.sqrt(sigma2_beta) 1ba
68def generate_beta_separate(
69 key: Key[Array, ''], partition: Bool[Array, 'k p'], sigma2_lin: Float[Array, '']
70) -> Float[Array, 'k p']:
71 """Generate separate linear coefficients for the lambda=0 case."""
72 k, p = partition.shape 1ba
73 beta_separate: Float[Array, 'k p'] = random.normal(key, (k, p)) 1ba
74 sigma2_beta = sigma2_lin / (p / k) 1ba
75 return jnp.where(partition, beta_separate, 0.0) * jnp.sqrt(sigma2_beta) 1ba
78def compute_linear_mean_shared(
79 beta_shared: Float[Array, ' p'], x: Float[Array, 'p n']
80) -> Float[Array, ' n']:
81 """mulin_ij = beta_r x_rj."""
82 return beta_shared @ x 1ba
85def compute_linear_mean_separate(
86 beta_separate: Float[Array, 'k p'], x: Float[Array, 'p n']
87) -> Float[Array, 'k n']:
88 """mulin_ij = beta_ir x_rj."""
89 return beta_separate @ x 1ba
92def combine_mulin(
93 mulin_shared: Float[Array, ' n'],
94 mulin_separate: Float[Array, 'k n'],
95 lam: Float[Array, ''],
96) -> Float[Array, 'k n']:
97 """Combine shared and separate linear means."""
98 return jnp.sqrt(1.0 - lam) * mulin_separate + jnp.sqrt(lam) * mulin_shared 1ba
101def interaction_pattern(p: int, q: Integer[Array, ''] | int) -> Bool[Array, 'p p']:
102 """Create a symmetric interaction pattern for q interactions per variable.
104 Parameters
105 ----------
106 p
107 Number of predictors
108 q
109 Number of interactions per predictor (must be even)
111 Returns
112 -------
113 Symmetric binary pattern of shape (p, p) where each row/col sums to q+1
114 """
115 q = error_if(q, q % 2 != 0, 'q must be even') 1hijba
116 q = error_if(q, q >= p, 'q must be less than p') 1hijba
118 i, j = jnp.ogrid[:p, :p] 1hijba
119 dist = jnp.minimum(jnp.abs(i - j), p - jnp.abs(i - j)) 1hijba
120 return dist <= (q // 2) 1hijba
123def generate_A_shared(
124 key: Key[Array, ''],
125 p: int,
126 q: Integer[Array, ''],
127 sigma2_quad: Float[Array, ''],
128 kurt_x: float,
129) -> Float[Array, 'p p']:
130 """Generate shared quadratic coefficients for the lambda=1 case."""
131 pattern: Bool[Array, 'p p'] = interaction_pattern(p, q) 1ba
132 A_shared: Float[Array, 'p p'] = random.normal(key, (p, p)) 1ba
133 A_shared = jnp.where(pattern, A_shared, 0.0) 1ba
134 sigma2_A = sigma2_quad / (p * (kurt_x - 1 + q)) 1ba
135 return A_shared * jnp.sqrt(sigma2_A) 1ba
138def partitioned_interaction_pattern(
139 partition: Bool[Array, 'k p'], q: Integer[Array, ''] | int
140) -> Bool[Array, 'k p p']:
141 """Create k interaction patterns that use disjoint variable sets.
143 Parameters
144 ----------
145 partition
146 Binary partition of shape (k, p) indicating variable assignments
147 to components
148 q
149 Number of interactions per predictor (must be even and < p // k)
151 Returns
152 -------
153 Interaction patterns of shape (k, p, p)
154 """
155 k, p = partition.shape 1defgba
156 q = error_if(q, q % 2 != 0, 'q must be even') 1defgba
157 q = error_if(q, q >= p // k, 'q must be less than p // k') 1defgba
159 indices: Int[Array, 'k p'] = jnp.cumsum(partition, axis=1) 1defgba
160 linear_dist: Int[Array, 'k p p'] = jnp.abs( 1defgba
161 indices[:, :, None] - indices[:, None, :]
162 )
163 num_vars: Int[Array, 'k'] = jnp.max(indices, axis=1) 1defgba
164 wrapped_dist: Int[Array, 'k p p'] = jnp.minimum( 1defgba
165 linear_dist, num_vars[:, None, None] - linear_dist
166 )
167 interacts: Bool[Array, 'k p p'] = wrapped_dist <= (q // 2) 1defgba
168 interacts = jnp.where(partition[:, :, None], interacts, False) 1defgba
169 return jnp.where(partition[:, None, :], interacts, False) 1defgba
172def generate_A_separate(
173 key: Key[Array, ''],
174 partition: Bool[Array, 'k p'],
175 q: Integer[Array, ''],
176 sigma2_quad: Float[Array, ''],
177 kurt_x: float,
178) -> Float[Array, 'k p p']:
179 """Generate separate quadratic coefficients for the lambda=0 case."""
180 k, p = partition.shape 1ba
181 A_separate: Float[Array, 'k p p'] = random.normal(key, (k, p, p)) 1ba
182 component_pattern: Bool[Array, 'k p p'] = partitioned_interaction_pattern( 1ba
183 partition, q
184 )
185 A_separate = jnp.where(component_pattern, A_separate, 0.0) 1ba
186 sigma2_A = sigma2_quad / (p / k * (kurt_x - 1 + q)) 1ba
187 return A_separate * jnp.sqrt(sigma2_A) 1ba
190def compute_muquad_shared(
191 A_shared: Float[Array, 'p p'], x: Float[Array, 'p n']
192) -> Float[Array, ' n']:
193 """Compute quadratic mean for the lambda=1 case.
195 muquad_ij = A_rs x_rj x_sj
196 Rows identical across components.
197 """
198 return jnp.einsum('rs,rj,sj->j', A_shared, x, x) 1ba
201def compute_muquad_separate(
202 A_separate: Float[Array, 'k p p'], x: Float[Array, 'p n']
203) -> Float[Array, 'k n']:
204 """Compute quadratic mean for the lambda=0 case.
206 muquad_ij = A_irs x_rj x_sj
207 Rows independent across components.
208 """
209 return jnp.einsum('irs,rj,sj->ij', A_separate, x, x) 1ba
212def combine_muquad(
213 muquad_shared: Float[Array, ' n'],
214 muquad_separate: Float[Array, 'k n'],
215 lam: Float[Array, ''],
216) -> Float[Array, 'k n']:
217 """Combine shared and separate quadratic means."""
218 return jnp.sqrt(1.0 - lam) * muquad_separate + jnp.sqrt(lam) * muquad_shared 1ba
221def compute_quadratic_mean(
222 A: Float[Array, 'k p p'], x: Float[Array, 'p n']
223) -> Float[Array, 'k n']:
224 """Compute quadratic part of the latent mean."""
225 return jnp.einsum('irs,rj,sj->ij', A, x, x)
228def generate_outcome(
229 key: Key[Array, ''], mu: Float[Array, 'k n'], sigma2_eps: Float[Array, '']
230) -> Float[Array, 'k n']:
231 """Generate noisy outcome."""
232 eps: Float[Array, 'k n'] = random.normal(key, mu.shape) 1ba
233 return mu + eps * jnp.sqrt(sigma2_eps) 1ba
236class DGP(Module):
237 """Output of `gen_data`.
239 Parameters
240 ----------
241 x
242 Predictors of shape (p, n), variance 1
243 y
244 Noisy outcomes of shape (k, n) or (n,)
245 partition
246 Predictor-outcome assignment partition of shape (k, p)
247 beta_shared
248 Shared linear coefficients of shape (p,)
249 beta_separate
250 Separate linear coefficients of shape (k, p)
251 mulin_shared
252 Linear mean at lambda=1 (shared), shape (k, n), rows identical
253 mulin_separate
254 Linear mean at lambda=0 (separate), shape (k, n), rows independent
255 mulin
256 Linear part of latent mean of shape (k, n)
257 A_shared
258 Shared quadratic coefficients of shape (p, p)
259 A_separate
260 Separate quadratic coefficients of shape (k, p, p)
261 muquad_shared
262 Quadratic mean at lambda=1 (shared), shape (k, n), rows identical
263 muquad_separate
264 Quadratic mean at lambda=0 (separate), shape (k, n), rows independent
265 muquad
266 Quadratic part of latent mean of shape (k, n)
267 mu
268 True latent means of shape (k, n)
269 q
270 Number of interactions per predictor
271 lam
272 Coupling parameter in [0, 1]
273 sigma2_lin
274 Prior and expected population variance of mulin
275 sigma2_quad
276 Expected population variance of muquad
277 sigma2_eps
278 Variance of the error
279 """
281 # Main outputs
282 x: Float[Array, 'p n']
283 y: Float[Array, 'k n'] | Float[Array, ' n']
285 # Intermediate results
286 partition: Bool[Array, 'k p']
287 beta_shared: Float[Array, ' p']
288 beta_separate: Float[Array, 'k p']
289 mulin_shared: Float[Array, ' n']
290 mulin_separate: Float[Array, 'k n']
291 mulin: Float[Array, 'k n']
292 A_shared: Float[Array, 'p p']
293 A_separate: Float[Array, 'k p p']
294 muquad_shared: Float[Array, ' n']
295 muquad_separate: Float[Array, 'k n']
296 muquad: Float[Array, 'k n']
297 mu: Float[Array, 'k n']
299 # Params
300 q: Integer[Array, '']
301 lam: Float[Array, '']
302 sigma2_lin: Float[Array, '']
303 sigma2_quad: Float[Array, '']
304 sigma2_eps: Float[Array, '']
306 kurt_x: float = 9 / 5 # kurtosis of uniform distribution
308 @property
309 def sigma2_pri(self) -> Float[Array, '']:
310 """Prior variance of y."""
311 return self.sigma2_pop + self.sigma2_mean 1klm
313 @property
314 def sigma2_pop(self) -> Float[Array, '']:
315 """Expected population variance of y."""
316 return self.sigma2_lin + self.sigma2_quad + self.sigma2_eps 1noklm
318 @property
319 def sigma2_mean(self) -> Float[Array, '']:
320 """Variance of the mean function."""
321 return self.sigma2_quad / (self.kurt_x - 1 + self.q) 1kpqrlm
323 def split(self, n_train: int | None = None) -> tuple['DGP', 'DGP']:
324 """Split the data into training and test sets.
326 Parameters
327 ----------
328 n_train
329 Number of training observations. If None, split in half.
331 Returns
332 -------
333 Two `DGP` object with the train and test splits.
334 """
335 if n_train is None:
336 n_train = self.x.shape[1] // 2
337 assert 0 < n_train < self.x.shape[1], 'n_train must be in (0, n)'
338 train = replace(
339 self,
340 x=self.x[:, :n_train],
341 y=self.y[:, :n_train],
342 mulin_shared=self.mulin_shared[:n_train],
343 mulin_separate=self.mulin_separate[:, :n_train],
344 mulin=self.mulin[:, :n_train],
345 muquad_shared=self.muquad_shared[:n_train],
346 muquad_separate=self.muquad_separate[:, :n_train],
347 muquad=self.muquad[:, :n_train],
348 mu=self.mu[:, :n_train],
349 )
350 test = replace(
351 self,
352 x=self.x[:, n_train:],
353 y=self.y[:, n_train:],
354 mulin_shared=self.mulin_shared[n_train:],
355 mulin_separate=self.mulin_separate[:, n_train:],
356 mulin=self.mulin[:, n_train:],
357 muquad_shared=self.muquad_shared[n_train:],
358 muquad_separate=self.muquad_separate[:, n_train:],
359 muquad=self.muquad[:, n_train:],
360 mu=self.mu[:, n_train:],
361 )
362 return train, test
365@partial(jit, static_argnames=('n', 'p', 'k'))
366def gen_data(
367 key: Key[Array, ''],
368 *,
369 n: int,
370 p: int,
371 k: int | None = None,
372 q: Integer[Array, ''] | int,
373 lam: Float[Array, ''] | float,
374 sigma2_lin: Float[Array, ''] | float,
375 sigma2_quad: Float[Array, ''] | float,
376 sigma2_eps: Float[Array, ''] | float,
377) -> DGP:
378 """Generate data from a quadratic multivariate DGP.
380 Parameters
381 ----------
382 key
383 JAX random key
384 n
385 Number of observations
386 p
387 Number of predictors
388 k
389 Number of outcome components
390 q
391 Number of interactions per predictor (must be even and < p // k)
392 lam
393 Coupling parameter in [0, 1]. 0=independent, 1=identical components
394 sigma2_lin
395 Prior and expected population variance of the linear term
396 sigma2_quad
397 Expected population variance of the quadratic term
398 sigma2_eps
399 Variance of the error term
401 Returns
402 -------
403 An object with all generated data and parameters.
404 """
405 squeeze = k is None 1ba
406 if squeeze: 1ba
407 k = 1 1a
409 assert p >= k, 'p must be at least k' 1ba
411 # check q
412 q = error_if(q, q % 2 != 0, 'q must be even') 1ba
413 q = error_if(q, q >= p // k, 'q must be less than p // k') 1ba
415 keys = split(key, 7) 1ba
417 x = generate_x(keys.pop(), n, p) 1ba
418 partition = generate_partition(keys.pop(), p, k) 1ba
419 beta_shared = generate_beta_shared(keys.pop(), p, sigma2_lin) 1ba
420 beta_separate = generate_beta_separate(keys.pop(), partition, sigma2_lin) 1ba
421 mulin_shared = compute_linear_mean_shared(beta_shared, x) 1ba
422 mulin_separate = compute_linear_mean_separate(beta_separate, x) 1ba
423 mulin = combine_mulin(mulin_shared, mulin_separate, lam) 1ba
424 A_shared = generate_A_shared(keys.pop(), p, q, sigma2_quad, DGP.kurt_x) 1ba
425 A_separate = generate_A_separate(keys.pop(), partition, q, sigma2_quad, DGP.kurt_x) 1ba
426 muquad_shared = compute_muquad_shared(A_shared, x) 1ba
427 muquad_separate = compute_muquad_separate(A_separate, x) 1ba
428 muquad = combine_muquad(muquad_shared, muquad_separate, lam) 1ba
429 mu = mulin + muquad 1ba
430 y = generate_outcome(keys.pop(), mu, sigma2_eps) 1ba
431 if squeeze: 1ba
432 y = y.squeeze(0) 1a
434 return DGP( 1ba
435 x=x,
436 y=y,
437 partition=partition,
438 beta_shared=beta_shared,
439 beta_separate=beta_separate,
440 mulin_shared=mulin_shared,
441 mulin_separate=mulin_separate,
442 mulin=mulin,
443 A_shared=A_shared,
444 A_separate=A_separate,
445 muquad_shared=muquad_shared,
446 muquad_separate=muquad_separate,
447 muquad=muquad,
448 mu=mu,
449 q=q,
450 lam=lam,
451 sigma2_lin=sigma2_lin,
452 sigma2_quad=sigma2_quad,
453 sigma2_eps=sigma2_eps,
454 )