Coverage for src / bartz / prepcovars.py: 86%

79 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/prepcovars.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions to preprocess data.""" 

26 

27from functools import partial 

28from typing import Any 

29 

30from jax import jit, vmap 

31from jax import numpy as jnp 

32from jaxtyping import Array, Float, Integer, Real, UInt 

33 

34from bartz.jaxext import autobatch, minimal_unsigned_dtype, unique 

35 

36 

37def parse_xinfo( 

38 xinfo: Float[Array, 'p m'], 

39) -> tuple[Float[Array, 'p m'], UInt[Array, ' p']]: 

40 """Parse pre-defined splits in the format of the R package BART. 

41 

42 Parameters 

43 ---------- 

44 xinfo 

45 A matrix with the cutpoins to use to bin each predictor. Each row shall 

46 contain a sorted list of cutpoints for a predictor. If there are less 

47 cutpoints than the number of columns in the matrix, fill the remaining 

48 cells with NaN. 

49 

50 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

51 

52 Returns 

53 ------- 

54 splits : Float[Array, 'p m'] 

55 `xinfo` modified by replacing nan with a large value. 

56 max_split : UInt[Array, 'p'] 

57 The number of non-nan elements in each row of `xinfo`. 

58 """ 

59 is_not_nan = ~jnp.isnan(xinfo) 1qrstuvwxyzA

60 max_split = jnp.sum(is_not_nan, axis=1) 1qrstuvwxyzA

61 max_split = max_split.astype(minimal_unsigned_dtype(xinfo.shape[1])) 1qrstuvwxyzA

62 huge = _huge_value(xinfo) 1qrstuvwxyzA

63 splits = jnp.where(is_not_nan, xinfo, huge) 1qrstuvwxyzA

64 return splits, max_split 1qrstuvwxyzA

65 

66 

67@partial(jit, static_argnums=(1,)) 

68def quantilized_splits_from_matrix( 

69 X: Real[Array, 'p n'], max_bins: int 

70) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

71 """ 

72 Determine bins that make the distribution of each predictor uniform. 

73 

74 Parameters 

75 ---------- 

76 X 

77 A matrix with `p` predictors and `n` observations. 

78 max_bins 

79 The maximum number of bins to produce. 

80 

81 Returns 

82 ------- 

83 splits : Real[Array, 'p m'] 

84 A matrix containing, for each predictor, the boundaries between bins. 

