Coverage for src/bartz/stochtree/_preprocess.py: 93%

209 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/stochtree/_preprocess.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Auto-preprocessing of covariates for the stochtree-compatible BART interface. 

26 

27Two parallel implementations are provided, `PandasPreprocessor` and 

28`PolarsPreprocessor`, each handling the corresponding dataframe library. Both 

29classes have the same interface:: 

30 

31 pp = PandasPreprocessor() # or PolarsPreprocessor() 

32 varprob = pp.fit(X_train, variable_weights=w) 

33 x_train = pp.transform(X_train) 

34 x_new = pp.transform(X_new) # at prediction time 

35 

36`fit` records the per-column encoding and returns the variable weights expanded 

37to match the new column count (or `None`); `transform` returns the 

38post-processing covariate matrix as a 2-D numpy float32 array (rows=observations, 

39columns=expanded features). 

40 

41Per-column handling: 

42 

43- ordered categorical (pandas ordered `Categorical`): ordinal encoded into a 

44 single integer-valued column, with the declared category order giving the 

45 integer mapping. polars has no ordered categorical dtype; pass an integer 

46 column for ordinal encoding. 

47- unordered categorical (pandas unordered `Categorical`, polars `Enum`): one-hot 

48 encoded into one binary column per declared category. A polars `Enum` 

49 round-trips to a pandas *unordered* `Categorical`, so the two are treated 

50 identically. 

51- boolean: cast to ``{0.0, 1.0}``, single column. 

52- numeric (integer, unsigned, float): pass-through as float. 

53- anything else (strings, ``object``, datetime, polars `Categorical`, etc.): 

54 raises `ValueError`. polars `Categorical` has no reliable per-column category 

55 list (the categories live in a process-wide string cache shared across 

56 columns), so it must be cast to an `Enum` (one-hot) or an integer (ordinal) 

57 first. 

58 

59When a single original column expands into ``k`` output columns (one-hot), the 

60original `variable_weights` entry for that column is split evenly across the 

61``k`` expansions, preserving each original variable's total splitting budget 

