Coverage for src / bartz / grove / _check.py: 98%
94 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
1# bartz/src/bartz/grove/_check.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement functions to check validity of trees."""
27from typing import Protocol
29from jax import jit
30from jax import numpy as jnp
31from jaxtyping import Array, Bool, Integer, UInt
33from bartz.grove._grove import TreeHeaps, TreesTrace, is_actual_leaf
34from bartz.jaxext import autobatch, minimal_unsigned_dtype
36CHECK_FUNCTIONS = []
39class CheckFunc(Protocol):
40 """Protocol for functions that check whether a tree is valid."""
42 def __call__(
43 self, tree: TreeHeaps, max_split: UInt[Array, ' p'], /
44 ) -> bool | Bool[Array, '']:
45 """Check whether a tree is valid.
47 Parameters
48 ----------
49 tree
50 The tree to check.
51 max_split
52 The maximum split value for each variable.
54 Returns
55 -------
56 A boolean scalar indicating whether the tree is valid.
57 """
58 ...
61def check(func: CheckFunc) -> CheckFunc:
62 """Add a function to a list of functions used to check trees.
64 Use to decorate functions that check whether a tree is valid in some way.
65 These functions are invoked automatically by `check_tree`, `check_trace` and
66 `debug_gbart`.
68 Parameters
69 ----------
70 func
71 The function to add to the list. It must accept a `TreeHeaps` and a
72 `max_split` argument, and return a boolean scalar that indicates if the
73 tree is ok.
75 Returns
76 -------
77 The function unchanged.
78 """
79 CHECK_FUNCTIONS.append(func)
80 return func
83@check
84def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool:
85 """Check that integer types are as small as possible and coherent."""
86 expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1ab
87 expected_split_dtype = max_split.dtype 1ab
88 return ( 1ab
89 tree.var_tree.dtype == expected_var_dtype
90 and tree.split_tree.dtype == expected_split_dtype
91 and jnp.issubdtype(max_split.dtype, jnp.unsignedinteger)
92 )
95@check
96def check_shapes(tree: TreeHeaps, _max_split: UInt[Array, ' p']) -> bool:
97 """Check that array shapes are coherent."""
98 return ( 1ab
99 tree.leaf_tree.ndim in (1, 2)
100 and tree.var_tree.ndim == 1
101 and tree.split_tree.ndim == 1
102 and tree.leaf_tree.shape[-1]
103 == 2 * tree.var_tree.size
104 == 2 * tree.split_tree.size
105 )
108@check
109def check_unused_node(
110 tree: TreeHeaps, _max_split: UInt[Array, ' p']
111) -> Bool[Array, '']:
112 """Check that the unused node slot at index 0 is not dirty."""
113 return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0) 1ab
116@check
117def check_leaf_values(
118 tree: TreeHeaps, _max_split: UInt[Array, ' p']
119) -> Bool[Array, '']:
120 """Check that all leaf values are not inf of nan."""
121 return jnp.all(jnp.isfinite(tree.leaf_tree)) 1ab
124@check
125def check_stray_nodes(
126 tree: TreeHeaps, _max_split: UInt[Array, ' p']
127) -> Bool[Array, '']:
128 """Check if there is any marked-non-leaf node with a marked-leaf parent."""
129 index = jnp.arange( 1ab
130 2 * tree.split_tree.size,
131 dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1),
132 )
133 parent_index = index >> 1 1ab
134 is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0 1ab
135 parent_is_leaf = tree.split_tree[parent_index] == 0 1ab
136 stray = is_not_leaf & parent_is_leaf 1ab
137 stray = stray.at[1].set(False) 1ab
138 return ~jnp.any(stray) 1ab
141@check
142def check_rule_consistency(
143 tree: TreeHeaps, max_split: UInt[Array, ' p']
144) -> bool | Bool[Array, '']:
145 """Check that decision rules define proper subsets of ancestor rules."""
146 if tree.var_tree.size < 4: 146 ↛ 147line 146 didn't jump to line 147 because the condition on line 146 was never true1ab
147 return True
149 # initial boundaries of decision rules. use extreme integers instead of 0,
150 # max_split to avoid checking if there is something out of bounds.
151 dtype = tree.split_tree.dtype 1ab
152 small = jnp.iinfo(dtype).min 1ab
153 large = jnp.iinfo(dtype).max 1ab
154 lower = jnp.full(max_split.size, small, dtype) 1ab
155 upper = jnp.full(max_split.size, large, dtype) 1ab
156 # the split must be in (lower[var], upper[var]]
158 def _check_recursive( 1ab
159 node: int, lower: UInt[Array, ' p'], upper: UInt[Array, ' p']
160 ) -> Bool[Array, '']:
161 # read decision rule
162 var = tree.var_tree[node] 1ab
163 split = tree.split_tree[node] 1ab
165 # get rule boundaries from ancestors. use fill value in case var is
166 # out of bounds, we don't want to check out of bounds in this function
167 lower_var = lower.at[var].get(mode='fill', fill_value=small) 1ab
168 upper_var = upper.at[var].get(mode='fill', fill_value=large) 1ab
170 # check rule is in bounds
171 bad = jnp.where(split, (split <= lower_var) | (split > upper_var), False) 1ab
173 # recurse
174 if node < tree.var_tree.size // 2: 1ab
175 idx = jnp.where(split, var, max_split.size) 1ab
176 bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split - 1)) 1ab
177 bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper) 1ab
179 return bad 1ab
181 return ~_check_recursive(1, lower, upper) 1ab
184@check
185def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
186 """Check that #leaves = 1 + #(internal nodes)."""
187 is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True) 1ab
188 num_leaves = jnp.count_nonzero(is_leaf) 1ab
189 num_internal = jnp.count_nonzero(tree.split_tree) 1ab
190 return num_leaves == num_internal + 1 1ab
193@check
194def check_var_in_bounds(
195 tree: TreeHeaps, max_split: UInt[Array, ' p']
196) -> Bool[Array, '']:
197 """Check that variables are in [0, max_split.size)."""
198 decision_node = tree.split_tree.astype(bool) 1ab
199 in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size) 1ab
200 return jnp.all(in_bounds | ~decision_node) 1ab
203@check
204def check_split_in_bounds(
205 tree: TreeHeaps, max_split: UInt[Array, ' p']
206) -> Bool[Array, '']:
207 """Check that splits are in [0, max_split[var]]."""
208 max_split_var = ( 1ab
209 max_split.astype(jnp.int32)
210 .at[tree.var_tree]
211 .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max)
212 )
213 return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var)) 1ab
216def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']:
217 """Check the validity of a tree.
219 Use `describe_error` to parse the error code returned by this function.
221 Parameters
222 ----------
223 tree
224 The tree to check.
225 max_split
226 The maximum split value for each variable.
228 Returns
229 -------
230 An integer where each bit indicates whether a check failed.
231 """
232 error_type = minimal_unsigned_dtype(2 ** len(CHECK_FUNCTIONS) - 1) 1ab
233 error = error_type(0) 1ab
234 for i, func in enumerate(CHECK_FUNCTIONS): 1ab
235 ok = func(tree, max_split) 1ab
236 ok = jnp.bool_(ok) 1ab
237 bit = (~ok) << i 1ab
238 error |= bit 1ab
239 return error 1ab
242def describe_error(error: int | Integer[Array, '']) -> list[str]:
243 """Describe an error code returned by `check_trace`.
245 Parameters
246 ----------
247 error
248 An error code returned by `check_trace`.
250 Returns
251 -------
252 A list of the function names that implement the failed checks.
253 """
254 return [func.__name__ for i, func in enumerate(CHECK_FUNCTIONS) if error & (1 << i)] 1gb
257@jit
258def check_trace(
259 trace: TreeHeaps, max_split: UInt[Array, ' p']
260) -> UInt[Array, '*batch_shape']:
261 """Check the validity of a set of trees.
263 Use `describe_error` to parse the error codes returned by this function.
265 Parameters
266 ----------
267 trace
268 The set of trees to check. This object can have additional attributes
269 beyond the tree arrays, they are ignored.
270 max_split
271 The maximum split value for each variable.
273 Returns
274 -------
275 A tensor of error codes for each tree.
276 """
277 # vectorize check_tree over all batch dimensions
278 unpack_check_tree = lambda l, v, s: check_tree(TreesTrace(l, v, s), max_split) 1ab
279 is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim 1ab
280 signature = '(k,ts),(hts),(hts)->()' if is_mv else '(ts),(hts),(hts)->()' 1afdeb
281 vec_check_tree = jnp.vectorize(unpack_check_tree, signature=signature) 1afdeb
283 # automatically batch over all batch dimensions
284 max_io_nbytes = 2**24 # 16 MiB 1ab
285 batch_ndim = trace.split_tree.ndim - 1 1ab
286 batched_check_tree = vec_check_tree 1ab
287 for i in reversed(range(batch_ndim)): 1adeb
288 batched_check_tree = autobatch(batched_check_tree, max_io_nbytes, i, i) 1ade
290 return batched_check_tree(trace.leaf_tree, trace.var_tree, trace.split_tree) 1ab