Coverage for src / bartz / debug / _check.py: 98%

95 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/debug/_check.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implement functions to check validity of trees.""" 

26 

27from typing import Protocol 

28 

29from jax import jit 

30from jax import numpy as jnp 

31from jaxtyping import Array, Bool, Integer, UInt 

32 

33from bartz.grove import TreeHeaps, is_actual_leaf 

34from bartz.jaxext import autobatch, minimal_unsigned_dtype 

35from bartz.mcmcloop import TreesTrace 

36 

37CHECK_FUNCTIONS = [] 

38 

39 

40class CheckFunc(Protocol): 

41 """Protocol for functions that check whether a tree is valid.""" 

42 

43 def __call__( 

44 self, tree: TreeHeaps, max_split: UInt[Array, ' p'], / 

45 ) -> bool | Bool[Array, '']: 

46 """Check whether a tree is valid. 

47 

48 Parameters 

49 ---------- 

50 tree 

51 The tree to check. 

52 max_split 

53 The maximum split value for each variable. 

54 

55 Returns 

56 ------- 

57 A boolean scalar indicating whether the tree is valid. 

58 """ 

59 ... 

60 

61 

62def check(func: CheckFunc) -> CheckFunc: 

63 """Add a function to a list of functions used to check trees. 

64 

65 Use to decorate functions that check whether a tree is valid in some way. 

66 These functions are invoked automatically by `check_tree`, `check_trace` and 

67 `debug_gbart`. 

68 

69 Parameters 

70 ---------- 

71 func 

72 The function to add to the list. It must accept a `TreeHeaps` and a 

73 `max_split` argument, and return a boolean scalar that indicates if the 

74 tree is ok. 

75 

76 Returns 

77 ------- 

78 The function unchanged. 

79 """ 

80 CHECK_FUNCTIONS.append(func) 

81 return func 

82 

83 

84@check 

85def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: 

86 """Check that integer types are as small as possible and coherent.""" 

87 expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

88 expected_split_dtype = max_split.dtype 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

89 return ( 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

90 tree.var_tree.dtype == expected_var_dtype 

91 and tree.split_tree.dtype == expected_split_dtype 

92 and jnp.issubdtype(max_split.dtype, jnp.unsignedinteger) 

93 ) 

94 

95 

96@check 

97def check_sizes(tree: TreeHeaps, _max_split: UInt[Array, ' p']) -> bool: 

98 """Check that array sizes are coherent.""" 

99 return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

100 

101 

102@check 

103def check_unused_node( 

104 tree: TreeHeaps, _max_split: UInt[Array, ' p'] 

105) -> Bool[Array, '']: 

106 """Check that the unused node slot at index 0 is not dirty.""" 

107 return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

108 

109 

110@check 

111def check_leaf_values( 

112 tree: TreeHeaps, _max_split: UInt[Array, ' p'] 

113) -> Bool[Array, '']: 

114 """Check that all leaf values are not inf of nan.""" 

115 return jnp.all(jnp.isfinite(tree.leaf_tree)) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

116 

117 

118@check 

119def check_stray_nodes( 

120 tree: TreeHeaps, _max_split: UInt[Array, ' p'] 

121) -> Bool[Array, '']: 

122 """Check if there is any marked-non-leaf node with a marked-leaf parent.""" 

123 index = jnp.arange( 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

124 2 * tree.split_tree.size, 

125 dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1), 

126 ) 

127 parent_index = index >> 1 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

128 is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

129 parent_is_leaf = tree.split_tree[parent_index] == 0 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

130 stray = is_not_leaf & parent_is_leaf 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

131 stray = stray.at[1].set(False) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

132 return ~jnp.any(stray) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

133 

134 

135@check 

136def check_rule_consistency( 

137 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

138) -> bool | Bool[Array, '']: 

139 """Check that decision rules define proper subsets of ancestor rules.""" 

140 if tree.var_tree.size < 4: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

141 return True 

142 

143 # initial boundaries of decision rules. use extreme integers instead of 0, 

144 # max_split to avoid checking if there is something out of bounds. 

145 dtype = tree.split_tree.dtype 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

146 small = jnp.iinfo(dtype).min 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

147 large = jnp.iinfo(dtype).max 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

148 lower = jnp.full(max_split.size, small, dtype) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

149 upper = jnp.full(max_split.size, large, dtype) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

150 # the split must be in (lower[var], upper[var]] 

151 

152 def _check_recursive( 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

153 node: int, lower: UInt[Array, ' p'], upper: UInt[Array, ' p'] 

154 ) -> Bool[Array, '']: 

155 # read decision rule 

156 var = tree.var_tree[node] 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

157 split = tree.split_tree[node] 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

158 

159 # get rule boundaries from ancestors. use fill value in case var is 

160 # out of bounds, we don't want to check out of bounds in this function 

161 lower_var = lower.at[var].get(mode='fill', fill_value=small) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

162 upper_var = upper.at[var].get(mode='fill', fill_value=large) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

163 

164 # check rule is in bounds 

165 bad = jnp.where(split, (split <= lower_var) | (split > upper_var), False) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

166 

167 # recurse 

168 if node < tree.var_tree.size // 2: 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

169 idx = jnp.where(split, var, max_split.size) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

170 bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split - 1)) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

171 bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

172 

173 return bad 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

174 

175 return ~_check_recursive(1, lower, upper) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

176 

177 

178@check 

179def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 

180 """Check that #leaves = 1 + #(internal nodes).""" 

