Coverage for src/bartz/stochtree/_preprocess.py: 93%
209 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/stochtree/_preprocess.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Auto-preprocessing of covariates for the stochtree-compatible BART interface.
27Two parallel implementations are provided, `PandasPreprocessor` and
28`PolarsPreprocessor`, each handling the corresponding dataframe library. Both
29classes have the same interface::
31 pp = PandasPreprocessor() # or PolarsPreprocessor()
32 varprob = pp.fit(X_train, variable_weights=w)
33 x_train = pp.transform(X_train)
34 x_new = pp.transform(X_new) # at prediction time
36`fit` records the per-column encoding and returns the variable weights expanded
37to match the new column count (or `None`); `transform` returns the
38post-processing covariate matrix as a 2-D numpy float32 array (rows=observations,
39columns=expanded features).
41Per-column handling:
43- ordered categorical (pandas ordered `Categorical`): ordinal encoded into a
44 single integer-valued column, with the declared category order giving the
45 integer mapping. polars has no ordered categorical dtype; pass an integer
46 column for ordinal encoding.
47- unordered categorical (pandas unordered `Categorical`, polars `Enum`): one-hot
48 encoded into one binary column per declared category. A polars `Enum`
49 round-trips to a pandas *unordered* `Categorical`, so the two are treated
50 identically.
51- boolean: cast to ``{0.0, 1.0}``, single column.
52- numeric (integer, unsigned, float): pass-through as float.
53- anything else (strings, ``object``, datetime, polars `Categorical`, etc.):
54 raises `ValueError`. polars `Categorical` has no reliable per-column category
55 list (the categories live in a process-wide string cache shared across
56 columns), so it must be cast to an `Enum` (one-hot) or an integer (ordinal)
57 first.
59When a single original column expands into ``k`` output columns (one-hot), the
60original `variable_weights` entry for that column is split evenly across the
61``k`` expansions, preserving each original variable's total splitting budget
62(matching stochtree's `bart.py` behavior).
64Unknown category values encountered during `transform` raise `ValueError`.
65"""
67from collections.abc import Sequence
68from dataclasses import dataclass
69from typing import Any, Literal, TypeAlias, overload
71import numpy as np
72from jaxtyping import Float32, Shaped
73from numpy.typing import ArrayLike
75# Duck-typed stand-ins for the optional dataframe libraries. bartz does not
76# depend on pandas or polars at runtime, so we cannot reference their real
77# classes here; these aliases resolve to `Any` but give the signatures below
78# legible names.
79DataFrame: TypeAlias = Any # a pandas or polars DataFrame
80Series: TypeAlias = Any # a pandas or polars Series
81PolarsModule: TypeAlias = Any # the polars top-level module
83_UNSEEN_PREVIEW = 10
85ColumnKind: TypeAlias = Literal['numeric', 'bool', 'ordered_cat', 'unordered_cat']
88@dataclass(frozen=True)
89class _ColumnSpec:
90 """Per-original-column fitted state."""
92 kind: ColumnKind
93 """Encoding to apply to the column."""
95 name: str
96 """Original column name (for error messages)."""
98 categories: tuple[Any, ...] | None = None
99 """Declared category list for ordered_cat / unordered_cat."""
101 @property
102 def width(self) -> int:
103 """Number of output columns this spec produces."""
104 if self.kind == 'unordered_cat':
105 assert self.categories is not None
106 return len(self.categories)
107 return 1
110def _unseen_error(name: str, unseen: Sequence[Any], known: Sequence[Any]) -> ValueError:
111 """Build the error for category values absent from the fitted list."""
112 uniq = sorted({repr(v) for v in unseen})
113 msg = (
114 f'column {name!r}: {len(unseen)} value(s) at transform time are not in'
115 f' the fitted category list; unseen sample: {uniq[:_UNSEEN_PREVIEW]};'
116 f' known categories: {list(known)[:_UNSEEN_PREVIEW]}'
117 )
118 return ValueError(msg)
121def _unsupported_dtype_error(name: str, dtype: object) -> ValueError:
122 """Build the error for a column whose dtype has no supported encoding."""
123 msg = (
124 f'column {name!r} has unsupported dtype {dtype!r}; supported types are'
125 ' numeric, boolean, pandas ordered/unordered Categorical, and polars'
126 ' Enum. Convert strings, objects, datetimes, etc. to one of these (e.g.'
127 ' an explicit Categorical / Enum) before fitting.'
128 )
129 return ValueError(msg)
132def _polars_categorical_error(name: str) -> ValueError:
133 """Build the error rejecting a polars `Categorical` column."""
134 msg = (
135 f'column {name!r} is a polars Categorical, which has no reliable'
136 ' per-column category list (the categories live in a process-wide'
137 ' string cache shared across columns). Cast it to a polars Enum with an'
138 ' explicit category list (pl.Enum([...])) for one-hot encoding, or to an'
139 ' integer column for ordinal encoding.'
140 )
141 return ValueError(msg)
144def _ordinal_encode(
145 values: Shaped[np.ndarray, ' n'], categories: Sequence[Any], name: str
146) -> Float32[np.ndarray, 'n 1']:
147 """Map `values` to integer positions in `categories`; raise on unseen."""
148 table = {c: i for i, c in enumerate(categories)}
149 out = np.empty(len(values), dtype=np.float32)
150 unseen: list[Any] = []
151 for i, v in enumerate(values):
152 code = table.get(v, -1)
153 if code < 0: 153 ↛ 154line 153 didn't jump to line 154 because the condition on line 153 was never true
154 unseen.append(v)
155 else:
156 out[i] = code
157 if unseen: 157 ↛ 158line 157 didn't jump to line 158 because the condition on line 157 was never true
158 raise _unseen_error(name, unseen, categories)
159 return out[:, None]
162def _one_hot_encode(
163 values: Shaped[np.ndarray, ' n'], categories: Sequence[Any], name: str
164) -> Float32[np.ndarray, 'n k']:
165 """Build a ``(n, k)`` one-hot matrix using `categories` order; raise on unseen."""
166 table = {c: i for i, c in enumerate(categories)}
167 n = len(values)
168 k = len(categories)
169 out = np.zeros((n, k), dtype=np.float32)
170 unseen: list[Any] = []
171 for i, v in enumerate(values):
172 code = table.get(v, -1)
173 if code < 0:
174 unseen.append(v)
175 else:
176 out[i, code] = 1.0
177 if unseen:
178 raise _unseen_error(name, unseen, categories)
179 return out
182def _polars_one_hot(
183 pl: PolarsModule, series: Series, categories: Sequence[Any], name: str
184) -> Float32[np.ndarray, 'n k']:
185 """Validate via cast to `pl.Enum(categories)` and one-hot via polars APIs.
187 Polars's `Enum` cast natively raises on any value not in `categories`, and
188 `to_physical` returns the integer codes in the declared-category order. The
189 `np.eye` index is the only numpy bit and is just an identity-matrix lookup;
190 the categorical bookkeeping itself stays inside polars.
191 """
192 cats = list(categories)
193 try:
194 coded = series.cast(pl.Enum(cats))
195 except pl.exceptions.InvalidOperationError as exc:
196 # Identify the actual unseen values for a friendly error. Cast to String
197 # first: the input column may itself be an Enum with different categories,
198 # which would make a direct is_in(cats) fail trying to coerce the list.
199 known = set(cats)
200 unseen = sorted(
201 {
202 v
203 for v in series.cast(pl.String).to_list()
204 if v not in known and v is not None
205 }
206 )
207 raise _unseen_error(name, unseen, cats) from exc
208 if coded.null_count(): 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true
209 msg = f'column {name!r}: null values are not supported in categorical columns'
210 raise ValueError(msg)
211 codes = coded.to_physical().to_numpy()
212 return np.eye(len(cats), dtype=np.float32)[codes]
215def _expand_variable_weights(
216 weights: Shaped[ArrayLike, '...'], original_var_indices: Sequence[int], n_orig: int
217) -> Float32[np.ndarray, ' p']:
218 """Split each original weight evenly across its one-hot expansions."""
219 w = np.asarray(weights, dtype=np.float32)
220 if w.shape != (n_orig,):
221 msg = (
222 f'variable_weights must have shape ({n_orig},) matching the number'
223 f' of original columns; got {w.shape}'
224 )
225 raise ValueError(msg)
226 if not original_var_indices: 226 ↛ 227line 226 didn't jump to line 227 because the condition on line 226 was never true
227 return np.empty((0,), dtype=np.float32)
228 counts = np.bincount(np.asarray(original_var_indices), minlength=n_orig)
229 return np.array([w[j] / counts[j] for j in original_var_indices], dtype=np.float32)
232def _stack(
233 cols: Sequence[Float32[np.ndarray, 'n _']], n_rows: int
234) -> Float32[np.ndarray, 'n p']:
235 if not cols: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true
236 return np.empty((n_rows, 0), dtype=np.float32)
237 return np.concatenate(cols, axis=1)
240class _PreprocessorBase:
241 """Common state for `PandasPreprocessor` and `PolarsPreprocessor`."""
243 _library: str = ''
244 """Top-level module prefix of the supported dataframe library."""
246 _fitted: bool = False
247 _specs: Sequence[_ColumnSpec] = ()
248 _original_var_indices: Sequence[int] = ()
250 @property
251 def fitted(self) -> bool:
252 """Whether `fit` has been called."""
253 return self._fitted
255 @property
256 def n_original_columns(self) -> int:
257 """Number of columns in the dataframe given to `fit`."""
258 return len(self._specs)
260 @property
261 def n_processed_columns(self) -> int:
262 """Number of columns in the matrix returned by `transform`."""
263 return len(self._original_var_indices)
265 @property
266 def original_var_indices(self) -> tuple[int, ...]:
267 """For each output column, the index of the original column it came from."""
268 return tuple(self._original_var_indices)
270 @overload
271 def fit(
272 self, X: DataFrame, *, variable_weights: Shaped[ArrayLike, '...']
273 ) -> Float32[np.ndarray, ' p']: ...
275 @overload
276 def fit(
277 self, X: DataFrame, *, variable_weights: None = None
278 ) -> Float32[np.ndarray, ' p'] | None: ...
280 def fit(
281 self, X: DataFrame, *, variable_weights: Shaped[ArrayLike, '...'] | None = None
282 ) -> Float32[np.ndarray, ' p'] | None:
283 """Record the per-column encoding and return the expanded variable weights.
285 Returns `None` when no weights are supplied and no column expands into
286 several output columns, so the caller can fall back to the native
287 uniform-weights path; otherwise returns the weights split across each
288 original column's one-hot expansion.
289 """
290 self._check_library(X)
291 specs: list[_ColumnSpec] = []
292 original_var_indices: list[int] = []
293 for orig_idx in range(X.shape[1]):
294 name, series = self._get_column(X, orig_idx)
295 spec = self._fit_column(series, str(name))
296 specs.append(spec)
297 original_var_indices.extend([orig_idx] * spec.width)
298 self._specs = tuple(specs)
299 self._original_var_indices = tuple(original_var_indices)
300 self._fitted = True
301 expanded = len(set(original_var_indices)) != len(original_var_indices)
302 if variable_weights is None:
303 if not expanded:
304 return None
305 variable_weights = np.full(len(specs), 1.0 / len(specs))
306 return _expand_variable_weights(
307 variable_weights, self._original_var_indices, len(self._specs)
308 )
310 def transform(self, X: DataFrame) -> Float32[np.ndarray, 'n p']:
311 """Apply the fitted transformation to a new dataframe."""
312 self._check_fitted()
313 self._check_library(X)
314 self._check_n_columns(X.shape[1])
315 cols = [
316 self._transform_column(self._get_column(X, orig_idx)[1], spec)
317 for orig_idx, spec in enumerate(self._specs)
318 ]
319 return _stack(cols, X.shape[0])
321 def _check_fitted(self) -> None:
322 if not self._fitted: 322 ↛ 323line 322 didn't jump to line 323 because the condition on line 322 was never true
323 msg = 'preprocessor has not been fitted yet; call fit first'
324 raise RuntimeError(msg)
326 def _check_n_columns(self, n_cols: int) -> None:
327 if n_cols != len(self._specs):
328 msg = (
329 f'transform input has {n_cols} columns; preprocessor was fitted'
330 f' on {len(self._specs)} columns'
331 )
332 raise ValueError(msg)
334 def _check_library(self, X: DataFrame) -> None:
335 module = type(X).__module__
336 if not module.startswith(self._library):
337 msg = (
338 f'this preprocessor handles {self._library} dataframes, but got'
339 f' an object from {module!r}; fit and transform must use the same'
340 ' dataframe library'
341 )
342 raise TypeError(msg)
344 @staticmethod
345 def _get_column(X: DataFrame, orig_idx: int) -> tuple[Any, Series]:
346 """Return the ``(name, series)`` of the column at position `orig_idx`."""
347 raise NotImplementedError
349 @staticmethod
350 def _fit_column(series: Series, name: str) -> _ColumnSpec:
351 """Inspect a column's dtype and return its encoding spec."""
352 raise NotImplementedError
354 @staticmethod
355 def _transform_column(
356 series: Series, spec: _ColumnSpec
357 ) -> Float32[np.ndarray, 'n _']:
358 """Encode a single column according to its fitted spec."""
359 raise NotImplementedError
362class PandasPreprocessor(_PreprocessorBase):
363 """Stochtree-style covariate preprocessor for `pandas.DataFrame` inputs."""
365 _library = 'pandas'
367 @staticmethod
368 def _get_column(X: DataFrame, orig_idx: int) -> tuple[Any, Series]:
369 return X.columns[orig_idx], X.iloc[:, orig_idx]
371 @staticmethod
372 def _fit_column(series: Series, name: str) -> _ColumnSpec:
373 import pandas as pd # noqa: PLC0415 # optional runtime dependency
375 dt = series.dtype
376 if isinstance(dt, pd.CategoricalDtype):
377 cats = tuple(dt.categories)
378 kind: ColumnKind = 'ordered_cat' if dt.ordered else 'unordered_cat'
379 return _ColumnSpec(kind, name, categories=cats)
380 if pd.api.types.is_bool_dtype(dt):
381 return _ColumnSpec('bool', name)
382 if pd.api.types.is_numeric_dtype(dt):
383 return _ColumnSpec('numeric', name)
384 raise _unsupported_dtype_error(name, dt)
386 @staticmethod
387 def _transform_column(
388 series: Series, spec: _ColumnSpec
389 ) -> Float32[np.ndarray, 'n _']:
390 if spec.kind == 'ordered_cat':
391 assert spec.categories is not None
392 return _ordinal_encode(series.to_numpy(), spec.categories, spec.name)
393 if spec.kind == 'unordered_cat':
394 assert spec.categories is not None
395 return _one_hot_encode(series.to_numpy(), spec.categories, spec.name)
396 return series.to_numpy(dtype=np.float32)[:, None]
399class PolarsPreprocessor(_PreprocessorBase):
400 """Stochtree-style covariate preprocessor for `polars.DataFrame` inputs."""
402 _library = 'polars'
404 @staticmethod
405 def _get_column(X: DataFrame, orig_idx: int) -> tuple[Any, Series]:
406 name = X.columns[orig_idx]
407 return name, X[name]
409 @staticmethod
410 def _fit_column(series: Series, name: str) -> _ColumnSpec:
411 import polars as pl # noqa: PLC0415 # optional runtime dependency
413 dt = series.dtype
414 if isinstance(dt, pl.Enum):
415 # A polars Enum round-trips to a pandas *unordered* Categorical, so
416 # we treat it as unordered (one-hot). For ordinal encoding, pass an
417 # integer column.
418 return _ColumnSpec(
419 'unordered_cat', name, categories=tuple(dt.categories.to_list())
420 )
421 if isinstance(dt, pl.Categorical):
422 raise _polars_categorical_error(name)
423 if dt == pl.Boolean:
424 return _ColumnSpec('bool', name)
425 if dt.is_numeric():
426 return _ColumnSpec('numeric', name)
427 raise _unsupported_dtype_error(name, dt)
429 @staticmethod
430 def _transform_column(
431 series: Series, spec: _ColumnSpec
432 ) -> Float32[np.ndarray, 'n _']:
433 import polars as pl # noqa: PLC0415 # optional runtime dependency
435 if spec.kind == 'unordered_cat':
436 assert spec.categories is not None
437 return _polars_one_hot(pl, series, spec.categories, spec.name)
438 return series.cast(pl.Float32).to_numpy()[:, None]
441def make_preprocessor(X: object) -> _PreprocessorBase | None:
442 """Return a preprocessor matched to `X`'s library, or `None` if `X` is not a DataFrame.
444 Dispatches by inspecting ``type(X).__module__`` to avoid hard imports of
445 pandas/polars.
446 """
447 mod = type(X).__module__
448 if mod.startswith('polars'):
449 return PolarsPreprocessor()
450 if mod.startswith('pandas'):
451 return PandasPreprocessor()
452 return None