Coverage for src / bartz / grove / _check.py: 98%

94 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 18:11 +0000

1# bartz/src/bartz/grove/_check.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement functions to check validity of trees.""" 

26 

27from typing import Protocol 

28 

29from jax import jit 

30from jax import numpy as jnp 

31from jaxtyping import Array, Bool, Integer, UInt 

32 

33from bartz.grove._grove import TreeHeaps, TreesTrace, is_actual_leaf 

34from bartz.jaxext import autobatch, minimal_unsigned_dtype 

35 

36CHECK_FUNCTIONS = [] 

37 

38 

39class CheckFunc(Protocol): 

40 """Protocol for functions that check whether a tree is valid.""" 

41 

42 def __call__( 

43 self, tree: TreeHeaps, max_split: UInt[Array, ' p'], / 

44 ) -> bool | Bool[Array, '']: 

45 """Check whether a tree is valid. 

46 

47 Parameters 

48 ---------- 

49 tree 

50 The tree to check. 

51 max_split 

52 The maximum split value for each variable. 

53 

54 Returns 

55 ------- 

56 A boolean scalar indicating whether the tree is valid. 

57 """ 

58 ... 

59 

60 

61def check(func: CheckFunc) -> CheckFunc: 

62 """Add a function to a list of functions used to check trees. 

63 

64 Use to decorate functions that check whether a tree is valid in some way. 

65 These functions are invoked automatically by `check_tree`, `check_trace` and 

66 `debug_gbart`. 

67 

68 Parameters 

69 ---------- 

70 func 

71 The function to add to the list. It must accept a `TreeHeaps` and a 

72 `max_split` argument, and return a boolean scalar that indicates if the 

73 tree is ok. 

74 

75 Returns 

76 ------- 

77 The function unchanged. 

78 """ 

79 CHECK_FUNCTIONS.append(func) 

80 return func 

81 

82 

83@check 

84def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: 

85 """Check that integer types are as small as possible and coherent.""" 

86 expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1ab

87 expected_split_dtype = max_split.dtype 1ab

88 return ( 1ab

89 tree.var_tree.dtype == expected_var_dtype 

90 and tree.split_tree.dtype == expected_split_dtype 

91 and jnp.issubdtype(max_split.dtype, jnp.unsignedinteger) 

92 ) 

93 

94 

95@check 

96def check_shapes(tree: TreeHeaps, _max_split: UInt[Array, ' p']) -> bool: 

97 """Check that array shapes are coherent.""" 

98 return ( 1ab

99 tree.leaf_tree.ndim in (1, 2) 

100 and tree.var_tree.ndim == 1 

101 and tree.split_tree.ndim == 1 

102 and tree.leaf_tree.shape[-1] 

103 == 2 * tree.var_tree.size 

104 == 2 * tree.split_tree.size 

105 ) 

106 

107 

108@check 

109def check_unused_node( 

110 tree: TreeHeaps, _max_split: UInt[Array, ' p'] 

111) -> Bool[Array, '']: 

112 """Check that the unused node slot at index 0 is not dirty.""" 

113 return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0) 1ab

114 

115 

116@check 

117def check_leaf_values( 

118 tree: TreeHeaps, _max_split: UInt[Array, ' p'] 

119) -> Bool[Array, '']: 

120 """Check that all leaf values are not inf of nan.""" 

121 return jnp.all(jnp.isfinite(tree.leaf_tree)) 1ab

122 

123 

124@check 

125def check_stray_nodes( 

126 tree: TreeHeaps, _max_split: UInt[Array, ' p'] 

127) -> Bool[Array, '']: 

128 """Check if there is any marked-non-leaf node with a marked-leaf parent.""" 

129 index = jnp.arange( 1ab

130 2 * tree.split_tree.size, 

131 dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1), 

132 ) 

133 parent_index = index >> 1 1ab

134 is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0 1ab

135 parent_is_leaf = tree.split_tree[parent_index] == 0 1ab

136 stray = is_not_leaf & parent_is_leaf 1ab

137 stray = stray.at[1].set(False) 1ab

138 return ~jnp.any(stray) 1ab

139 

140 

141@check 

142def check_rule_consistency( 

143 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

144) -> bool | Bool[Array, '']: 

145 """Check that decision rules define proper subsets of ancestor rules.""" 

146 if tree.var_tree.size < 4: 146 ↛ 147line 146 didn't jump to line 147 because the condition on line 146 was never true1ab

147 return True 

148 

149 # initial boundaries of decision rules. use extreme integers instead of 0, 

150 # max_split to avoid checking if there is something out of bounds. 

151 dtype = tree.split_tree.dtype 1ab

152 small = jnp.iinfo(dtype).min 1ab

153 large = jnp.iinfo(dtype).max 1ab

154 lower = jnp.full(max_split.size, small, dtype) 1ab

155 upper = jnp.full(max_split.size, large, dtype) 1ab

156 # the split must be in (lower[var], upper[var]] 

157 