85 `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number 

86 of splits. Each predictor may have a different number of splits; unused 

87 values at the end of each row are filled with the maximum value 

88 representable in the type of `X`. 

89 max_split : UInt[Array, ' p'] 

90 The number of actually used values in each row of `splits`. 

91 

92 Raises 

93 ------ 

94 ValueError 

95 If `X` has no columns or if `max_bins` is less than 1. 

96 """ 

97 out_length = min(max_bins, X.shape[1]) - 1 1fghijklabecmnodLM

98 

99 if out_length < 0: 1fghijklabecmnodLM

100 msg = f'{X.shape[1]=} and {max_bins=}, they should be both at least 1.' 1LM

101 raise ValueError(msg) 1LM

102 

103 @partial(autobatch, max_io_nbytes=2**29) 1fghijklabecmnod

104 def quantilize( 1fghijklabecmnod

105 X: Real[Array, 'p n'], 

106 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

107 # wrap this function because autobatch needs traceable args 

108 return _quantilized_splits_from_matrix(X, out_length) 1fghijklabecmnod

109 

110 return quantilize(X) 1fghijklabecmnod

111 

112 

113@partial(vmap, in_axes=(0, None)) 

114def _quantilized_splits_from_matrix( 

115 x: Real[Array, 'p n'], out_length: int 

116) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

117 # find the sorted unique values in x 

118 huge = _huge_value(x) 1fghijklabecmnod

119 u, actual_length = unique(x, size=x.size, fill_value=huge) 1fghijklabecmnod

120 

121 # compute the midpoints between each unique value 

122 if jnp.issubdtype(x.dtype, jnp.integer): 1fghijklabecmnod

123 midpoints = u[:-1] + _ensure_unsigned(u[1:] - u[:-1]) // 2 1abecd

124 else: 

125 midpoints = u[:-1] + (u[1:] - u[:-1]) / 2 1fghijklmno

126 # using x_i + (x_i+1 - x_i) / 2 instead of (x_i + x_i+1) / 2 is to 

127 # avoid overflow 

128 actual_length -= 1 1fghijklabecmnod

129 if midpoints.size: 1fghijklabecmnod

130 midpoints = midpoints.at[actual_length].set(huge) 1fghijklabcmnod

131 

132 # take a subset of the midpoints if there are more than the requested maximum 

133 indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1] 1fghijklabecmnod

134 indices = jnp.around(indices).astype(minimal_unsigned_dtype(midpoints.size - 1)) 1fghijklabecmnod

135 # indices calculation with float rather than int to avoid potential 

136 # overflow with int32, and to round to nearest instead of rounding down 

137 decimated_midpoints = midpoints[indices] 1fghijklabecmnod

138 truncated_midpoints = midpoints[:out_length] 1fghijklabecmnod

139 splits = jnp.where( 1fghijklabecmnod

140 actual_length > out_length, decimated_midpoints, truncated_midpoints 

141 ) 

142 max_split = jnp.minimum(actual_length, out_length) 1fghijklabecmnod

143 max_split = max_split.astype(minimal_unsigned_dtype(out_length)) 1fghijklabecmnod

144 return splits, max_split 1fghijklabecmnod

145 

146 

147def _huge_value(x: Array) -> int | float: 

148 """ 

149 Return the maximum value that can be stored in `x`. 

150 

151 Parameters 

152 ---------- 

153 x 

154 A numerical numpy or jax array. 

155 

156 Returns 

157 ------- 

158 The maximum value allowed by `x`'s type (finite for floats). 

159 """ 

160 if jnp.issubdtype(x.dtype, jnp.integer): 1qfghijrstuvwxyzklAabecmnod

161 return jnp.iinfo(x.dtype).max 1abecd

162 else: 

163 return float(jnp.finfo(x.dtype).max) 1qfghijrstuvwxyzklAmno

164 

165 

166def _ensure_unsigned(x: Integer[Array, '*shape']) -> UInt[Array, '*shape']: 

167 """If x has signed integer type, cast it to the unsigned dtype of the same size.""" 

168 return x.astype(_signed_to_unsigned(x.dtype)) 1abecd

169 

170 

171def _signed_to_unsigned(int_dtype: jnp.dtype) -> jnp.dtype: 

172 """ 

173 Map a signed integer type to its unsigned counterpart. 

174 

175 Unsigned types are passed through. 

176 """ 

177 assert jnp.issubdtype(int_dtype, jnp.integer) 1abecd

178 if jnp.issubdtype(int_dtype, jnp.unsignedinteger): 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true1abecd

179 return int_dtype 

180 match int_dtype: 1abecd

181 case jnp.int8: 181 ↛ 182line 181 didn't jump to line 182 because the pattern on line 181 never matched1abecd

182 return jnp.uint8 

183 case jnp.int16: 183 ↛ 184line 183 didn't jump to line 184 because the pattern on line 183 never matched1abecd

184 return jnp.uint16 

185 case jnp.int32: 185 ↛ 187line 185 didn't jump to line 187 because the pattern on line 185 always matched1abecd

186 return jnp.uint32 1abecd

187 case jnp.int64: 

188 return jnp.uint64 

189 case _: 

190 msg = f'unexpected integer type {int_dtype}' 

191 raise TypeError(msg) 

192 

193 

194@partial(jit, static_argnums=(1,)) 

195def uniform_splits_from_matrix( 

196 X: Real[Array, 'p n'], num_bins: int 

197) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

198 """ 

199 Make an evenly spaced binning grid. 

200 

201 Parameters 

202 ---------- 

203 X 

204 A matrix with `p` predictors and `n` observations. 

205 num_bins 

206 The number of bins to produce. 

207 

208 Returns 

209 ------- 

210 splits : Real[Array, 'p m'] 

211 A matrix containing, for each predictor, the boundaries between bins. 

212 The excluded endpoints are the minimum and maximum value in each row of 

213 `X`. 

214 max_split : UInt[Array, ' p'] 

215 The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``. 

216 """ 

217 low = jnp.min(X, axis=1) 1BCDEFG

218 high = jnp.max(X, axis=1) 1BCDEFG

219 splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1] 1BCDEFG

220 assert splits.shape == (X.shape[0], num_bins - 1) 1BCDEFG

221 max_split = jnp.full(*splits.shape, minimal_unsigned_dtype(num_bins - 1)) 1BCDEFG

222 return splits, max_split 1BCDEFG

223 

224 

225@partial(jit, static_argnames=('method',)) 

226def bin_predictors( 

227 X: Real[Array, 'p n'], splits: Real[Array, 'p m'], **kw: Any 

228) -> UInt[Array, 'p n']: 

229 """ 

230 Bin the predictors according to the given splits. 

231 

232 A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``. 

233 

234 Parameters 

235 ---------- 

236 X 

237 A matrix with `p` predictors and `n` observations. 

238 splits 

239 A matrix containing, for each predictor, the boundaries between bins. 

240 `m` is the maximum number of splits; each row may have shorter 

241 actual length, marked by padding unused locations at the end of the 

242 row with the maximum value allowed by the type. 

243 **kw 

244 Additional arguments are passed to `jax.numpy.searchsorted`. 

245 

246 Returns 

247 ------- 

248 `X` but with each value replaced by the index of the bin it falls into. 

249 """ 

250 

251 @partial(autobatch, max_io_nbytes=2**29) 1qfBghHCijrstDuvwxyzEklFGAIJK

252 @vmap 1qfBghHCijrstDuvwxyzEklFGAIJK

253 def bin_predictors( 1qfBghHCijrstDuvwxyzEklFGAIJK

254 x: Real[Array, 'p n'], splits: Real[Array, 'p m'] 

255 ) -> UInt[Array, 'p n']: 

256 dtype = minimal_unsigned_dtype(splits.size) 1qfBghHCijrstDuvwxyzEklFGAIJK

257 return jnp.searchsorted(splits, x, **kw).astype(dtype) 1qfBghHCijrstDuvwxyzEklFGAIJK

258 

259 return bin_predictors(X, splits) 1qfBghHCijrstDuvwxyzEklFGAIJK