Coverage for src / bartz / testing / _dgp.py: 93%

115 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/testing/_dgp.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25 

26"""Define `gen_data` that generates simulated data for testing.""" 

27 

28from dataclasses import replace 

29from functools import partial 

30 

31from equinox import Module, error_if 

32from jax import jit, random 

33from jax import numpy as jnp 

34from jaxtyping import Array, Bool, Float, Int, Integer, Key 

35 

36from bartz.jaxext import split 

37 

38 

39def generate_x(key: Key[Array, ''], n: int, p: int) -> Float[Array, 'p n']: 

40 """Generate predictors with mean 0 and variance 1. 

41 

42 x_rj ~iid U(-√3, √3) 

43 """ 

44 return random.uniform(key, (p, n), minval=-jnp.sqrt(3.0), maxval=jnp.sqrt(3.0)) 1ba

45 

46 

47def generate_partition(key: Key[Array, ''], p: int, k: int) -> Bool[Array, 'k p']: 

48 """Partition x components amongst y components. 

49 

50 Each row i has either p // k or p // k + 1 non-zero entries. 

51 """ 

52 keys = split(key) 1defgba

53 indices: Int[Array, 'p'] = jnp.linspace(0, k, p, endpoint=False) 1defgba

54 indices = jnp.trunc(indices).astype(jnp.int32) 1defgba

55 indices = random.permutation(keys.pop(), indices) 1defgba

56 assignments: Int[Array, 'k'] = random.permutation(keys.pop(), k) 1defgba

57 return indices == assignments[:, None] 1defgba

58 

59 

60def generate_beta_shared( 

61 key: Key[Array, ''], p: int, sigma2_lin: Float[Array, ''] 

62) -> Float[Array, ' p']: 

63 """Generate shared linear coefficients for the lambda=1 case.""" 

64 sigma2_beta = sigma2_lin / p 1ba

65 return random.normal(key, (p,)) * jnp.sqrt(sigma2_beta) 1ba

66 

67 

68def generate_beta_separate( 

69 key: Key[Array, ''], partition: Bool[Array, 'k p'], sigma2_lin: Float[Array, ''] 

70) -> Float[Array, 'k p']: 

71 """Generate separate linear coefficients for the lambda=0 case.""" 

72 k, p = partition.shape 1ba

73 beta_separate: Float[Array, 'k p'] = random.normal(key, (k, p)) 1ba

74 sigma2_beta = sigma2_lin / (p / k) 1ba

75 return jnp.where(partition, beta_separate, 0.0) * jnp.sqrt(sigma2_beta) 1ba

76 

77 

78def compute_linear_mean_shared( 

79 beta_shared: Float[Array, ' p'], x: Float[Array, 'p n'] 

80) -> Float[Array, ' n']: 

81 """mulin_ij = beta_r x_rj.""" 

82 return beta_shared @ x 1ba

83 

84 

85def compute_linear_mean_separate( 

86 beta_separate: Float[Array, 'k p'], x: Float[Array, 'p n'] 

87) -> Float[Array, 'k n']: 

88 """mulin_ij = beta_ir x_rj.""" 

89 return beta_separate @ x 1ba

90 

91 

92def combine_mulin( 

93 mulin_shared: Float[Array, ' n'], 

94 mulin_separate: Float[Array, 'k n'], 

95 lam: Float[Array, ''], 

96) -> Float[Array, 'k n']: 

97 """Combine shared and separate linear means.""" 

98 return jnp.sqrt(1.0 - lam) * mulin_separate + jnp.sqrt(lam) * mulin_shared 1ba

99 

100 

101def interaction_pattern(p: int, q: Integer[Array, ''] | int) -> Bool[Array, 'p p']: 

102 """Create a symmetric interaction pattern for q interactions per variable. 

103 

104 Parameters 

105 ---------- 

106 p 

107 Number of predictors 

108 q 

109 Number of interactions per predictor (must be even) 

110 

111 Returns 

112 ------- 

113 Symmetric binary pattern of shape (p, p) where each row/col sums to q+1 

114 """ 

115 q = error_if(q, q % 2 != 0, 'q must be even') 1hijba

116 q = error_if(q, q >= p, 'q must be less than p') 1hijba

117 

118 i, j = jnp.ogrid[:p, :p] 1hijba

119 dist = jnp.minimum(jnp.abs(i - j), p - jnp.abs(i - j)) 1hijba