62(matching stochtree's `bart.py` behavior). 

63 

64Unknown category values encountered during `transform` raise `ValueError`. 

65""" 

66 

67from collections.abc import Sequence 

68from dataclasses import dataclass 

69from typing import Any, Literal, TypeAlias, overload 

70 

71import numpy as np 

72from jaxtyping import Float32, Shaped 

73from numpy.typing import ArrayLike 

74 

75# Duck-typed stand-ins for the optional dataframe libraries. bartz does not 

76# depend on pandas or polars at runtime, so we cannot reference their real 

77# classes here; these aliases resolve to `Any` but give the signatures below 

78# legible names. 

79DataFrame: TypeAlias = Any # a pandas or polars DataFrame 

80Series: TypeAlias = Any # a pandas or polars Series 

81PolarsModule: TypeAlias = Any # the polars top-level module 

82 

83_UNSEEN_PREVIEW = 10 

84 

85ColumnKind: TypeAlias = Literal['numeric', 'bool', 'ordered_cat', 'unordered_cat'] 

86 

87 

88@dataclass(frozen=True) 

89class _ColumnSpec: 

90 """Per-original-column fitted state.""" 

91 

92 kind: ColumnKind 

93 """Encoding to apply to the column.""" 

94 

95 name: str 

96 """Original column name (for error messages).""" 

97 

98 categories: tuple[Any, ...] | None = None 

99 """Declared category list for ordered_cat / unordered_cat.""" 

100 

101 @property 

102 def width(self) -> int: 

103 """Number of output columns this spec produces.""" 

104 if self.kind == 'unordered_cat': 

105 assert self.categories is not None 

106 return len(self.categories) 

107 return 1 

108 

109 

110def _unseen_error(name: str, unseen: Sequence[Any], known: Sequence[Any]) -> ValueError: 

111 """Build the error for category values absent from the fitted list.""" 

112 uniq = sorted({repr(v) for v in unseen}) 

113 msg = ( 

114 f'column {name!r}: {len(unseen)} value(s) at transform time are not in' 

115 f' the fitted category list; unseen sample: {uniq[:_UNSEEN_PREVIEW]};' 

116 f' known categories: {list(known)[:_UNSEEN_PREVIEW]}' 

117 ) 

118 return ValueError(msg) 

119 

120 

121def _unsupported_dtype_error(name: str, dtype: object) -> ValueError: 

122 """Build the error for a column whose dtype has no supported encoding.""" 

123 msg = ( 

124 f'column {name!r} has unsupported dtype {dtype!r}; supported types are' 

125 ' numeric, boolean, pandas ordered/unordered Categorical, and polars' 

126 ' Enum. Convert strings, objects, datetimes, etc. to one of these (e.g.' 

127 ' an explicit Categorical / Enum) before fitting.' 

128 ) 

129 return ValueError(msg) 

130 

131 

132def _polars_categorical_error(name: str) -> ValueError: 

133 """Build the error rejecting a polars `Categorical` column.""" 

134 msg = ( 

135 f'column {name!r} is a polars Categorical, which has no reliable' 

136 ' per-column category list (the categories live in a process-wide' 

137 ' string cache shared across columns). Cast it to a polars Enum with an' 

138 ' explicit category list (pl.Enum([...])) for one-hot encoding, or to an' 

139 ' integer column for ordinal encoding.' 

140 ) 

141 return ValueError(msg) 

142 

143 

144def _ordinal_encode( 

145 values: Shaped[np.ndarray, ' n'], categories: Sequence[Any], name: str 

146) -> Float32[np.ndarray, 'n 1']: 

147 """Map `values` to integer positions in `categories`; raise on unseen.""" 

148 table = {c: i for i, c in enumerate(categories)} 

149 out = np.empty(len(values), dtype=np.float32) 

150 unseen: list[Any] = [] 

151 for i, v in enumerate(values): 

152 code = table.get(v, -1) 

153 if code < 0: 153 ↛ 154line 153 didn't jump to line 154 because the condition on line 153 was never true

154 unseen.append(v) 

155 else: 

156 out[i] = code 

157 if unseen: 157 ↛ 158line 157 didn't jump to line 158 because the condition on line 157 was never true

158 raise _unseen_error(name, unseen, categories) 

159 return out[:, None] 

160 

161 

162def _one_hot_encode( 

163 values: Shaped[np.ndarray, ' n'], categories: Sequence[Any], name: str 

164) -> Float32[np.ndarray, 'n k']: 

165 """Build a ``(n, k)`` one-hot matrix using `categories` order; raise on unseen.""" 

166 table = {c: i for i, c in enumerate(categories)} 

167 n = len(values) 

168 k = len(categories) 

169 out = np.zeros((n, k), dtype=np.float32) 

170 unseen: list[Any] = [] 

171 for i, v in enumerate(values): 

172 code = table.get(v, -1) 

173 if code < 0: 

174 unseen.append(v) 

175 else: 

176 out[i, code] = 1.0 

177 if unseen: 

178 raise _unseen_error(name, unseen, categories) 

179 return out 

180 

181 

182def _polars_one_hot( 

183 pl: PolarsModule, series: Series, categories: Sequence[Any], name: str 

184) -> Float32[np.ndarray, 'n k']: 

185 """Validate via cast to `pl.Enum(categories)` and one-hot via polars APIs. 

186 

187 Polars's `Enum` cast natively raises on any value not in `categories`, and 

188 `to_physical` returns the integer codes in the declared-category order. The 

189 `np.eye` index is the only numpy bit and is just an identity-matrix lookup; 

190 the categorical bookkeeping itself stays inside polars. 

191 """ 

192 cats = list(categories) 

193 try: 

194 coded = series.cast(pl.Enum(cats)) 

195 except pl.exceptions.InvalidOperationError as exc: 

196 # Identify the actual unseen values for a friendly error. Cast to String 

197 # first: the input column may itself be an Enum with different categories, 

198 # which would make a direct is_in(cats) fail trying to coerce the list. 

199 known = set(cats) 

200 unseen = sorted( 

201 { 

202 v 

203 for v in series.cast(pl.String).to_list() 

204 if v not in known and v is not None 

205 } 

206 ) 

207 raise _unseen_error(name, unseen, cats) from exc 

208 if coded.null_count(): 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true

209 msg = f'column {name!r}: null values are not supported in categorical columns' 

210 raise ValueError(msg) 

211 codes = coded.to_physical().to_numpy() 

212 return np.eye(len(cats), dtype=np.float32)[codes] 

213 

214 

215def _expand_variable_weights( 

216 weights: Shaped[ArrayLike, '...'], original_var_indices: Sequence[int], n_orig: int 

217) -> Float32[np.ndarray, ' p']: 

218 """Split each original weight evenly across its one-hot expansions.""" 

219 w = np.asarray(weights, dtype=np.float32) 

220 if w.shape != (n_orig,): 

221 msg = ( 

222 f'variable_weights must have shape ({n_orig},) matching the number' 

223 f' of original columns; got {w.shape}' 

224 ) 

225 raise ValueError(msg) 

226 if not original_var_indices: 226 ↛ 227line 226 didn't jump to line 227 because the condition on line 226 was never true

227 return np.empty((0,), dtype=np.float32) 

228 counts = np.bincount(np.asarray(original_var_indices), minlength=n_orig) 

229 return np.array([w[j] / counts[j] for j in original_var_indices], dtype=np.float32) 

230 

231 

232def _stack( 

233 cols: Sequence[Float32[np.ndarray, 'n _']], n_rows: int 

234) -> Float32[np.ndarray, 'n p']: 

235 if not cols: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true

236 return np.empty((n_rows, 0), dtype=np.float32) 

237 return np.concatenate(cols, axis=1) 

238 

239 

240class _PreprocessorBase: 

241 """Common state for `PandasPreprocessor` and `PolarsPreprocessor`.""" 

242 

243 _library: str = '' 

244 """Top-level module prefix of the supported dataframe library.""" 

245 

246 _fitted: bool = False 

247 _specs: Sequence[_ColumnSpec] = () 

248 _original_var_indices: Sequence[int] = () 

249 

250 @property 

251 def fitted(self) -> bool: 

252 """Whether `fit` has been called.""" 

253 return self._fitted 

254 

255 @property 

256 def n_original_columns(self) -> int: 

257 """Number of columns in the dataframe given to `fit`.""" 

258 return len(self._specs) 

259 

260 @property 

261 def n_processed_columns(self) -> int: 

262 """Number of columns in the matrix returned by `transform`.""" 

263 return len(self._original_var_indices) 

264 

265 @property 

266 def original_var_indices(self) -> tuple[int, ...]: 

267 """For each output column, the index of the original column it came from.""" 

268 return tuple(self._original_var_indices) 

269 

270 @overload 

271 def fit( 

272 self, X: DataFrame, *, variable_weights: Shaped[ArrayLike, '...'] 

273 ) -> Float32[np.ndarray, ' p']: ... 

274 

275 @overload 

276 def fit( 

277 self, X: DataFrame, *, variable_weights: None = None 

278 ) -> Float32[np.ndarray, ' p'] | None: ... 

279 

280 def fit( 

281 self, X: DataFrame, *, variable_weights: Shaped[ArrayLike, '...'] | None = None 

282 ) -> Float32[np.ndarray, ' p'] | None: 

283 """Record the per-column encoding and return the expanded variable weights. 

284 

285 Returns `None` when no weights are supplied and no column expands into 

286 several output columns, so the caller can fall back to the native 

287 uniform-weights path; otherwise returns the weights split across each 

288 original column's one-hot expansion. 

289 """ 

290 self._check_library(X) 

291 specs: list[_ColumnSpec] = [] 

292 original_var_indices: list[int] = [] 

293 for orig_idx in range(X.shape[1]): 

294 name, series = self._get_column(X, orig_idx) 

295 spec = self._fit_column(series, str(name)) 

296 specs.append(spec) 

297 original_var_indices.extend([orig_idx] * spec.width) 

298 self._specs = tuple(specs) 

299 self._original_var_indices = tuple(original_var_indices) 

300 self._fitted = True 

301 expanded = len(set(original_var_indices)) != len(original_var_indices) 

302 if variable_weights is None: 

303 if not expanded: 

304 return None 

305 variable_weights = np.full(len(specs), 1.0 / len(specs)) 

306 return _expand_variable_weights( 

307 variable_weights, self._original_var_indices, len(self._specs) 

308 ) 

309 

310 def transform(self, X: DataFrame) -> Float32[np.ndarray, 'n p']: 

311 """Apply the fitted transformation to a new dataframe.""" 

312 self._check_fitted() 

313 self._check_library(X) 

314 self._check_n_columns(X.shape[1]) 

315 cols = [ 

316 self._transform_column(self._get_column(X, orig_idx)[1], spec) 

317 for orig_idx, spec in enumerate(self._specs) 

318 ] 

319 return _stack(cols, X.shape[0]) 

320 

321 def _check_fitted(self) -> None: 

322 if not self._fitted: 322 ↛ 323line 322 didn't jump to line 323 because the condition on line 322 was never true

323 msg = 'preprocessor has not been fitted yet; call fit first' 

324 raise RuntimeError(msg) 

325 

326 def _check_n_columns(self, n_cols: int) -> None: 

327 if n_cols != len(self._specs): 

328 msg = ( 

329 f'transform input has {n_cols} columns; preprocessor was fitted' 

330 f' on {len(self._specs)} columns' 

331 ) 

332 raise ValueError(msg) 

333 

334 def _check_library(self, X: DataFrame) -> None: 

335 module = type(X).__module__ 

336 if not module.startswith(self._library): 

337 msg = ( 

338 f'this preprocessor handles {self._library} dataframes, but got' 

339 f' an object from {module!r}; fit and transform must use the same' 

340 ' dataframe library' 

341 ) 

342 raise TypeError(msg) 

343 

344 @staticmethod 

345 def _get_column(X: DataFrame, orig_idx: int) -> tuple[Any, Series]: 

346 """Return the ``(name, series)`` of the column at position `orig_idx`.""" 

347 raise NotImplementedError 

348 

349 @staticmethod 

350 def _fit_column(series: Series, name: str) -> _ColumnSpec: 

351 """Inspect a column's dtype and return its encoding spec.""" 

352 raise NotImplementedError 

353 

354 @staticmethod 

355 def _transform_column( 

356 series: Series, spec: _ColumnSpec 

357 ) -> Float32[np.ndarray, 'n _']: 

358 """Encode a single column according to its fitted spec.""" 

359 raise NotImplementedError 

360 

361 

362class PandasPreprocessor(_PreprocessorBase): 

363 """Stochtree-style covariate preprocessor for `pandas.DataFrame` inputs.""" 

364 

365 _library = 'pandas' 

366 

367 @staticmethod 

368 def _get_column(X: DataFrame, orig_idx: int) -> tuple[Any, Series]: 

369 return X.columns[orig_idx], X.iloc[:, orig_idx] 

370 

371 @staticmethod 

372 def _fit_column(series: Series, name: str) -> _ColumnSpec: 

373 import pandas as pd # noqa: PLC0415 # optional runtime dependency 

374 

375 dt = series.dtype 

376 if isinstance(dt, pd.CategoricalDtype): 

377 cats = tuple(dt.categories) 

378 kind: ColumnKind = 'ordered_cat' if dt.ordered else 'unordered_cat' 

379 return _ColumnSpec(kind, name, categories=cats) 

380 if pd.api.types.is_bool_dtype(dt): 

381 return _ColumnSpec('bool', name) 

382 if pd.api.types.is_numeric_dtype(dt): 

383 return _ColumnSpec('numeric', name) 

384 raise _unsupported_dtype_error(name, dt) 

385 

386 @staticmethod 

387 def _transform_column( 

388 series: Series, spec: _ColumnSpec 

389 ) -> Float32[np.ndarray, 'n _']: 

390 if spec.kind == 'ordered_cat': 

391 assert spec.categories is not None 

392 return _ordinal_encode(series.to_numpy(), spec.categories, spec.name) 

393 if spec.kind == 'unordered_cat': 

394 assert spec.categories is not None 

395 return _one_hot_encode(series.to_numpy(), spec.categories, spec.name) 

396 return series.to_numpy(dtype=np.float32)[:, None] 

397 

398 

399class PolarsPreprocessor(_PreprocessorBase): 

400 """Stochtree-style covariate preprocessor for `polars.DataFrame` inputs.""" 

401 

402 _library = 'polars' 

403 

404 @staticmethod 

405 def _get_column(X: DataFrame, orig_idx: int) -> tuple[Any, Series]: 

406 name = X.columns[orig_idx] 

407 return name, X[name] 

408 

409 @staticmethod 

410 def _fit_column(series: Series, name: str) -> _ColumnSpec: 

411 import polars as pl # noqa: PLC0415 # optional runtime dependency 

412 

413 dt = series.dtype 

414 if isinstance(dt, pl.Enum): 

415 # A polars Enum round-trips to a pandas *unordered* Categorical, so 

416 # we treat it as unordered (one-hot). For ordinal encoding, pass an 

417 # integer column. 

418 return _ColumnSpec( 

419 'unordered_cat', name, categories=tuple(dt.categories.to_list()) 

420 ) 

421 if isinstance(dt, pl.Categorical): 

422 raise _polars_categorical_error(name) 

423 if dt == pl.Boolean: 

424 return _ColumnSpec('bool', name) 

425 if dt.is_numeric(): 

426 return _ColumnSpec('numeric', name) 

427 raise _unsupported_dtype_error(name, dt) 

428 

429 @staticmethod 

430 def _transform_column( 

431 series: Series, spec: _ColumnSpec 

432 ) -> Float32[np.ndarray, 'n _']: 

433 import polars as pl # noqa: PLC0415 # optional runtime dependency 

434 

435 if spec.kind == 'unordered_cat': 

436 assert spec.categories is not None 

437 return _polars_one_hot(pl, series, spec.categories, spec.name) 

438 return series.cast(pl.Float32).to_numpy()[:, None] 

439 

440 

441def make_preprocessor(X: object) -> _PreprocessorBase | None: 

442 """Return a preprocessor matched to `X`'s library, or `None` if `X` is not a DataFrame. 

443 

444 Dispatches by inspecting ``type(X).__module__`` to avoid hard imports of 

445 pandas/polars. 

446 """ 

447 mod = type(X).__module__ 

448 if mod.startswith('polars'): 

449 return PolarsPreprocessor() 

450 if mod.startswith('pandas'): 

451 return PandasPreprocessor() 

452 return None