Coverage for src/bartz/prepcovars/_prepcovars.py: 93%

164 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/prepcovars/_prepcovars.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implementation of the predictor preprocessing utilities.""" 

26 

27from abc import abstractmethod 

28from functools import partial 

29from typing import Any, Protocol, runtime_checkable 

30 

31from equinox import AbstractVar, Module, field 

32from jax import numpy as jnp 

33from jax import random, vmap 

34from jax.typing import DTypeLike 

35from jaxtyping import Array, Float, Float32, Integer, Key, Real, Shaped, UInt 

36 

37from bartz._jaxext import autobatch, jit, minimal_unsigned_dtype, unique 

38 

39 

40def _parse_xinfo( 

41 xinfo: Float[Array, 'p m'], 

42) -> tuple[Float[Array, 'p m'], UInt[Array, ' p']]: 

43 """Parse pre-defined splits in the format of the R package BART. 

44 

45 Parameters 

46 ---------- 

47 xinfo 

48 A matrix with the cutpoins to use to bin each predictor. Each row shall 

49 contain a sorted list of cutpoints for a predictor. If there are less 

50 cutpoints than the number of columns in the matrix, fill the remaining 

51 cells with NaN. 

52 

53 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

54 

55 Returns 

56 ------- 

57 splits : Float[Array, 'p m'] 

58 `xinfo` modified by replacing nan with a large value. 

59 max_split : UInt[Array, 'p'] 

60 The number of non-nan elements in each row of `xinfo`. 

61 """ 

62 is_not_nan = ~jnp.isnan(xinfo) 

63 max_split = jnp.sum(is_not_nan, axis=1) 

64 max_split = max_split.astype(minimal_unsigned_dtype(xinfo.shape[1])) 

65 huge = _huge_value(xinfo) 

66 splits = jnp.where(is_not_nan, xinfo, huge) 

67 return splits, max_split 

68 

69 

70@jit(static_argnums=(2,)) 

71def _subsample( 

72 key: Key[Array, ''], X: Real[Array, 'p n'], max_samples: int 

73) -> Real[Array, 'p m']: 

74 """Randomly thin each predictor row to at most `max_samples` elements. 

75 

76 Parameters 

77 ---------- 

78 key 

79 A jax random key. 

80 X 

81 A matrix with `p` predictors and `n` observations. 

82 max_samples 

83 The target maximum number of samples per row. 

84 

85 Returns 

86 ------- 

87 A matrix with `p` rows and ``min(n, max_samples)`` columns. If ``n <= max_samples``, `X` is returned unchanged. Otherwise each row contains `max_samples` distinct values drawn without replacement from the corresponding row of `X`, with rows sampled independently. The order of values within each row is unspecified. 

88 

89 Raises 

90 ------ 

91 ValueError 

92 If `max_samples` is less than 1. 

93 """ 

94 if max_samples < 1: 

95 msg = f'{max_samples=}, must be at least 1.' 

96 raise ValueError(msg) 

97 

98 p, n = X.shape 

99 if n <= max_samples: 

100 return X 

101 

102 keys = random.split(key, p) 

103 

104 @partial(autobatch, max_io_nbytes=2**29) 

105 @vmap 

106 def per_row(k: Key[Array, ''], x: Real[Array, ' n']) -> Real[Array, ' m']: 

107 return random.choice(k, x, shape=(max_samples,), replace=False) 

108 

109 return per_row(keys, X) 

110 

111 

112@jit(static_argnums=(1,)) 

113def _quantilized_splits_from_matrix( 

114 X: Real[Array, 'p n'], max_bins: int 

115) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

116 """ 

117 Determine bins that make the distribution of each predictor uniform. 

118 

119 Parameters 

120 ---------- 

121 X 

122 A matrix with `p` predictors and `n` observations. 

123 max_bins 

124 The maximum number of bins to produce. 

125 

126 Returns 

127 ------- 

128 splits : Real[Array, 'p m'] 

129 A matrix containing, for each predictor, the boundaries between bins. 

