Coverage for src/bartz/grove/_check.py: 98%

94 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/grove/_check.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement functions to check validity of trees.""" 

26 

27from typing import Protocol, runtime_checkable 

28 

29from jax import numpy as jnp 

30from jaxtyping import Array, Bool, Integer, UInt 

31 

32from bartz._jaxext import autobatch, jit, minimal_unsigned_dtype 

33from bartz.grove._grove import TreeHeaps, TreesTrace, is_actual_leaf, is_multivariate 

34 

35CHECK_FUNCTIONS = [] 

36 

37 

38@runtime_checkable 

39class CheckFunc(Protocol): 

40 """Protocol for functions that check whether a tree is valid.""" 

41 

42 def __call__( 

43 self, tree: TreeHeaps, max_split: UInt[Array, ' p'], / 

44 ) -> bool | Bool[Array, '']: 

45 """Check whether a tree is valid. 

46 

47 Parameters 

48 ---------- 

49 tree 

50 The tree to check. 

51 max_split 

52 The maximum split value for each variable. 

53 

54 Returns 

55 ------- 

56 A boolean scalar indicating whether the tree is valid. 

57 """ 

58 ... 

59 

60 

61def check(func: CheckFunc) -> CheckFunc: 

62 """Add a function to a list of functions used to check trees. 

63 

64 Use to decorate functions that check whether a tree is valid in some way. 

65 These functions are invoked automatically by `check_tree` and `check_trace`. 

66 

67 Parameters 

68 ---------- 

69 func 

70 The function to add to the list. It must accept a `TreeHeaps` and a 

71 `max_split` argument, and return a boolean scalar that indicates if the 

72 tree is ok. 

73 

74 Returns 

75 ------- 

76 The function unchanged. 

77 """ 

78 CHECK_FUNCTIONS.append(func) 

79 return func 

80 

81 

82@check 

83def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: 

84 """Check that integer types are as small as possible and coherent.""" 

85 expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1) 

86 expected_split_dtype = max_split.dtype 

87 return ( 

88 tree.var_tree.dtype == expected_var_dtype 

89 and tree.split_tree.dtype == expected_split_dtype 

90 and jnp.issubdtype(max_split.dtype, jnp.unsignedinteger) 

91 ) 

92 

93 

94@check 

95def check_shapes(tree: TreeHeaps, _max_split: UInt[Array, ' p']) -> bool: 

96 """Check that array shapes are coherent.""" 

97 return ( 

98 tree.leaf_tree.ndim in (1, 2) 

99 and tree.var_tree.ndim == 1 

100 and tree.split_tree.ndim == 1 

101 and tree.leaf_tree.shape[-1] 

102 == 2 * tree.var_tree.size 

103 == 2 * tree.split_tree.size 

104 ) 

105 

106 

107@check 

108def check_unused_node( 

109 tree: TreeHeaps, _max_split: UInt[Array, ' p'] 

110) -> Bool[Array, '']: 

111 """Check that the unused node slot at index 0 is not dirty.""" 

112 return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0) 

113 

114 

115@check 

116def check_leaf_values( 

117 tree: TreeHeaps, _max_split: UInt[Array, ' p'] 

118) -> Bool[Array, '']: 

119 """Check that all leaf values are not inf of nan.""" 

120 return jnp.all(jnp.isfinite(tree.leaf_tree)) 

121 

122 

123@check 

124def check_stray_nodes( 

125 tree: TreeHeaps, _max_split: UInt[Array, ' p'] 

126) -> Bool[Array, '']: 

127 """Check if there is any marked-non-leaf node with a marked-leaf parent.""" 

128 index = jnp.arange( 

129 2 * tree.split_tree.size, 

130 dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1), 

131 ) 

132 parent_index = index >> 1 

133 is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0 

134 parent_is_leaf = tree.split_tree[parent_index] == 0 

135 stray = is_not_leaf & parent_is_leaf 

136 stray = stray.at[1].set(False) 

137 return ~jnp.any(stray) 

138 

139 

140@check 

141def check_rule_consistency( 

142 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

143) -> bool | Bool[Array, '']: 

144 """Check that decision rules define proper subsets of ancestor rules.""" 

145 if tree.var_tree.size < 4: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true

146 return True 

147 

148 # initial boundaries of decision rules. use extreme integers instead of 0, 

149 # max_split to avoid checking if there is something out of bounds. 

150 dtype = tree.split_tree.dtype 

151 small = jnp.iinfo(dtype).min 

152 large = jnp.iinfo(dtype).max 

153 lower = jnp.full(max_split.size, small, dtype) 

154 upper = jnp.full(max_split.size, large, dtype) 

155 # the split must be in (lower[var], upper[var]] 

156 

157 def _check_recursive( 

158 node: int, lower: UInt[Array, ' p'], upper: UInt[Array, ' p'] 

159 ) -> Bool[Array, '']: 

160 # read decision rule 

161 var = tree.var_tree[node] 

162 split = tree.split_tree[node] 

163 

164 # get rule boundaries from ancestors. use fill value in case var is 

165 # out of bounds, we don't want to check out of bounds in this function 

166 lower_var = lower.at[var].get(mode='fill', fill_value=small) 

167 upper_var = upper.at[var].get(mode='fill', fill_value=large) 

168 

169 # check rule is in bounds 

170 bad = jnp.where(split, (split <= lower_var) | (split > upper_var), False) 

171 

172 # recurse 

173 if node < tree.var_tree.size // 2: 

174 idx = jnp.where(split, var, max_split.size) 

175 bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split - 1)) 

176 bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper) 

177 

178 return bad 

179 

180 return ~_check_recursive(1, lower, upper) 

181 

182 

183@check 

184def check_num_nodes( 

185 tree: TreeHeaps, 

186 max_split: UInt[Array, ' p'], # noqa: ARG001 

187) -> Bool[Array, '']: 

188 """Check that #leaves = 1 + #(internal nodes).""" 

