Coverage for src/bartz/prepcovars/_prepcovars.py: 93%
164 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/prepcovars/_prepcovars.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implementation of the predictor preprocessing utilities."""
27from abc import abstractmethod
28from functools import partial
29from typing import Any, Protocol, runtime_checkable
31from equinox import AbstractVar, Module, field
32from jax import numpy as jnp
33from jax import random, vmap
34from jax.typing import DTypeLike
35from jaxtyping import Array, Float, Float32, Integer, Key, Real, Shaped, UInt
37from bartz._jaxext import autobatch, jit, minimal_unsigned_dtype, unique
40def _parse_xinfo(
41 xinfo: Float[Array, 'p m'],
42) -> tuple[Float[Array, 'p m'], UInt[Array, ' p']]:
43 """Parse pre-defined splits in the format of the R package BART.
45 Parameters
46 ----------
47 xinfo
48 A matrix with the cutpoins to use to bin each predictor. Each row shall
49 contain a sorted list of cutpoints for a predictor. If there are less
50 cutpoints than the number of columns in the matrix, fill the remaining
51 cells with NaN.
53 `xinfo` shall be a matrix even if `x_train` is a dataframe.
55 Returns
56 -------
57 splits : Float[Array, 'p m']
58 `xinfo` modified by replacing nan with a large value.
59 max_split : UInt[Array, 'p']
60 The number of non-nan elements in each row of `xinfo`.
61 """
62 is_not_nan = ~jnp.isnan(xinfo)
63 max_split = jnp.sum(is_not_nan, axis=1)
64 max_split = max_split.astype(minimal_unsigned_dtype(xinfo.shape[1]))
65 huge = _huge_value(xinfo)
66 splits = jnp.where(is_not_nan, xinfo, huge)
67 return splits, max_split
70@jit(static_argnums=(2,))
71def _subsample(
72 key: Key[Array, ''], X: Real[Array, 'p n'], max_samples: int
73) -> Real[Array, 'p m']:
74 """Randomly thin each predictor row to at most `max_samples` elements.
76 Parameters
77 ----------
78 key
79 A jax random key.
80 X
81 A matrix with `p` predictors and `n` observations.
82 max_samples
83 The target maximum number of samples per row.
85 Returns
86 -------
87 A matrix with `p` rows and ``min(n, max_samples)`` columns. If ``n <= max_samples``, `X` is returned unchanged. Otherwise each row contains `max_samples` distinct values drawn without replacement from the corresponding row of `X`, with rows sampled independently. The order of values within each row is unspecified.
89 Raises
90 ------
91 ValueError
92 If `max_samples` is less than 1.
93 """
94 if max_samples < 1:
95 msg = f'{max_samples=}, must be at least 1.'
96 raise ValueError(msg)
98 p, n = X.shape
99 if n <= max_samples:
100 return X
102 keys = random.split(key, p)
104 @partial(autobatch, max_io_nbytes=2**29)
105 @vmap
106 def per_row(k: Key[Array, ''], x: Real[Array, ' n']) -> Real[Array, ' m']:
107 return random.choice(k, x, shape=(max_samples,), replace=False)
109 return per_row(keys, X)
112@jit(static_argnums=(1,))
113def _quantilized_splits_from_matrix(
114 X: Real[Array, 'p n'], max_bins: int
115) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
116 """
117 Determine bins that make the distribution of each predictor uniform.
119 Parameters
120 ----------
121 X
122 A matrix with `p` predictors and `n` observations.
123 max_bins
124 The maximum number of bins to produce.
126 Returns
127 -------
128 splits : Real[Array, 'p m']
129 A matrix containing, for each predictor, the boundaries between bins.
130 `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number
131 of splits. Each predictor may have a different number of splits; unused
132 values at the end of each row are filled with the maximum value
133 representable in the type of `X`.
134 max_split : UInt[Array, ' p']
135 The number of actually used values in each row of `splits`.
137 Raises
138 ------
139 ValueError
140 If `X` has no columns or if `max_bins` is less than 1.
141 """
142 out_length = min(max_bins, X.shape[1]) - 1
144 if out_length < 0:
145 msg = f'{X.shape[1]=} and {max_bins=}, they should be both at least 1.'
146 raise ValueError(msg)
148 @partial(autobatch, max_io_nbytes=2**29)
149 def quantilize(
150 X: Real[Array, 'p n'],
151 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
152 # wrap this function because autobatch needs traceable args
153 return _quantilized_splits_from_vector(X, out_length)
155 return quantilize(X)
158@partial(vmap, in_axes=(0, None))
159def _quantilized_splits_from_vector(
160 x: Real[Array, ' n'], out_length: int
161) -> tuple[Real[Array, ' m'], UInt[Array, '']]:
162 # find the sorted unique values in x
163 huge = _huge_value(x)
164 u, actual_length = unique(x, size=x.size, fill_value=huge)
166 # compute the midpoints between each unique value
167 if jnp.issubdtype(x.dtype, jnp.integer):
168 midpoints = u[:-1] + _ensure_unsigned(u[1:] - u[:-1]) // 2
169 else:
170 midpoints = u[:-1] + (u[1:] - u[:-1]) / 2
171 # using x_i + (x_i+1 - x_i) / 2 instead of (x_i + x_i+1) / 2 is to
172 # avoid overflow
173 actual_length -= 1
174 if midpoints.size:
175 midpoints = midpoints.at[actual_length].set(huge)
177 # take a subset of the midpoints if there are more than the requested maximum
178 indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1]
179 indices = jnp.around(indices).astype(minimal_unsigned_dtype(midpoints.size - 1))
180 # indices calculation with float rather than int to avoid potential
181 # overflow with int32, and to round to nearest instead of rounding down
182 decimated_midpoints = midpoints[indices]
183 truncated_midpoints = midpoints[:out_length]
184 splits = jnp.where(
185 actual_length > out_length, decimated_midpoints, truncated_midpoints
186 )
187 max_split = jnp.minimum(actual_length, out_length)
188 max_split = max_split.astype(minimal_unsigned_dtype(out_length))
189 return splits, max_split
192def _huge_value(x: Shaped[Array, '...']) -> int | float:
193 """
194 Return the maximum value that can be stored in `x`.
196 Parameters
197 ----------
198 x
199 A numerical numpy or jax array.
201 Returns
202 -------
203 The maximum value allowed by `x`'s type (finite for floats).
204 """
205 if jnp.issubdtype(x.dtype, jnp.integer):
206 return jnp.iinfo(x.dtype).max
207 else:
208 return float(jnp.finfo(x.dtype).max)
211def _ensure_unsigned(x: Integer[Array, '*shape']) -> UInt[Array, '*shape']:
212 """If x has signed integer type, cast it to the unsigned dtype of the same size."""
213 return x.astype(_signed_to_unsigned(x.dtype))
216def _signed_to_unsigned(int_dtype: DTypeLike) -> DTypeLike:
217 """
218 Map a signed integer type to its unsigned counterpart.
220 Unsigned types are passed through.
221 """
222 assert jnp.issubdtype(int_dtype, jnp.integer)
223 if jnp.issubdtype(int_dtype, jnp.unsignedinteger): 223 ↛ 224line 223 didn't jump to line 224 because the condition on line 223 was never true
224 return int_dtype
225 match int_dtype:
226 case jnp.int8: 226 ↛ 227line 226 didn't jump to line 227 because the pattern on line 226 never matched
227 return jnp.uint8
228 case jnp.int16: 228 ↛ 229line 228 didn't jump to line 229 because the pattern on line 228 never matched
229 return jnp.uint16
230 case jnp.int32: 230 ↛ 232line 230 didn't jump to line 232 because the pattern on line 230 always matched
231 return jnp.uint32
232 case jnp.int64:
233 return jnp.uint64
234 case _:
235 msg = f'unexpected integer type {int_dtype}'
236 raise TypeError(msg)
239@jit(static_argnums=(1,))
240def _uniform_splits_from_matrix(
241 X: Real[Array, 'p n'], num_bins: int
242) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
243 """
244 Make an evenly spaced binning grid.
246 Parameters
247 ----------
248 X
249 A matrix with `p` predictors and `n` observations.
250 num_bins
251 The number of bins to produce.
253 Returns
254 -------
255 splits : Real[Array, 'p m']
256 A matrix containing, for each predictor, the boundaries between bins.
257 The excluded endpoints are the minimum and maximum value in each row of
258 `X`.
259 max_split : UInt[Array, ' p']
260 The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``.
261 """
262 low = jnp.min(X, axis=1)
263 high = jnp.max(X, axis=1)
264 splits = _uniform_splits_from_range(low, high, num_bins)
265 assert splits.shape == (X.shape[0], num_bins - 1)
266 max_split = jnp.full(X.shape[0], num_bins - 1, minimal_unsigned_dtype(num_bins - 1))
267 return splits, max_split
270@jit(static_argnums=(2,))
271def _uniform_splits_from_range(
272 low: Real[Array, ' p'], high: Real[Array, ' p'], num_bins: int
273) -> Real[Array, 'p m']:
274 """
275 Make an evenly spaced binning grid from per-predictor ranges.
277 Parameters
278 ----------
279 low
280 The lower endpoint of the grid for each predictor.
281 high
282 The upper endpoint of the grid for each predictor.
283 num_bins
284 The number of bins to produce.
286 Returns
287 -------
288 A `(p, num_bins - 1)` matrix of cutpoints, with `low` and `high` excluded.
289 """
290 splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1]
291 (p,) = low.shape
292 assert splits.shape == (p, num_bins - 1)
293 return splits
296@jit(static_argnums=(3,))
297def _bin_predictors_uniform(
298 X: Real[Array, 'p n'],
299 low: Real[Array, ' p'],
300 high: Real[Array, ' p'],
301 num_bins: int,
302) -> UInt[Array, 'p n']:
303 """
304 Bin predictors onto an evenly spaced grid without materializing the cutpoints.
306 This is the arithmetic equivalent of binning with the splits from
307 `_uniform_splits_from_range`: cutpoint ``j`` is ``low + (j + 1) * step``
308 with ``step = (high - low) / num_bins``, and ``x`` falls in bin ``i`` iff
309 ``cutpoint[i - 1] < x <= cutpoint[i]``.
311 Parameters
312 ----------
313 X
314 A matrix with `p` predictors and `n` observations.
315 low
316 The minimum value of each predictor's grid.
317 high
318 The maximum value of each predictor's grid.
319 num_bins
320 The number of bins per predictor.
322 Returns
323 -------
324 `X` with each value replaced by the index of the bin it falls into.
325 """
326 step = (high - low) / num_bins
327 safe_step = jnp.where(step > 0, step, 1)
328 # bin = #{cutpoints < x}; right-closed bins make this ceil(t) - 1 (= floor(t)
329 # away from cutpoints), matching `searchsorted(..., side='left')`
330 t = (X - low[:, None]) / safe_step[:, None]
331 bins = jnp.ceil(t) - 1
332 # constant predictors (step == 0) have coincident cutpoints at `low`
333 bins = jnp.where(
334 step[:, None] > 0, bins, jnp.where(low[:, None] < X, num_bins - 1, 0)
335 )
336 bins = jnp.clip(bins, 0, num_bins - 1)
337 return bins.astype(minimal_unsigned_dtype(num_bins - 1))
340@jit(static_argnames=('method',))
341def _bin_predictors(
342 X: Real[Array, 'p n'], splits: Real[Array, 'p m'], **kw: Any
343) -> UInt[Array, 'p n']:
344 """
345 Bin the predictors according to the given splits.
347 A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``.
349 Parameters
350 ----------
351 X
352 A matrix with `p` predictors and `n` observations.
353 splits
354 A matrix containing, for each predictor, the boundaries between bins.
355 `m` is the maximum number of splits; each row may have shorter
356 actual length, marked by padding unused locations at the end of the
357 row with the maximum value allowed by the type.
358 **kw
359 Additional arguments are passed to `jax.numpy.searchsorted`.
361 Returns
362 -------
363 `X` but with each value replaced by the index of the bin it falls into.
364 """
366 @partial(autobatch, max_io_nbytes=2**29)
367 @vmap
368 def bin_predictors(
369 x: Real[Array, ' n'], splits: Real[Array, ' m']
370 ) -> UInt[Array, ' n']:
371 dtype = minimal_unsigned_dtype(splits.size)
372 return jnp.searchsorted(splits, x, **kw).astype(dtype)
374 return bin_predictors(X, splits)
377class Binner(Module):
378 """Abstract base class for predictor binners.
380 A binner inspects the training predictors at construction time,
381 chooses cutpoints for each predictor, and encapsulates the logic
382 that maps any predictor matrix (training or test) to bin indices via
383 `bin`.
385 A predictor value ``x`` is mapped to bin ``i`` iff
386 ``c[i - 1] < x <= c[i]``, where ``c`` are the cutpoints chosen for
387 that predictor at construction. A predictor with ``k`` cutpoints
388 therefore has ``k + 1`` bins indexed from ``0`` to ``k``. The number
389 of cutpoints actually used per predictor is exposed as `max_split`
390 and may differ across predictors; the remaining capacity, if any, is
391 padded internally with the maximum value representable in the dtype
392 of the cutpoints, so binning still produces a valid in-range index.
394 The constructor takes the training predictors and an optional random
395 key. Concrete subclasses may add their own keyword arguments. Binners
396 that do not use the key still accept it for protocol uniformity and
397 silently ignore it. Binners that need it raise `ValueError` if it is
398 not provided.
399 """
401 max_split: AbstractVar[UInt[Array, ' p']]
402 """The number of cutpoints actually used for each of the `p` predictors."""
404 _splits: AbstractVar[Real[Array, 'p m']]
405 """The cutpoints for each of the `p` predictors, padded to a common length."""
407 @abstractmethod
408 def __init__(
409 self, X: Real[Array, 'p n'], *, key: Key[Array, ''] | None = None
410 ) -> None: ...
412 @abstractmethod
413 def bin(self, X: Real[Array, 'p n']) -> UInt[Array, 'p n']:
414 """Map predictors to bin indices using the cutpoints chosen at construction.
416 Parameters
417 ----------
418 X
419 A matrix with `p` predictors and `n` observations. Must have
420 the same number of predictors as the training matrix passed
421 to the constructor.
423 Returns
424 -------
425 Quantized `X` with minimal data type.
426 """
427 ...
430@runtime_checkable
431class BinnerFactory(Protocol):
432 """Callable that constructs a `Binner` from training predictors.
434 This is the type of the `binner` argument of `bartz.Bart`. A bare
435 `Binner` subclass satisfies this protocol, as does
436 ``functools.partial(BinnerSubclass, **subclass_kwargs)``.
437 """
439 def __call__(
440 self, X: Real[Array, 'p n'], *, key: Key[Array, ''] | None = None
441 ) -> Binner:
442 """Construct a `Binner` from `X` and an optional random key."""
443 ...
446class RangeEvenBinner(Binner):
447 """Binner with cutpoints evenly spaced over the observed range.
449 For each predictor, ``max_bins - 1`` cutpoints are placed at
450 equally spaced positions strictly between the minimum and the
451 maximum value observed in the training matrix. All predictors use
452 the same number of cutpoints.
454 Parameters
455 ----------
456 X
457 Training predictors with `p` predictors and `n` observations.
458 max_bins
459 The number of bins per predictor; ``max_bins - 1`` cutpoints
460 are produced per predictor.
461 key
462 Accepted for protocol uniformity; unused.
463 """
465 _low: Real[Array, ' p']
466 """Minimum observed value per predictor."""
468 _high: Real[Array, ' p']
469 """Maximum observed value per predictor."""
471 # WORKAROUND(jax<0.9.1): use `jax.tree.static` instead of `field(static=True)`
472 _max_bins: int = field(static=True)
473 """Number of bins per predictor."""
475 max_split: UInt[Array, ' p']
477 def __init__(
478 self,
479 X: Real[Array, 'p n'],
480 *,
481 max_bins: int = 256,
482 key: Key[Array, ''] | None = None,
483 ) -> None:
484 del key
485 self._low = jnp.min(X, axis=1)
486 self._high = jnp.max(X, axis=1)
487 self._max_bins = max_bins
488 self.max_split = jnp.full(
489 X.shape[0], max_bins - 1, minimal_unsigned_dtype(max_bins - 1)
490 )
492 @property
493 def _splits(self) -> Real[Array, 'p m']:
494 """Materialize the cutpoints. Intended for testing only, not library use.
496 The cutpoints are not stored: `bin` works arithmetically from the
497 observed range, since they are evenly spaced. This property reconstructs
498 them only to expose them; the library should rely on `bin` and
499 `max_split` instead.
500 """
501 return _uniform_splits_from_range(self._low, self._high, self._max_bins)
503 def bin(self, X: Real[Array, 'p n']) -> UInt[Array, 'p n']:
504 return _bin_predictors_uniform(X, self._low, self._high, self._max_bins)
507class UniqueQuantileBinner(Binner):
508 """Binner with quantile-based cutpoints from observed unique values.
510 For each predictor, cutpoints are placed between sorted unique
511 values so that the empirical distribution is approximately uniform
512 across bins. The number of cutpoints is at most ``max_bins - 1``
513 and at most one less than the number of unique values, so different
514 predictors may end up with different effective cutpoint counts.
515 Trailing unused entries of the cutpoint matrix are padded with the
516 maximum value representable in the dtype of `X`.
518 Note: the quantiles are over the *unique* values, not over the
519 original distribution.
521 When ``n > max_subsample``, the predictor matrix is randomly thinned
522 along the observation axis to ``max_subsample`` columns before
523 quantilization. Each predictor row is thinned independently and
524 without replacement. This keeps quantilization tractable on very
525 large datasets at the cost of approximate quantiles.
527 Parameters
528 ----------
529 X
530 Training predictors with `p` predictors and `n` observations.
531 max_bins
532 The maximum number of bins per predictor.
533 max_subsample
534 The maximum number of observations to use when computing
535 quantiles. If `None`, no subsampling is performed. If `n`
536 exceeds this, `key` is required.
537 key
538 Random key for subsampling. Required when ``X.shape[1] >
539 max_subsample``; otherwise unused.
541 Raises
542 ------
543 ValueError
544 If subsampling would trigger but `key` is `None`.
545 """
547 _splits: Real[Array, 'p m']
548 """Cutpoints per predictor, padded on the right with the dtype's maximum value."""
550 max_split: UInt[Array, ' p']
552 def __init__(
553 self,
554 X: Real[Array, 'p n'],
555 *,
556 max_bins: int = 256,
557 max_subsample: int | None = 100_000,
558 key: Key[Array, ''] | None = None,
559 ) -> None:
560 if max_subsample is not None and X.shape[1] > max_subsample:
561 if key is None:
562 msg = (
563 'UniqueQuantileBinner requires a `key` because '
564 f'n={X.shape[1]} exceeds max_subsample={max_subsample}.'
565 )
566 raise ValueError(msg)
567 X = _subsample(key, X, max_subsample)
568 self._splits, self.max_split = _quantilized_splits_from_matrix(X, max_bins)
570 def bin(self, X: Real[Array, 'p n']) -> UInt[Array, 'p n']:
571 return _bin_predictors(X, self._splits)
574class GivenSplitsBinner(Binner):
575 """Binner with cutpoints supplied directly in R BART `xinfo` format.
577 The cutpoints are taken verbatim from `xinfo`: a `(p, m)` matrix
578 whose rows hold per-predictor sorted cutpoints, with NaN-padded
579 trailing entries marking unused capacity. Internally NaNs are
580 replaced by the maximum representable value in the dtype of
581 `xinfo`, and `max_split` is set to the count of non-NaN entries
582 per row, so binning behaves as if the row had been declared with
583 only its non-NaN cutpoints.
585 Parameters
586 ----------
587 X
588 Training predictors. Used only to validate the shape of `xinfo`.
589 xinfo
590 A `(p, m)` matrix of cutpoints. Each row holds a sorted list of
591 cutpoints for one predictor, optionally padded on the right with
592 NaN.
593 key
594 Accepted for protocol uniformity; unused.
596 Raises
597 ------
598 ValueError
599 If `xinfo` is not 2D, or if its first dimension does not match
600 ``X.shape[0]``.
601 """
603 _splits: Float[Array, 'p m']
604 """Cutpoints per predictor, with NaNs replaced by the dtype's maximum value."""
606 max_split: UInt[Array, ' p']
608 def __init__(
609 self,
610 X: Real[Array, 'p n'],
611 *,
612 xinfo: Float[Array, 'p m'],
613 key: Key[Array, ''] | None = None,
614 ) -> None:
615 del key
616 if xinfo.ndim != 2 or xinfo.shape[0] != X.shape[0]:
617 msg = f'{xinfo.shape=} different from expected ({X.shape[0]}, *)'
618 raise ValueError(msg)
619 self._splits, self.max_split = _parse_xinfo(xinfo)
621 def bin(self, X: Real[Array, 'p n']) -> UInt[Array, 'p n']:
622 return _bin_predictors(X, self._splits)
625@jit
626def _sigma2_from_ols(
627 x_train: Shaped[Array, 'p n'], y_train: Float32[Array, ' n'] | Float32[Array, 'k n']
628) -> Float32[Array, ''] | Float32[Array, ' k']:
629 """Return the error variance estimated with OLS with intercept."""
630 x_centered = x_train.T - x_train.mean(axis=1)
631 y_centered = y_train.T - y_train.mean(axis=-1)
632 # centering is equivalent to adding an intercept column
633 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
634 chisq = chisq.reshape(y_train.shape[:-1])
635 dof = y_train.shape[-1] - rank
636 return chisq / dof