130 `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number 

131 of splits. Each predictor may have a different number of splits; unused 

132 values at the end of each row are filled with the maximum value 

133 representable in the type of `X`. 

134 max_split : UInt[Array, ' p'] 

135 The number of actually used values in each row of `splits`. 

136 

137 Raises 

138 ------ 

139 ValueError 

140 If `X` has no columns or if `max_bins` is less than 1. 

141 """ 

142 out_length = min(max_bins, X.shape[1]) - 1 

143 

144 if out_length < 0: 

145 msg = f'{X.shape[1]=} and {max_bins=}, they should be both at least 1.' 

146 raise ValueError(msg) 

147 

148 @partial(autobatch, max_io_nbytes=2**29) 

149 def quantilize( 

150 X: Real[Array, 'p n'], 

151 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

152 # wrap this function because autobatch needs traceable args 

153 return _quantilized_splits_from_vector(X, out_length) 

154 

155 return quantilize(X) 

156 

157 

158@partial(vmap, in_axes=(0, None)) 

159def _quantilized_splits_from_vector( 

160 x: Real[Array, ' n'], out_length: int 

161) -> tuple[Real[Array, ' m'], UInt[Array, '']]: 

162 # find the sorted unique values in x 

163 huge = _huge_value(x) 

164 u, actual_length = unique(x, size=x.size, fill_value=huge) 

165 

166 # compute the midpoints between each unique value 

167 if jnp.issubdtype(x.dtype, jnp.integer): 

168 midpoints = u[:-1] + _ensure_unsigned(u[1:] - u[:-1]) // 2 

169 else: 

170 midpoints = u[:-1] + (u[1:] - u[:-1]) / 2 

171 # using x_i + (x_i+1 - x_i) / 2 instead of (x_i + x_i+1) / 2 is to 

172 # avoid overflow 

173 actual_length -= 1 

174 if midpoints.size: 

175 midpoints = midpoints.at[actual_length].set(huge) 

176 

177 # take a subset of the midpoints if there are more than the requested maximum 

178 indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1] 

179 indices = jnp.around(indices).astype(minimal_unsigned_dtype(midpoints.size - 1)) 

180 # indices calculation with float rather than int to avoid potential 

181 # overflow with int32, and to round to nearest instead of rounding down 

182 decimated_midpoints = midpoints[indices] 

183 truncated_midpoints = midpoints[:out_length] 

184 splits = jnp.where( 

185 actual_length > out_length, decimated_midpoints, truncated_midpoints 

186 ) 

187 max_split = jnp.minimum(actual_length, out_length) 

188 max_split = max_split.astype(minimal_unsigned_dtype(out_length)) 

189 return splits, max_split 

190 

191 

192def _huge_value(x: Shaped[Array, '...']) -> int | float: 

193 """ 

194 Return the maximum value that can be stored in `x`. 

195 

196 Parameters 

197 ---------- 

198 x 

199 A numerical numpy or jax array. 

200 

201 Returns 

202 ------- 

203 The maximum value allowed by `x`'s type (finite for floats). 

204 """ 

205 if jnp.issubdtype(x.dtype, jnp.integer): 

206 return jnp.iinfo(x.dtype).max 

207 else: 

208 return float(jnp.finfo(x.dtype).max) 

209 

210 

211def _ensure_unsigned(x: Integer[Array, '*shape']) -> UInt[Array, '*shape']: 

212 """If x has signed integer type, cast it to the unsigned dtype of the same size.""" 

213 return x.astype(_signed_to_unsigned(x.dtype)) 

214 

215 

216def _signed_to_unsigned(int_dtype: DTypeLike) -> DTypeLike: 

217 """ 

218 Map a signed integer type to its unsigned counterpart. 

219 

220 Unsigned types are passed through. 

221 """ 

222 assert jnp.issubdtype(int_dtype, jnp.integer) 

223 if jnp.issubdtype(int_dtype, jnp.unsignedinteger): 223 ↛ 224line 223 didn't jump to line 224 because the condition on line 223 was never true

224 return int_dtype 

225 match int_dtype: 

226 case jnp.int8: 226 ↛ 227line 226 didn't jump to line 227 because the pattern on line 226 never matched

227 return jnp.uint8 

228 case jnp.int16: 228 ↛ 229line 228 didn't jump to line 229 because the pattern on line 228 never matched

229 return jnp.uint16 

230 case jnp.int32: 230 ↛ 232line 230 didn't jump to line 232 because the pattern on line 230 always matched

231 return jnp.uint32 

232 case jnp.int64: 

233 return jnp.uint64 

234 case _: 

235 msg = f'unexpected integer type {int_dtype}' 

236 raise TypeError(msg) 

237 

238 

239@jit(static_argnums=(1,)) 

240def _uniform_splits_from_matrix( 

241 X: Real[Array, 'p n'], num_bins: int 

242) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

243 """ 

244 Make an evenly spaced binning grid. 

245 

246 Parameters 

247 ---------- 

248 X 

249 A matrix with `p` predictors and `n` observations. 

250 num_bins 

251 The number of bins to produce. 

252 

253 Returns 

254 ------- 

255 splits : Real[Array, 'p m'] 

256 A matrix containing, for each predictor, the boundaries between bins. 

257 The excluded endpoints are the minimum and maximum value in each row of 

258 `X`. 

259 max_split : UInt[Array, ' p'] 

260 The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``. 

261 """ 

262 low = jnp.min(X, axis=1) 

263 high = jnp.max(X, axis=1) 

264 splits = _uniform_splits_from_range(low, high, num_bins) 

265 assert splits.shape == (X.shape[0], num_bins - 1) 

266 max_split = jnp.full(X.shape[0], num_bins - 1, minimal_unsigned_dtype(num_bins - 1)) 

267 return splits, max_split 

268 

269 

270@jit(static_argnums=(2,)) 

271def _uniform_splits_from_range( 

272 low: Real[Array, ' p'], high: Real[Array, ' p'], num_bins: int 

273) -> Real[Array, 'p m']: 

274 """ 

275 Make an evenly spaced binning grid from per-predictor ranges. 

276 

277 Parameters 

278 ---------- 

279 low 

280 The lower endpoint of the grid for each predictor. 

281 high 

282 The upper endpoint of the grid for each predictor. 

283 num_bins 

284 The number of bins to produce. 

285 

286 Returns 

287 ------- 

288 A `(p, num_bins - 1)` matrix of cutpoints, with `low` and `high` excluded. 

289 """ 

290 splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1] 

291 (p,) = low.shape 

292 assert splits.shape == (p, num_bins - 1) 

293 return splits 

294 

295 

296@jit(static_argnums=(3,)) 

297def _bin_predictors_uniform( 

298 X: Real[Array, 'p n'], 

299 low: Real[Array, ' p'], 

300 high: Real[Array, ' p'], 

301 num_bins: int, 

302) -> UInt[Array, 'p n']: 

303 """ 

304 Bin predictors onto an evenly spaced grid without materializing the cutpoints. 

305 

306 This is the arithmetic equivalent of binning with the splits from 

307 `_uniform_splits_from_range`: cutpoint ``j`` is ``low + (j + 1) * step`` 

308 with ``step = (high - low) / num_bins``, and ``x`` falls in bin ``i`` iff 

309 ``cutpoint[i - 1] < x <= cutpoint[i]``. 

310 

311 Parameters 

312 ---------- 

313 X 

314 A matrix with `p` predictors and `n` observations. 

315 low 

316 The minimum value of each predictor's grid. 

317 high 

318 The maximum value of each predictor's grid. 

319 num_bins 

320 The number of bins per predictor. 

321 

322 Returns 

323 ------- 

324 `X` with each value replaced by the index of the bin it falls into. 

325 """ 

326 step = (high - low) / num_bins 

327 safe_step = jnp.where(step > 0, step, 1) 

328 # bin = #{cutpoints < x}; right-closed bins make this ceil(t) - 1 (= floor(t) 

329 # away from cutpoints), matching `searchsorted(..., side='left')` 

330 t = (X - low[:, None]) / safe_step[:, None] 

331 bins = jnp.ceil(t) - 1 

332 # constant predictors (step == 0) have coincident cutpoints at `low` 

333 bins = jnp.where( 

334 step[:, None] > 0, bins, jnp.where(low[:, None] < X, num_bins - 1, 0) 

335 ) 

336 bins = jnp.clip(bins, 0, num_bins - 1) 

337 return bins.astype(minimal_unsigned_dtype(num_bins - 1)) 

338 

339 

340@jit(static_argnames=('method',)) 

341def _bin_predictors( 

342 X: Real[Array, 'p n'], splits: Real[Array, 'p m'], **kw: Any 

343) -> UInt[Array, 'p n']: 

344 """ 

345 Bin the predictors according to the given splits. 

346 

347 A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``. 

348 

349 Parameters 

350 ---------- 

351 X 

352 A matrix with `p` predictors and `n` observations. 

353 splits 

354 A matrix containing, for each predictor, the boundaries between bins. 

355 `m` is the maximum number of splits; each row may have shorter 

356 actual length, marked by padding unused locations at the end of the 

357 row with the maximum value allowed by the type. 

358 **kw 

359 Additional arguments are passed to `jax.numpy.searchsorted`. 

360 

361 Returns 

362 ------- 

363 `X` but with each value replaced by the index of the bin it falls into. 

364 """ 

365 

366 @partial(autobatch, max_io_nbytes=2**29) 

367 @vmap 

368 def bin_predictors( 

369 x: Real[Array, ' n'], splits: Real[Array, ' m'] 

370 ) -> UInt[Array, ' n']: 

371 dtype = minimal_unsigned_dtype(splits.size) 

372 return jnp.searchsorted(splits, x, **kw).astype(dtype) 

373 

374 return bin_predictors(X, splits) 

375 

376 

377class Binner(Module): 

378 """Abstract base class for predictor binners. 

379 

380 A binner inspects the training predictors at construction time, 

381 chooses cutpoints for each predictor, and encapsulates the logic 

382 that maps any predictor matrix (training or test) to bin indices via 

383 `bin`. 

384 

385 A predictor value ``x`` is mapped to bin ``i`` iff 

386 ``c[i - 1] < x <= c[i]``, where ``c`` are the cutpoints chosen for 

387 that predictor at construction. A predictor with ``k`` cutpoints 

388 therefore has ``k + 1`` bins indexed from ``0`` to ``k``. The number 

389 of cutpoints actually used per predictor is exposed as `max_split` 

390 and may differ across predictors; the remaining capacity, if any, is 

391 padded internally with the maximum value representable in the dtype 

392 of the cutpoints, so binning still produces a valid in-range index. 

393 

394 The constructor takes the training predictors and an optional random 

395 key. Concrete subclasses may add their own keyword arguments. Binners 

396 that do not use the key still accept it for protocol uniformity and 

397 silently ignore it. Binners that need it raise `ValueError` if it is 

398 not provided. 

399 """ 

400 

401 max_split: AbstractVar[UInt[Array, ' p']] 

402 """The number of cutpoints actually used for each of the `p` predictors.""" 

403 

404 _splits: AbstractVar[Real[Array, 'p m']] 

405 """The cutpoints for each of the `p` predictors, padded to a common length.""" 

406 

407 @abstractmethod 

408 def __init__( 

409 self, X: Real[Array, 'p n'], *, key: Key[Array, ''] | None = None 

410 ) -> None: ... 

411 

412 @abstractmethod 

413 def bin(self, X: Real[Array, 'p n']) -> UInt[Array, 'p n']: 

414 """Map predictors to bin indices using the cutpoints chosen at construction. 

415 

416 Parameters 

417 ---------- 

418 X 

419 A matrix with `p` predictors and `n` observations. Must have 

420 the same number of predictors as the training matrix passed 

421 to the constructor. 

422 

423 Returns 

424 ------- 

425 Quantized `X` with minimal data type. 

426 """ 

427 ... 

428 

429 

430@runtime_checkable 

431class BinnerFactory(Protocol): 

432 """Callable that constructs a `Binner` from training predictors. 

433 

434 This is the type of the `binner` argument of `bartz.Bart`. A bare 

435 `Binner` subclass satisfies this protocol, as does 

436 ``functools.partial(BinnerSubclass, **subclass_kwargs)``. 

437 """ 

438 

439 def __call__( 

440 self, X: Real[Array, 'p n'], *, key: Key[Array, ''] | None = None 

441 ) -> Binner: 

442 """Construct a `Binner` from `X` and an optional random key.""" 

443 ... 

444 

445 

446class RangeEvenBinner(Binner): 

447 """Binner with cutpoints evenly spaced over the observed range. 

448 

449 For each predictor, ``max_bins - 1`` cutpoints are placed at 

450 equally spaced positions strictly between the minimum and the 

451 maximum value observed in the training matrix. All predictors use 

452 the same number of cutpoints. 

453 

454 Parameters 

455 ---------- 

456 X 

457 Training predictors with `p` predictors and `n` observations. 

458 max_bins 

459 The number of bins per predictor; ``max_bins - 1`` cutpoints 

460 are produced per predictor. 

461 key 

462 Accepted for protocol uniformity; unused. 

463 """ 

464 

465 _low: Real[Array, ' p'] 

466 """Minimum observed value per predictor.""" 

467 

468 _high: Real[Array, ' p'] 

469 """Maximum observed value per predictor.""" 

470 

471 # WORKAROUND(jax<0.9.1): use `jax.tree.static` instead of `field(static=True)` 

472 _max_bins: int = field(static=True) 

473 """Number of bins per predictor.""" 

474 

475 max_split: UInt[Array, ' p'] 

476 

477 def __init__( 

478 self, 

479 X: Real[Array, 'p n'], 

480 *, 

481 max_bins: int = 256, 

482 key: Key[Array, ''] | None = None, 

483 ) -> None: 

484 del key 

485 self._low = jnp.min(X, axis=1) 

486 self._high = jnp.max(X, axis=1) 

487 self._max_bins = max_bins 

488 self.max_split = jnp.full( 

489 X.shape[0], max_bins - 1, minimal_unsigned_dtype(max_bins - 1) 

490 ) 

491 

492 @property 

493 def _splits(self) -> Real[Array, 'p m']: 

494 """Materialize the cutpoints. Intended for testing only, not library use. 

495 

496 The cutpoints are not stored: `bin` works arithmetically from the 

497 observed range, since they are evenly spaced. This property reconstructs 

498 them only to expose them; the library should rely on `bin` and 

499 `max_split` instead. 

500 """ 

501 return _uniform_splits_from_range(self._low, self._high, self._max_bins) 

502 

503 def bin(self, X: Real[Array, 'p n']) -> UInt[Array, 'p n']: 

504 return _bin_predictors_uniform(X, self._low, self._high, self._max_bins) 

505 

506 

507class UniqueQuantileBinner(Binner): 

508 """Binner with quantile-based cutpoints from observed unique values. 

509 

510 For each predictor, cutpoints are placed between sorted unique 

511 values so that the empirical distribution is approximately uniform 

512 across bins. The number of cutpoints is at most ``max_bins - 1`` 

513 and at most one less than the number of unique values, so different 

514 predictors may end up with different effective cutpoint counts. 

515 Trailing unused entries of the cutpoint matrix are padded with the 

516 maximum value representable in the dtype of `X`. 

517 

518 Note: the quantiles are over the *unique* values, not over the 

519 original distribution. 

520 

521 When ``n > max_subsample``, the predictor matrix is randomly thinned 

522 along the observation axis to ``max_subsample`` columns before 

523 quantilization. Each predictor row is thinned independently and 

524 without replacement. This keeps quantilization tractable on very 

525 large datasets at the cost of approximate quantiles. 

526 

527 Parameters 

528 ---------- 

529 X 

530 Training predictors with `p` predictors and `n` observations. 

531 max_bins 

532 The maximum number of bins per predictor. 

533 max_subsample 

534 The maximum number of observations to use when computing 

535 quantiles. If `None`, no subsampling is performed. If `n` 

536 exceeds this, `key` is required. 

537 key 

538 Random key for subsampling. Required when ``X.shape[1] > 

539 max_subsample``; otherwise unused. 

540 

541 Raises 

542 ------ 

543 ValueError 

544 If subsampling would trigger but `key` is `None`. 

545 """ 

546 

547 _splits: Real[Array, 'p m'] 

548 """Cutpoints per predictor, padded on the right with the dtype's maximum value.""" 

549 

550 max_split: UInt[Array, ' p'] 

551 

552 def __init__( 

553 self, 

554 X: Real[Array, 'p n'], 

555 *, 

556 max_bins: int = 256, 

557 max_subsample: int | None = 100_000, 

558 key: Key[Array, ''] | None = None, 

559 ) -> None: 

560 if max_subsample is not None and X.shape[1] > max_subsample: 

561 if key is None: 

562 msg = ( 

563 'UniqueQuantileBinner requires a `key` because ' 

564 f'n={X.shape[1]} exceeds max_subsample={max_subsample}.' 

565 ) 

566 raise ValueError(msg) 

567 X = _subsample(key, X, max_subsample) 

568 self._splits, self.max_split = _quantilized_splits_from_matrix(X, max_bins) 

569 

570 def bin(self, X: Real[Array, 'p n']) -> UInt[Array, 'p n']: 

571 return _bin_predictors(X, self._splits) 

572 

573 

574class GivenSplitsBinner(Binner): 

575 """Binner with cutpoints supplied directly in R BART `xinfo` format. 

576 

577 The cutpoints are taken verbatim from `xinfo`: a `(p, m)` matrix 

578 whose rows hold per-predictor sorted cutpoints, with NaN-padded 

579 trailing entries marking unused capacity. Internally NaNs are 

580 replaced by the maximum representable value in the dtype of 

581 `xinfo`, and `max_split` is set to the count of non-NaN entries 

582 per row, so binning behaves as if the row had been declared with 

583 only its non-NaN cutpoints. 

584 

585 Parameters 

586 ---------- 

587 X 

588 Training predictors. Used only to validate the shape of `xinfo`. 

589 xinfo 

590 A `(p, m)` matrix of cutpoints. Each row holds a sorted list of 

591 cutpoints for one predictor, optionally padded on the right with 

592 NaN. 

593 key 

594 Accepted for protocol uniformity; unused. 

595 

596 Raises 

597 ------ 

598 ValueError 

599 If `xinfo` is not 2D, or if its first dimension does not match 

600 ``X.shape[0]``. 

601 """ 

602 

603 _splits: Float[Array, 'p m'] 

604 """Cutpoints per predictor, with NaNs replaced by the dtype's maximum value.""" 

605 

606 max_split: UInt[Array, ' p'] 

607 

608 def __init__( 

609 self, 

610 X: Real[Array, 'p n'], 

611 *, 

612 xinfo: Float[Array, 'p m'], 

613 key: Key[Array, ''] | None = None, 

614 ) -> None: 

615 del key 

616 if xinfo.ndim != 2 or xinfo.shape[0] != X.shape[0]: 

617 msg = f'{xinfo.shape=} different from expected ({X.shape[0]}, *)' 

618 raise ValueError(msg) 

619 self._splits, self.max_split = _parse_xinfo(xinfo) 

620 

621 def bin(self, X: Real[Array, 'p n']) -> UInt[Array, 'p n']: 

622 return _bin_predictors(X, self._splits) 

623 

624 

625@jit 

626def _sigma2_from_ols( 

627 x_train: Shaped[Array, 'p n'], y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] 

628) -> Float32[Array, ''] | Float32[Array, ' k']: 

629 """Return the error variance estimated with OLS with intercept.""" 

630 x_centered = x_train.T - x_train.mean(axis=1) 

631 y_centered = y_train.T - y_train.mean(axis=-1) 

632 # centering is equivalent to adding an intercept column 

633 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 

634 chisq = chisq.reshape(y_train.shape[:-1]) 

635 dof = y_train.shape[-1] - rank 

636 return chisq / dof