181 is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

182 num_leaves = jnp.count_nonzero(is_leaf) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

183 num_internal = jnp.count_nonzero(tree.split_tree) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

184 return num_leaves == num_internal + 1 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

185 

186 

187@check 

188def check_var_in_bounds( 

189 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

190) -> Bool[Array, '']: 

191 """Check that variables are in [0, max_split.size).""" 

192 decision_node = tree.split_tree.astype(bool) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

193 in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

194 return jnp.all(in_bounds | ~decision_node) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

195 

196 

197@check 

198def check_split_in_bounds( 

199 tree: TreeHeaps, max_split: UInt[Array, ' p'] 

200) -> Bool[Array, '']: 

201 """Check that splits are in [0, max_split[var]].""" 

202 max_split_var = ( 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

203 max_split.astype(jnp.int32) 

204 .at[tree.var_tree] 

205 .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max) 

206 ) 

207 return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var)) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

208 

209 

210def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']: 

211 """Check the validity of a tree. 

212 

213 Use `describe_error` to parse the error code returned by this function. 

214 

215 Parameters 

216 ---------- 

217 tree 

218 The tree to check. 

219 max_split 

220 The maximum split value for each variable. 

221 

222 Returns 

223 ------- 

224 An integer where each bit indicates whether a check failed. 

225 """ 

226 error_type = minimal_unsigned_dtype(2 ** len(CHECK_FUNCTIONS) - 1) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

227 error = error_type(0) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

228 for i, func in enumerate(CHECK_FUNCTIONS): 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

229 ok = func(tree, max_split) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

230 ok = jnp.bool_(ok) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

231 bit = (~ok) << i 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

232 error |= bit 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

233 return error 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

234 

235 

236def describe_error(error: int | Integer[Array, '']) -> list[str]: 

237 """Describe an error code returned by `check_trace`. 

238 

239 Parameters 

240 ---------- 

241 error 

242 An error code returned by `check_trace`. 

243 

244 Returns 

245 ------- 

246 A list of the function names that implement the failed checks. 

247 """ 

248 return [func.__name__ for i, func in enumerate(CHECK_FUNCTIONS) if error & (1 << i)] 1GMHNIJOKPQRSTUVW

249 

250 

251@jit 

252def check_trace( 

253 trace: TreeHeaps, max_split: UInt[Array, ' p'] 

254) -> UInt[Array, '*batch_shape']: 

255 """Check the validity of a set of trees. 

256 

257 Use `describe_error` to parse the error codes returned by this function. 

258 

259 Parameters 

260 ---------- 

261 trace 

262 The set of trees to check. This object can have additional attributes 

263 beyond the tree arrays, they are ignored. 

264 max_split 

265 The maximum split value for each variable. 

266 

267 Returns 

268 ------- 

269 A tensor of error codes for each tree. 

270 """ 

271 # vectorize check_tree over all batch dimensions 

272 unpack_check_tree = lambda l, v, s: check_tree(TreesTrace(l, v, s), max_split) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

273 is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

274 signature = '(k,ts),(hts),(hts)->()' if is_mv else '(ts),(hts),(hts)->()' 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

275 vec_check_tree = jnp.vectorize(unpack_check_tree, signature=signature) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

276 

277 # automatically batch over all batch dimensions 

278 max_io_nbytes = 2**24 # 16 MiB 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

279 batch_ndim = trace.split_tree.ndim - 1 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

280 batched_check_tree = vec_check_tree 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

281 for i in reversed(range(batch_ndim)): 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK

282 batched_check_tree = autobatch(batched_check_tree, max_io_nbytes, i, i) 1abcdefghijklmnopqrstuvwxyzABCDEF

283 

284 return batched_check_tree(trace.leaf_tree, trace.var_tree, trace.split_tree) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK