Coverage for src / bartz / prepcovars.py: 86%
79 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/prepcovars.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions to preprocess data."""
27from functools import partial
28from typing import Any
30from jax import jit, vmap
31from jax import numpy as jnp
32from jaxtyping import Array, Float, Integer, Real, UInt
34from bartz.jaxext import autobatch, minimal_unsigned_dtype, unique
37def parse_xinfo(
38 xinfo: Float[Array, 'p m'],
39) -> tuple[Float[Array, 'p m'], UInt[Array, ' p']]:
40 """Parse pre-defined splits in the format of the R package BART.
42 Parameters
43 ----------
44 xinfo
45 A matrix with the cutpoins to use to bin each predictor. Each row shall
46 contain a sorted list of cutpoints for a predictor. If there are less
47 cutpoints than the number of columns in the matrix, fill the remaining
48 cells with NaN.
50 `xinfo` shall be a matrix even if `x_train` is a dataframe.
52 Returns
53 -------
54 splits : Float[Array, 'p m']
55 `xinfo` modified by replacing nan with a large value.
56 max_split : UInt[Array, 'p']
57 The number of non-nan elements in each row of `xinfo`.
58 """
59 is_not_nan = ~jnp.isnan(xinfo) 1qrstuvwxyzA
60 max_split = jnp.sum(is_not_nan, axis=1) 1qrstuvwxyzA
61 max_split = max_split.astype(minimal_unsigned_dtype(xinfo.shape[1])) 1qrstuvwxyzA
62 huge = _huge_value(xinfo) 1qrstuvwxyzA
63 splits = jnp.where(is_not_nan, xinfo, huge) 1qrstuvwxyzA
64 return splits, max_split 1qrstuvwxyzA
67@partial(jit, static_argnums=(1,))
68def quantilized_splits_from_matrix(
69 X: Real[Array, 'p n'], max_bins: int
70) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
71 """
72 Determine bins that make the distribution of each predictor uniform.
74 Parameters
75 ----------
76 X
77 A matrix with `p` predictors and `n` observations.
78 max_bins
79 The maximum number of bins to produce.
81 Returns
82 -------
83 splits : Real[Array, 'p m']
84 A matrix containing, for each predictor, the boundaries between bins.
85 `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number
86 of splits. Each predictor may have a different number of splits; unused
87 values at the end of each row are filled with the maximum value
88 representable in the type of `X`.
89 max_split : UInt[Array, ' p']
90 The number of actually used values in each row of `splits`.
92 Raises
93 ------
94 ValueError
95 If `X` has no columns or if `max_bins` is less than 1.
96 """
97 out_length = min(max_bins, X.shape[1]) - 1 1fghijklabecmnodLM
99 if out_length < 0: 1fghijklabecmnodLM
100 msg = f'{X.shape[1]=} and {max_bins=}, they should be both at least 1.' 1LM
101 raise ValueError(msg) 1LM
103 @partial(autobatch, max_io_nbytes=2**29) 1fghijklabecmnod
104 def quantilize( 1fghijklabecmnod
105 X: Real[Array, 'p n'],
106 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
107 # wrap this function because autobatch needs traceable args
108 return _quantilized_splits_from_matrix(X, out_length) 1fghijklabecmnod
110 return quantilize(X) 1fghijklabecmnod
113@partial(vmap, in_axes=(0, None))
114def _quantilized_splits_from_matrix(
115 x: Real[Array, 'p n'], out_length: int
116) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
117 # find the sorted unique values in x
118 huge = _huge_value(x) 1fghijklabecmnod
119 u, actual_length = unique(x, size=x.size, fill_value=huge) 1fghijklabecmnod
121 # compute the midpoints between each unique value
122 if jnp.issubdtype(x.dtype, jnp.integer): 1fghijklabecmnod
123 midpoints = u[:-1] + _ensure_unsigned(u[1:] - u[:-1]) // 2 1abecd
124 else:
125 midpoints = u[:-1] + (u[1:] - u[:-1]) / 2 1fghijklmno
126 # using x_i + (x_i+1 - x_i) / 2 instead of (x_i + x_i+1) / 2 is to
127 # avoid overflow
128 actual_length -= 1 1fghijklabecmnod
129 if midpoints.size: 1fghijklabecmnod
130 midpoints = midpoints.at[actual_length].set(huge) 1fghijklabcmnod
132 # take a subset of the midpoints if there are more than the requested maximum
133 indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1] 1fghijklabecmnod
134 indices = jnp.around(indices).astype(minimal_unsigned_dtype(midpoints.size - 1)) 1fghijklabecmnod
135 # indices calculation with float rather than int to avoid potential
136 # overflow with int32, and to round to nearest instead of rounding down
137 decimated_midpoints = midpoints[indices] 1fghijklabecmnod
138 truncated_midpoints = midpoints[:out_length] 1fghijklabecmnod
139 splits = jnp.where( 1fghijklabecmnod
140 actual_length > out_length, decimated_midpoints, truncated_midpoints
141 )
142 max_split = jnp.minimum(actual_length, out_length) 1fghijklabecmnod
143 max_split = max_split.astype(minimal_unsigned_dtype(out_length)) 1fghijklabecmnod
144 return splits, max_split 1fghijklabecmnod
147def _huge_value(x: Array) -> int | float:
148 """
149 Return the maximum value that can be stored in `x`.
151 Parameters
152 ----------
153 x
154 A numerical numpy or jax array.
156 Returns
157 -------
158 The maximum value allowed by `x`'s type (finite for floats).
159 """
160 if jnp.issubdtype(x.dtype, jnp.integer): 1qfghijrstuvwxyzklAabecmnod
161 return jnp.iinfo(x.dtype).max 1abecd
162 else:
163 return float(jnp.finfo(x.dtype).max) 1qfghijrstuvwxyzklAmno
166def _ensure_unsigned(x: Integer[Array, '*shape']) -> UInt[Array, '*shape']:
167 """If x has signed integer type, cast it to the unsigned dtype of the same size."""
168 return x.astype(_signed_to_unsigned(x.dtype)) 1abecd
171def _signed_to_unsigned(int_dtype: jnp.dtype) -> jnp.dtype:
172 """
173 Map a signed integer type to its unsigned counterpart.
175 Unsigned types are passed through.
176 """
177 assert jnp.issubdtype(int_dtype, jnp.integer) 1abecd
178 if jnp.issubdtype(int_dtype, jnp.unsignedinteger): 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true1abecd
179 return int_dtype
180 match int_dtype: 1abecd
181 case jnp.int8: 181 ↛ 182line 181 didn't jump to line 182 because the pattern on line 181 never matched1abecd
182 return jnp.uint8
183 case jnp.int16: 183 ↛ 184line 183 didn't jump to line 184 because the pattern on line 183 never matched1abecd
184 return jnp.uint16
185 case jnp.int32: 185 ↛ 187line 185 didn't jump to line 187 because the pattern on line 185 always matched1abecd
186 return jnp.uint32 1abecd
187 case jnp.int64:
188 return jnp.uint64
189 case _:
190 msg = f'unexpected integer type {int_dtype}'
191 raise TypeError(msg)
194@partial(jit, static_argnums=(1,))
195def uniform_splits_from_matrix(
196 X: Real[Array, 'p n'], num_bins: int
197) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
198 """
199 Make an evenly spaced binning grid.
201 Parameters
202 ----------
203 X
204 A matrix with `p` predictors and `n` observations.
205 num_bins
206 The number of bins to produce.
208 Returns
209 -------
210 splits : Real[Array, 'p m']
211 A matrix containing, for each predictor, the boundaries between bins.
212 The excluded endpoints are the minimum and maximum value in each row of
213 `X`.
214 max_split : UInt[Array, ' p']
215 The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``.
216 """
217 low = jnp.min(X, axis=1) 1BCDEFG
218 high = jnp.max(X, axis=1) 1BCDEFG
219 splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1] 1BCDEFG
220 assert splits.shape == (X.shape[0], num_bins - 1) 1BCDEFG
221 max_split = jnp.full(*splits.shape, minimal_unsigned_dtype(num_bins - 1)) 1BCDEFG
222 return splits, max_split 1BCDEFG
225@partial(jit, static_argnames=('method',))
226def bin_predictors(
227 X: Real[Array, 'p n'], splits: Real[Array, 'p m'], **kw: Any
228) -> UInt[Array, 'p n']:
229 """
230 Bin the predictors according to the given splits.
232 A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``.
234 Parameters
235 ----------
236 X
237 A matrix with `p` predictors and `n` observations.
238 splits
239 A matrix containing, for each predictor, the boundaries between bins.
240 `m` is the maximum number of splits; each row may have shorter
241 actual length, marked by padding unused locations at the end of the
242 row with the maximum value allowed by the type.
243 **kw
244 Additional arguments are passed to `jax.numpy.searchsorted`.
246 Returns
247 -------
248 `X` but with each value replaced by the index of the bin it falls into.
249 """
251 @partial(autobatch, max_io_nbytes=2**29) 1qfBghHCijrstDuvwxyzEklFGAIJK
252 @vmap 1qfBghHCijrstDuvwxyzEklFGAIJK
253 def bin_predictors( 1qfBghHCijrstDuvwxyzEklFGAIJK
254 x: Real[Array, 'p n'], splits: Real[Array, 'p m']
255 ) -> UInt[Array, 'p n']:
256 dtype = minimal_unsigned_dtype(splits.size) 1qfBghHCijrstDuvwxyzEklFGAIJK
257 return jnp.searchsorted(splits, x, **kw).astype(dtype) 1qfBghHCijrstDuvwxyzEklFGAIJK
259 return bin_predictors(X, splits) 1qfBghHCijrstDuvwxyzEklFGAIJK