Coverage for src / bartz / debug / _check.py: 98%
95 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/debug/_check.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement functions to check validity of trees."""
27from typing import Protocol
29from jax import jit
30from jax import numpy as jnp
31from jaxtyping import Array, Bool, Integer, UInt
33from bartz.grove import TreeHeaps, is_actual_leaf
34from bartz.jaxext import autobatch, minimal_unsigned_dtype
35from bartz.mcmcloop import TreesTrace
37CHECK_FUNCTIONS = []
40class CheckFunc(Protocol):
41 """Protocol for functions that check whether a tree is valid."""
43 def __call__(
44 self, tree: TreeHeaps, max_split: UInt[Array, ' p'], /
45 ) -> bool | Bool[Array, '']:
46 """Check whether a tree is valid.
48 Parameters
49 ----------
50 tree
51 The tree to check.
52 max_split
53 The maximum split value for each variable.
55 Returns
56 -------
57 A boolean scalar indicating whether the tree is valid.
58 """
59 ...
62def check(func: CheckFunc) -> CheckFunc:
63 """Add a function to a list of functions used to check trees.
65 Use to decorate functions that check whether a tree is valid in some way.
66 These functions are invoked automatically by `check_tree`, `check_trace` and
67 `debug_gbart`.
69 Parameters
70 ----------
71 func
72 The function to add to the list. It must accept a `TreeHeaps` and a
73 `max_split` argument, and return a boolean scalar that indicates if the
74 tree is ok.
76 Returns
77 -------
78 The function unchanged.
79 """
80 CHECK_FUNCTIONS.append(func)
81 return func
84@check
85def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool:
86 """Check that integer types are as small as possible and coherent."""
87 expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
88 expected_split_dtype = max_split.dtype 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
89 return ( 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
90 tree.var_tree.dtype == expected_var_dtype
91 and tree.split_tree.dtype == expected_split_dtype
92 and jnp.issubdtype(max_split.dtype, jnp.unsignedinteger)
93 )
96@check
97def check_sizes(tree: TreeHeaps, _max_split: UInt[Array, ' p']) -> bool:
98 """Check that array sizes are coherent."""
99 return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
102@check
103def check_unused_node(
104 tree: TreeHeaps, _max_split: UInt[Array, ' p']
105) -> Bool[Array, '']:
106 """Check that the unused node slot at index 0 is not dirty."""
107 return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
110@check
111def check_leaf_values(
112 tree: TreeHeaps, _max_split: UInt[Array, ' p']
113) -> Bool[Array, '']:
114 """Check that all leaf values are not inf of nan."""
115 return jnp.all(jnp.isfinite(tree.leaf_tree)) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
118@check
119def check_stray_nodes(
120 tree: TreeHeaps, _max_split: UInt[Array, ' p']
121) -> Bool[Array, '']:
122 """Check if there is any marked-non-leaf node with a marked-leaf parent."""
123 index = jnp.arange( 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
124 2 * tree.split_tree.size,
125 dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1),
126 )
127 parent_index = index >> 1 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
128 is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
129 parent_is_leaf = tree.split_tree[parent_index] == 0 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
130 stray = is_not_leaf & parent_is_leaf 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
131 stray = stray.at[1].set(False) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
132 return ~jnp.any(stray) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
135@check
136def check_rule_consistency(
137 tree: TreeHeaps, max_split: UInt[Array, ' p']
138) -> bool | Bool[Array, '']:
139 """Check that decision rules define proper subsets of ancestor rules."""
140 if tree.var_tree.size < 4: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
141 return True
143 # initial boundaries of decision rules. use extreme integers instead of 0,
144 # max_split to avoid checking if there is something out of bounds.
145 dtype = tree.split_tree.dtype 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
146 small = jnp.iinfo(dtype).min 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
147 large = jnp.iinfo(dtype).max 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
148 lower = jnp.full(max_split.size, small, dtype) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
149 upper = jnp.full(max_split.size, large, dtype) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
150 # the split must be in (lower[var], upper[var]]
152 def _check_recursive( 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
153 node: int, lower: UInt[Array, ' p'], upper: UInt[Array, ' p']
154 ) -> Bool[Array, '']:
155 # read decision rule
156 var = tree.var_tree[node] 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
157 split = tree.split_tree[node] 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
159 # get rule boundaries from ancestors. use fill value in case var is
160 # out of bounds, we don't want to check out of bounds in this function
161 lower_var = lower.at[var].get(mode='fill', fill_value=small) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
162 upper_var = upper.at[var].get(mode='fill', fill_value=large) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
164 # check rule is in bounds
165 bad = jnp.where(split, (split <= lower_var) | (split > upper_var), False) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
167 # recurse
168 if node < tree.var_tree.size // 2: 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
169 idx = jnp.where(split, var, max_split.size) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
170 bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split - 1)) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
171 bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
173 return bad 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
175 return ~_check_recursive(1, lower, upper) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
178@check
179def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
180 """Check that #leaves = 1 + #(internal nodes)."""
181 is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
182 num_leaves = jnp.count_nonzero(is_leaf) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
183 num_internal = jnp.count_nonzero(tree.split_tree) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
184 return num_leaves == num_internal + 1 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
187@check
188def check_var_in_bounds(
189 tree: TreeHeaps, max_split: UInt[Array, ' p']
190) -> Bool[Array, '']:
191 """Check that variables are in [0, max_split.size)."""
192 decision_node = tree.split_tree.astype(bool) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
193 in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
194 return jnp.all(in_bounds | ~decision_node) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
197@check
198def check_split_in_bounds(
199 tree: TreeHeaps, max_split: UInt[Array, ' p']
200) -> Bool[Array, '']:
201 """Check that splits are in [0, max_split[var]]."""
202 max_split_var = ( 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
203 max_split.astype(jnp.int32)
204 .at[tree.var_tree]
205 .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max)
206 )
207 return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var)) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
210def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']:
211 """Check the validity of a tree.
213 Use `describe_error` to parse the error code returned by this function.
215 Parameters
216 ----------
217 tree
218 The tree to check.
219 max_split
220 The maximum split value for each variable.
222 Returns
223 -------
224 An integer where each bit indicates whether a check failed.
225 """
226 error_type = minimal_unsigned_dtype(2 ** len(CHECK_FUNCTIONS) - 1) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
227 error = error_type(0) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
228 for i, func in enumerate(CHECK_FUNCTIONS): 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
229 ok = func(tree, max_split) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
230 ok = jnp.bool_(ok) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
231 bit = (~ok) << i 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
232 error |= bit 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
233 return error 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
236def describe_error(error: int | Integer[Array, '']) -> list[str]:
237 """Describe an error code returned by `check_trace`.
239 Parameters
240 ----------
241 error
242 An error code returned by `check_trace`.
244 Returns
245 -------
246 A list of the function names that implement the failed checks.
247 """
248 return [func.__name__ for i, func in enumerate(CHECK_FUNCTIONS) if error & (1 << i)] 1GMHNIJOKPQRSTUVW
251@jit
252def check_trace(
253 trace: TreeHeaps, max_split: UInt[Array, ' p']
254) -> UInt[Array, '*batch_shape']:
255 """Check the validity of a set of trees.
257 Use `describe_error` to parse the error codes returned by this function.
259 Parameters
260 ----------
261 trace
262 The set of trees to check. This object can have additional attributes
263 beyond the tree arrays, they are ignored.
264 max_split
265 The maximum split value for each variable.
267 Returns
268 -------
269 A tensor of error codes for each tree.
270 """
271 # vectorize check_tree over all batch dimensions
272 unpack_check_tree = lambda l, v, s: check_tree(TreesTrace(l, v, s), max_split) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
273 is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
274 signature = '(k,ts),(hts),(hts)->()' if is_mv else '(ts),(hts),(hts)->()' 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
275 vec_check_tree = jnp.vectorize(unpack_check_tree, signature=signature) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
277 # automatically batch over all batch dimensions
278 max_io_nbytes = 2**24 # 16 MiB 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
279 batch_ndim = trace.split_tree.ndim - 1 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
280 batched_check_tree = vec_check_tree 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
281 for i in reversed(range(batch_ndim)): 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK
282 batched_check_tree = autobatch(batched_check_tree, max_io_nbytes, i, i) 1abcdefghijklmnopqrstuvwxyzABCDEF
284 return batched_check_tree(trace.leaf_tree, trace.var_tree, trace.split_tree) 1abcdefghijklmnopqrstuvwxyzABCDEFGHIJK