Coverage for src/bartz/grove/_check.py: 98%
94 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/grove/_check.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement functions to check validity of trees."""
27from typing import Protocol, runtime_checkable
29from jax import numpy as jnp
30from jaxtyping import Array, Bool, Integer, UInt
32from bartz._jaxext import autobatch, jit, minimal_unsigned_dtype
33from bartz.grove._grove import TreeHeaps, TreesTrace, is_actual_leaf, is_multivariate
35CHECK_FUNCTIONS = []
38@runtime_checkable
39class CheckFunc(Protocol):
40 """Protocol for functions that check whether a tree is valid."""
42 def __call__(
43 self, tree: TreeHeaps, max_split: UInt[Array, ' p'], /
44 ) -> bool | Bool[Array, '']:
45 """Check whether a tree is valid.
47 Parameters
48 ----------
49 tree
50 The tree to check.
51 max_split
52 The maximum split value for each variable.
54 Returns
55 -------
56 A boolean scalar indicating whether the tree is valid.
57 """
58 ...
61def check(func: CheckFunc) -> CheckFunc:
62 """Add a function to a list of functions used to check trees.
64 Use to decorate functions that check whether a tree is valid in some way.
65 These functions are invoked automatically by `check_tree` and `check_trace`.
67 Parameters
68 ----------
69 func
70 The function to add to the list. It must accept a `TreeHeaps` and a
71 `max_split` argument, and return a boolean scalar that indicates if the
72 tree is ok.
74 Returns
75 -------
76 The function unchanged.
77 """
78 CHECK_FUNCTIONS.append(func)
79 return func
82@check
83def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool:
84 """Check that integer types are as small as possible and coherent."""
85 expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1)
86 expected_split_dtype = max_split.dtype
87 return (
88 tree.var_tree.dtype == expected_var_dtype
89 and tree.split_tree.dtype == expected_split_dtype
90 and jnp.issubdtype(max_split.dtype, jnp.unsignedinteger)
91 )
94@check
95def check_shapes(tree: TreeHeaps, _max_split: UInt[Array, ' p']) -> bool:
96 """Check that array shapes are coherent."""
97 return (
98 tree.leaf_tree.ndim in (1, 2)
99 and tree.var_tree.ndim == 1
100 and tree.split_tree.ndim == 1
101 and tree.leaf_tree.shape[-1]
102 == 2 * tree.var_tree.size
103 == 2 * tree.split_tree.size
104 )
107@check
108def check_unused_node(
109 tree: TreeHeaps, _max_split: UInt[Array, ' p']
110) -> Bool[Array, '']:
111 """Check that the unused node slot at index 0 is not dirty."""
112 return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0)
115@check
116def check_leaf_values(
117 tree: TreeHeaps, _max_split: UInt[Array, ' p']
118) -> Bool[Array, '']:
119 """Check that all leaf values are not inf of nan."""
120 return jnp.all(jnp.isfinite(tree.leaf_tree))
123@check
124def check_stray_nodes(
125 tree: TreeHeaps, _max_split: UInt[Array, ' p']
126) -> Bool[Array, '']:
127 """Check if there is any marked-non-leaf node with a marked-leaf parent."""
128 index = jnp.arange(
129 2 * tree.split_tree.size,
130 dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1),
131 )
132 parent_index = index >> 1
133 is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0
134 parent_is_leaf = tree.split_tree[parent_index] == 0
135 stray = is_not_leaf & parent_is_leaf
136 stray = stray.at[1].set(False)
137 return ~jnp.any(stray)
140@check
141def check_rule_consistency(
142 tree: TreeHeaps, max_split: UInt[Array, ' p']
143) -> bool | Bool[Array, '']:
144 """Check that decision rules define proper subsets of ancestor rules."""
145 if tree.var_tree.size < 4: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true
146 return True
148 # initial boundaries of decision rules. use extreme integers instead of 0,
149 # max_split to avoid checking if there is something out of bounds.
150 dtype = tree.split_tree.dtype
151 small = jnp.iinfo(dtype).min
152 large = jnp.iinfo(dtype).max
153 lower = jnp.full(max_split.size, small, dtype)
154 upper = jnp.full(max_split.size, large, dtype)
155 # the split must be in (lower[var], upper[var]]
157 def _check_recursive(
158 node: int, lower: UInt[Array, ' p'], upper: UInt[Array, ' p']
159 ) -> Bool[Array, '']:
160 # read decision rule
161 var = tree.var_tree[node]
162 split = tree.split_tree[node]
164 # get rule boundaries from ancestors. use fill value in case var is
165 # out of bounds, we don't want to check out of bounds in this function
166 lower_var = lower.at[var].get(mode='fill', fill_value=small)
167 upper_var = upper.at[var].get(mode='fill', fill_value=large)
169 # check rule is in bounds
170 bad = jnp.where(split, (split <= lower_var) | (split > upper_var), False)
172 # recurse
173 if node < tree.var_tree.size // 2:
174 idx = jnp.where(split, var, max_split.size)
175 bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split - 1))
176 bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper)
178 return bad
180 return ~_check_recursive(1, lower, upper)
183@check
184def check_num_nodes(
185 tree: TreeHeaps,
186 max_split: UInt[Array, ' p'], # noqa: ARG001
187) -> Bool[Array, '']:
188 """Check that #leaves = 1 + #(internal nodes)."""
189 is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True)
190 num_leaves = jnp.count_nonzero(is_leaf)
191 num_internal = jnp.count_nonzero(tree.split_tree)
192 return num_leaves == num_internal + 1
195@check
196def check_var_in_bounds(
197 tree: TreeHeaps, max_split: UInt[Array, ' p']
198) -> Bool[Array, '']:
199 """Check that variables are in [0, max_split.size)."""
200 decision_node = tree.split_tree.astype(bool)
201 in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size)
202 return jnp.all(in_bounds | ~decision_node)
205@check
206def check_split_in_bounds(
207 tree: TreeHeaps, max_split: UInt[Array, ' p']
208) -> Bool[Array, '']:
209 """Check that splits are in [0, max_split[var]]."""
210 max_split_var = (
211 max_split.astype(jnp.int32)
212 .at[tree.var_tree]
213 .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max)
214 )
215 return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var))
218def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']:
219 """Check the validity of a tree.
221 Use `describe_error` to parse the error code returned by this function.
223 Parameters
224 ----------
225 tree
226 The tree to check.
227 max_split
228 The maximum split value for each variable.
230 Returns
231 -------
232 An integer where each bit indicates whether a check failed.
233 """
234 error_type = minimal_unsigned_dtype(2 ** len(CHECK_FUNCTIONS) - 1)
235 error = jnp.zeros((), error_type)
236 for i, func in enumerate(CHECK_FUNCTIONS):
237 ok = func(tree, max_split)
238 ok = jnp.bool_(ok)
239 bit = (~ok) << i
240 error |= bit
241 return error
244def describe_error(error: int | Integer[Array, '']) -> list[str]:
245 """Describe an error code returned by `check_trace`.
247 Parameters
248 ----------
249 error
250 An error code returned by `check_trace`.
252 Returns
253 -------
254 A list of the function names that implement the failed checks.
255 """
256 return [func.__name__ for i, func in enumerate(CHECK_FUNCTIONS) if error & (1 << i)]
259@jit
260def check_trace(
261 trace: TreeHeaps, max_split: UInt[Array, ' p']
262) -> UInt[Array, '*batch_shape']:
263 """Check the validity of a set of trees.
265 Use `describe_error` to parse the error codes returned by this function.
267 Parameters
268 ----------
269 trace
270 The set of trees to check. This object can have additional attributes
271 beyond the tree arrays, they are ignored.
272 max_split
273 The maximum split value for each variable.
275 Returns
276 -------
277 A tensor of error codes for each tree.
278 """
279 # vectorize check_tree over all batch dimensions
280 unpack_check_tree = lambda l, v, s: check_tree(
281 TreesTrace(leaf_tree=l, var_tree=v, split_tree=s), max_split
282 )
283 is_mv = is_multivariate(trace)
284 signature = '(k,ts),(hts),(hts)->()' if is_mv else '(ts),(hts),(hts)->()'
285 vec_check_tree = jnp.vectorize(unpack_check_tree, signature=signature)
287 # automatically batch over all batch dimensions
288 max_io_nbytes = 2**24 # 16 MiB
289 batch_ndim = trace.split_tree.ndim - 1
290 batched_check_tree = vec_check_tree
291 for i in reversed(range(batch_ndim)):
292 batched_check_tree = autobatch(batched_check_tree, max_io_nbytes, i, i)
294 return batched_check_tree(trace.leaf_tree, trace.var_tree, trace.split_tree)