189 is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True) 

190 num_leaves = jnp.count_nonzero(is_leaf) 

191 num_internal = jnp.count_nonzero(tree.split_tree) 

192 return num_leaves == num_internal + 1 

193 

194 

195@check 

196def check_var_in_bounds( 

197 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

198) -> Bool[Array, '']: 

199 """Check that variables are in [0, max_split.size).""" 

200 decision_node = tree.split_tree.astype(bool) 

201 in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size) 

202 return jnp.all(in_bounds | ~decision_node) 

203 

204 

205@check 

206def check_split_in_bounds( 

207 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

208) -> Bool[Array, '']: 

209 """Check that splits are in [0, max_split[var]].""" 

210 max_split_var = ( 

211 max_split.astype(jnp.int32) 

212 .at[tree.var_tree] 

213 .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max) 

214 ) 

215 return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var)) 

216 

217 

218def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']: 

219 """Check the validity of a tree. 

220 

221 Use `describe_error` to parse the error code returned by this function. 

222 

223 Parameters 

224 ---------- 

225 tree 

226 The tree to check. 

227 max_split 

228 The maximum split value for each variable. 

229 

230 Returns 

231 ------- 

232 An integer where each bit indicates whether a check failed. 

233 """ 

234 error_type = minimal_unsigned_dtype(2 ** len(CHECK_FUNCTIONS) - 1) 

235 error = jnp.zeros((), error_type) 

236 for i, func in enumerate(CHECK_FUNCTIONS): 

237 ok = func(tree, max_split) 

238 ok = jnp.bool_(ok) 

239 bit = (~ok) << i 

240 error |= bit 

241 return error 

242 

243 

244def describe_error(error: int | Integer[Array, '']) -> list[str]: 

245 """Describe an error code returned by `check_trace`. 

246 

247 Parameters 

248 ---------- 

249 error 

250 An error code returned by `check_trace`. 

251 

252 Returns 

253 ------- 

254 A list of the function names that implement the failed checks. 

255 """ 

256 return [func.__name__ for i, func in enumerate(CHECK_FUNCTIONS) if error & (1 << i)] 

257 

258 

259@jit 

260def check_trace( 

261 trace: TreeHeaps, max_split: UInt[Array, ' p'] 

262) -> UInt[Array, '*batch_shape']: 

263 """Check the validity of a set of trees. 

264 

265 Use `describe_error` to parse the error codes returned by this function. 

266 

267 Parameters 

268 ---------- 

269 trace 

270 The set of trees to check. This object can have additional attributes 

271 beyond the tree arrays, they are ignored. 

272 max_split 

273 The maximum split value for each variable. 

274 

275 Returns 

276 ------- 

277 A tensor of error codes for each tree. 

278 """ 

279 # vectorize check_tree over all batch dimensions 

280 unpack_check_tree = lambda l, v, s: check_tree( 

281 TreesTrace(leaf_tree=l, var_tree=v, split_tree=s), max_split 

282 ) 

283 is_mv = is_multivariate(trace) 

284 signature = '(k,ts),(hts),(hts)->()' if is_mv else '(ts),(hts),(hts)->()' 

285 vec_check_tree = jnp.vectorize(unpack_check_tree, signature=signature) 

286 

287 # automatically batch over all batch dimensions 

288 max_io_nbytes = 2**24 # 16 MiB 

289 batch_ndim = trace.split_tree.ndim - 1 

290 batched_check_tree = vec_check_tree 

291 for i in reversed(range(batch_ndim)): 

292 batched_check_tree = autobatch(batched_check_tree, max_io_nbytes, i, i) 

293 

294 return batched_check_tree(trace.leaf_tree, trace.var_tree, trace.split_tree)