120 return dist <= (q // 2) 1hijba

121 

122 

123def generate_A_shared( 

124 key: Key[Array, ''], 

125 p: int, 

126 q: Integer[Array, ''], 

127 sigma2_quad: Float[Array, ''], 

128 kurt_x: float, 

129) -> Float[Array, 'p p']: 

130 """Generate shared quadratic coefficients for the lambda=1 case.""" 

131 pattern: Bool[Array, 'p p'] = interaction_pattern(p, q) 1ba

132 A_shared: Float[Array, 'p p'] = random.normal(key, (p, p)) 1ba

133 A_shared = jnp.where(pattern, A_shared, 0.0) 1ba

134 sigma2_A = sigma2_quad / (p * (kurt_x - 1 + q)) 1ba

135 return A_shared * jnp.sqrt(sigma2_A) 1ba

136 

137 

138def partitioned_interaction_pattern( 

139 partition: Bool[Array, 'k p'], q: Integer[Array, ''] | int 

140) -> Bool[Array, 'k p p']: 

141 """Create k interaction patterns that use disjoint variable sets. 

142 

143 Parameters 

144 ---------- 

145 partition 

146 Binary partition of shape (k, p) indicating variable assignments 

147 to components 

148 q 

149 Number of interactions per predictor (must be even and < p // k) 

150 

151 Returns 

152 ------- 

153 Interaction patterns of shape (k, p, p) 

154 """ 

155 k, p = partition.shape 1defgba

156 q = error_if(q, q % 2 != 0, 'q must be even') 1defgba

157 q = error_if(q, q >= p // k, 'q must be less than p // k') 1defgba

158 

159 indices: Int[Array, 'k p'] = jnp.cumsum(partition, axis=1) 1defgba

160 linear_dist: Int[Array, 'k p p'] = jnp.abs( 1defgba

161 indices[:, :, None] - indices[:, None, :] 

162 ) 

163 num_vars: Int[Array, 'k'] = jnp.max(indices, axis=1) 1defgba

164 wrapped_dist: Int[Array, 'k p p'] = jnp.minimum( 1defgba

165 linear_dist, num_vars[:, None, None] - linear_dist 

166 ) 

167 interacts: Bool[Array, 'k p p'] = wrapped_dist <= (q // 2) 1defgba

168 interacts = jnp.where(partition[:, :, None], interacts, False) 1defgba

169 return jnp.where(partition[:, None, :], interacts, False) 1defgba

170 

171 

172def generate_A_separate( 

173 key: Key[Array, ''], 

174 partition: Bool[Array, 'k p'], 

175 q: Integer[Array, ''], 

176 sigma2_quad: Float[Array, ''], 

177 kurt_x: float, 

178) -> Float[Array, 'k p p']: 

179 """Generate separate quadratic coefficients for the lambda=0 case.""" 

180 k, p = partition.shape 1ba

181 A_separate: Float[Array, 'k p p'] = random.normal(key, (k, p, p)) 1ba

182 component_pattern: Bool[Array, 'k p p'] = partitioned_interaction_pattern( 1ba

183 partition, q 

184 ) 

185 A_separate = jnp.where(component_pattern, A_separate, 0.0) 1ba

186 sigma2_A = sigma2_quad / (p / k * (kurt_x - 1 + q)) 1ba

187 return A_separate * jnp.sqrt(sigma2_A) 1ba

188 

189 

190def compute_muquad_shared( 

191 A_shared: Float[Array, 'p p'], x: Float[Array, 'p n'] 

192) -> Float[Array, ' n']: 

193 """Compute quadratic mean for the lambda=1 case. 

194 

195 muquad_ij = A_rs x_rj x_sj 

196 Rows identical across components. 

197 """ 

198 return jnp.einsum('rs,rj,sj->j', A_shared, x, x) 1ba

199 

200 

201def compute_muquad_separate( 

202 A_separate: Float[Array, 'k p p'], x: Float[Array, 'p n'] 

203) -> Float[Array, 'k n']: 

204 """Compute quadratic mean for the lambda=0 case. 

205 

206 muquad_ij = A_irs x_rj x_sj 

207 Rows independent across components. 

208 """ 

209 return jnp.einsum('irs,rj,sj->ij', A_separate, x, x) 1ba

210 

211 

212def combine_muquad( 

213 muquad_shared: Float[Array, ' n'], 

214 muquad_separate: Float[Array, 'k n'], 

215 lam: Float[Array, ''], 

216) -> Float[Array, 'k n']: 

217 """Combine shared and separate quadratic means.""" 

218 return jnp.sqrt(1.0 - lam) * muquad_separate + jnp.sqrt(lam) * muquad_shared 1ba

219 

220 

221def compute_quadratic_mean( 

222 A: Float[Array, 'k p p'], x: Float[Array, 'p n'] 

223) -> Float[Array, 'k n']: 

224 """Compute quadratic part of the latent mean.""" 

225 return jnp.einsum('irs,rj,sj->ij', A, x, x) 

226 

227 

228def generate_outcome( 

229 key: Key[Array, ''], mu: Float[Array, 'k n'], sigma2_eps: Float[Array, ''] 

230) -> Float[Array, 'k n']: 

231 """Generate noisy outcome.""" 

232 eps: Float[Array, 'k n'] = random.normal(key, mu.shape) 1ba

233 return mu + eps * jnp.sqrt(sigma2_eps) 1ba

234 

235 

236class DGP(Module): 

237 """Output of `gen_data`. 

238 

239 Parameters 

240 ---------- 

241 x 

242 Predictors of shape (p, n), variance 1 

243 y 

244 Noisy outcomes of shape (k, n) or (n,) 

245 partition 

246 Predictor-outcome assignment partition of shape (k, p) 

247 beta_shared 

248 Shared linear coefficients of shape (p,) 

249 beta_separate 

250 Separate linear coefficients of shape (k, p) 

251 mulin_shared 

252 Linear mean at lambda=1 (shared), shape (k, n), rows identical 

253 mulin_separate 

254 Linear mean at lambda=0 (separate), shape (k, n), rows independent 

255 mulin 

256 Linear part of latent mean of shape (k, n) 

257 A_shared 

258 Shared quadratic coefficients of shape (p, p) 

259 A_separate 

260 Separate quadratic coefficients of shape (k, p, p) 

261 muquad_shared 

262 Quadratic mean at lambda=1 (shared), shape (k, n), rows identical 

263 muquad_separate 

264 Quadratic mean at lambda=0 (separate), shape (k, n), rows independent 

265 muquad 

266 Quadratic part of latent mean of shape (k, n) 

267 mu 

268 True latent means of shape (k, n) 

269 q 

270 Number of interactions per predictor 

271 lam 

272 Coupling parameter in [0, 1] 

273 sigma2_lin 

274 Prior and expected population variance of mulin 

275 sigma2_quad 

276 Expected population variance of muquad 

277 sigma2_eps 

278 Variance of the error 

279 """ 

280 

281 # Main outputs 

282 x: Float[Array, 'p n'] 

283 y: Float[Array, 'k n'] | Float[Array, ' n'] 

284 

285 # Intermediate results 

286 partition: Bool[Array, 'k p'] 

287 beta_shared: Float[Array, ' p'] 

288 beta_separate: Float[Array, 'k p'] 

289 mulin_shared: Float[Array, ' n'] 

290 mulin_separate: Float[Array, 'k n'] 

291 mulin: Float[Array, 'k n'] 

292 A_shared: Float[Array, 'p p'] 

293 A_separate: Float[Array, 'k p p'] 

294 muquad_shared: Float[Array, ' n'] 

295 muquad_separate: Float[Array, 'k n'] 

296 muquad: Float[Array, 'k n'] 

297 mu: Float[Array, 'k n'] 

298 

299 # Params 

300 q: Integer[Array, ''] 

301 lam: Float[Array, ''] 

302 sigma2_lin: Float[Array, ''] 

303 sigma2_quad: Float[Array, ''] 

304 sigma2_eps: Float[Array, ''] 

305 

306 kurt_x: float = 9 / 5 # kurtosis of uniform distribution 

307 

308 @property 

309 def sigma2_pri(self) -> Float[Array, '']: 

310 """Prior variance of y.""" 

311 return self.sigma2_pop + self.sigma2_mean 1klm

312 

313 @property 

314 def sigma2_pop(self) -> Float[Array, '']: 

315 """Expected population variance of y.""" 

316 return self.sigma2_lin + self.sigma2_quad + self.sigma2_eps 1noklm

317 

318 @property 

319 def sigma2_mean(self) -> Float[Array, '']: 

320 """Variance of the mean function.""" 

321 return self.sigma2_quad / (self.kurt_x - 1 + self.q) 1kpqrlm

322 

323 def split(self, n_train: int | None = None) -> tuple['DGP', 'DGP']: 

324 """Split the data into training and test sets. 

325 

326 Parameters 

327 ---------- 

328 n_train 

329 Number of training observations. If None, split in half. 

330 

331 Returns 

332 ------- 

333 Two `DGP` object with the train and test splits. 

334 """ 

335 if n_train is None: 

336 n_train = self.x.shape[1] // 2 

337 assert 0 < n_train < self.x.shape[1], 'n_train must be in (0, n)' 

338 train = replace( 

339 self, 

340 x=self.x[:, :n_train], 

341 y=self.y[:, :n_train], 

342 mulin_shared=self.mulin_shared[:n_train], 

343 mulin_separate=self.mulin_separate[:, :n_train], 

344 mulin=self.mulin[:, :n_train], 

345 muquad_shared=self.muquad_shared[:n_train], 

346 muquad_separate=self.muquad_separate[:, :n_train], 

347 muquad=self.muquad[:, :n_train], 

348 mu=self.mu[:, :n_train], 

349 ) 

350 test = replace( 

351 self, 

352 x=self.x[:, n_train:], 

353 y=self.y[:, n_train:], 

354 mulin_shared=self.mulin_shared[n_train:], 

355 mulin_separate=self.mulin_separate[:, n_train:], 

356 mulin=self.mulin[:, n_train:], 

357 muquad_shared=self.muquad_shared[n_train:], 

358 muquad_separate=self.muquad_separate[:, n_train:], 

359 muquad=self.muquad[:, n_train:], 

360 mu=self.mu[:, n_train:], 

361 ) 

362 return train, test 

363 

364 

365@partial(jit, static_argnames=('n', 'p', 'k')) 

366def gen_data( 

367 key: Key[Array, ''], 

368 *, 

369 n: int, 

370 p: int, 

371 k: int | None = None, 

372 q: Integer[Array, ''] | int, 

373 lam: Float[Array, ''] | float, 

374 sigma2_lin: Float[Array, ''] | float, 

375 sigma2_quad: Float[Array, ''] | float, 

376 sigma2_eps: Float[Array, ''] | float, 

377) -> DGP: 

378 """Generate data from a quadratic multivariate DGP. 

379 

380 Parameters 

381 ---------- 

382 key 

383 JAX random key 

384 n 

385 Number of observations 

386 p 

387 Number of predictors 

388 k 

389 Number of outcome components 

390 q 

391 Number of interactions per predictor (must be even and < p // k) 

392 lam 

393 Coupling parameter in [0, 1]. 0=independent, 1=identical components 

394 sigma2_lin 

395 Prior and expected population variance of the linear term 

396 sigma2_quad 

397 Expected population variance of the quadratic term 

398 sigma2_eps 

399 Variance of the error term 

400 

401 Returns 

402 ------- 

403 An object with all generated data and parameters. 

404 """ 

405 squeeze = k is None 1ba

406 if squeeze: 1ba

407 k = 1 1a

408 

409 assert p >= k, 'p must be at least k' 1ba

410 

411 # check q 

412 q = error_if(q, q % 2 != 0, 'q must be even') 1ba

413 q = error_if(q, q >= p // k, 'q must be less than p // k') 1ba

414 

415 keys = split(key, 7) 1ba

416 

417 x = generate_x(keys.pop(), n, p) 1ba

418 partition = generate_partition(keys.pop(), p, k) 1ba

419 beta_shared = generate_beta_shared(keys.pop(), p, sigma2_lin) 1ba

420 beta_separate = generate_beta_separate(keys.pop(), partition, sigma2_lin) 1ba

421 mulin_shared = compute_linear_mean_shared(beta_shared, x) 1ba

422 mulin_separate = compute_linear_mean_separate(beta_separate, x) 1ba

423 mulin = combine_mulin(mulin_shared, mulin_separate, lam) 1ba

424 A_shared = generate_A_shared(keys.pop(), p, q, sigma2_quad, DGP.kurt_x) 1ba

425 A_separate = generate_A_separate(keys.pop(), partition, q, sigma2_quad, DGP.kurt_x) 1ba

426 muquad_shared = compute_muquad_shared(A_shared, x) 1ba

427 muquad_separate = compute_muquad_separate(A_separate, x) 1ba

428 muquad = combine_muquad(muquad_shared, muquad_separate, lam) 1ba

429 mu = mulin + muquad 1ba

430 y = generate_outcome(keys.pop(), mu, sigma2_eps) 1ba

431 if squeeze: 1ba

432 y = y.squeeze(0) 1a

433 

434 return DGP( 1ba

435 x=x, 

436 y=y, 

437 partition=partition, 

438 beta_shared=beta_shared, 

439 beta_separate=beta_separate, 

440 mulin_shared=mulin_shared, 

441 mulin_separate=mulin_separate, 

442 mulin=mulin, 

443 A_shared=A_shared, 

444 A_separate=A_separate, 

445 muquad_shared=muquad_shared, 

446 muquad_separate=muquad_separate, 

447 muquad=muquad, 

448 mu=mu, 

449 q=q, 

450 lam=lam, 

451 sigma2_lin=sigma2_lin, 

452 sigma2_quad=sigma2_quad, 

453 sigma2_eps=sigma2_eps, 

454 )