158 def _check_recursive( 1ab

159 node: int, lower: UInt[Array, ' p'], upper: UInt[Array, ' p'] 

160 ) -> Bool[Array, '']: 

161 # read decision rule 

162 var = tree.var_tree[node] 1ab

163 split = tree.split_tree[node] 1ab

164 

165 # get rule boundaries from ancestors. use fill value in case var is 

166 # out of bounds, we don't want to check out of bounds in this function 

167 lower_var = lower.at[var].get(mode='fill', fill_value=small) 1ab

168 upper_var = upper.at[var].get(mode='fill', fill_value=large) 1ab

169 

170 # check rule is in bounds 

171 bad = jnp.where(split, (split <= lower_var) | (split > upper_var), False) 1ab

172 

173 # recurse 

174 if node < tree.var_tree.size // 2: 1ab

175 idx = jnp.where(split, var, max_split.size) 1ab

176 bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split - 1)) 1ab

177 bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper) 1ab

178 

179 return bad 1ab

180 

181 return ~_check_recursive(1, lower, upper) 1ab

182 

183 

184@check 

185def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 

186 """Check that #leaves = 1 + #(internal nodes).""" 

187 is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True) 1ab

188 num_leaves = jnp.count_nonzero(is_leaf) 1ab

189 num_internal = jnp.count_nonzero(tree.split_tree) 1ab

190 return num_leaves == num_internal + 1 1ab

191 

192 

193@check 

194def check_var_in_bounds( 

195 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

196) -> Bool[Array, '']: 

197 """Check that variables are in [0, max_split.size).""" 

198 decision_node = tree.split_tree.astype(bool) 1ab

199 in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size) 1ab

200 return jnp.all(in_bounds | ~decision_node) 1ab

201 

202 

203@check 

204def check_split_in_bounds( 

205 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

206) -> Bool[Array, '']: 

207 """Check that splits are in [0, max_split[var]].""" 

208 max_split_var = ( 1ab

209 max_split.astype(jnp.int32) 

210 .at[tree.var_tree] 

211 .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max) 

212 ) 

213 return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var)) 1ab

214 

215 

216def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']: 

217 """Check the validity of a tree. 

218 

219 Use `describe_error` to parse the error code returned by this function. 

220 

221 Parameters 

222 ---------- 

223 tree 

224 The tree to check. 

225 max_split 

226 The maximum split value for each variable. 

227 

228 Returns 

229 ------- 

230 An integer where each bit indicates whether a check failed. 

231 """ 

232 error_type = minimal_unsigned_dtype(2 ** len(CHECK_FUNCTIONS) - 1) 1ab

233 error = error_type(0) 1ab

234 for i, func in enumerate(CHECK_FUNCTIONS): 1ab

235 ok = func(tree, max_split) 1ab

236 ok = jnp.bool_(ok) 1ab

237 bit = (~ok) << i 1ab

238 error |= bit 1ab

239 return error 1ab

240 

241 

242def describe_error(error: int | Integer[Array, '']) -> list[str]: 

243 """Describe an error code returned by `check_trace`. 

244 

245 Parameters 

246 ---------- 

247 error 

248 An error code returned by `check_trace`. 

249 

250 Returns 

251 ------- 

252 A list of the function names that implement the failed checks. 

253 """ 

254 return [func.__name__ for i, func in enumerate(CHECK_FUNCTIONS) if error & (1 << i)] 1gb

255 

256 

257@jit 

258def check_trace( 

259 trace: TreeHeaps, max_split: UInt[Array, ' p'] 

260) -> UInt[Array, '*batch_shape']: 

261 """Check the validity of a set of trees. 

262 

263 Use `describe_error` to parse the error codes returned by this function. 

264 

265 Parameters 

266 ---------- 

267 trace 

268 The set of trees to check. This object can have additional attributes 

269 beyond the tree arrays, they are ignored. 

270 max_split 

271 The maximum split value for each variable. 

272 

273 Returns 

274 ------- 

275 A tensor of error codes for each tree. 

276 """ 

277 # vectorize check_tree over all batch dimensions 

278 unpack_check_tree = lambda l, v, s: check_tree(TreesTrace(l, v, s), max_split) 1ab

279 is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim 1ab

280 signature = '(k,ts),(hts),(hts)->()' if is_mv else '(ts),(hts),(hts)->()' 1afdeb

281 vec_check_tree = jnp.vectorize(unpack_check_tree, signature=signature) 1afdeb

282 

283 # automatically batch over all batch dimensions 

284 max_io_nbytes = 2**24 # 16 MiB 1ab

285 batch_ndim = trace.split_tree.ndim - 1 1ab

286 batched_check_tree = vec_check_tree 1ab

287 for i in reversed(range(batch_ndim)): 1adeb

288 batched_check_tree = autobatch(batched_check_tree, max_io_nbytes, i, i) 1ade

289 

290 return batched_check_tree(trace.leaf_tree, trace.var_tree, trace.split_tree) 1ab