Coverage for src / bartz / debug / _debuggbart.py: 66%
98 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/debug/_debuggbart.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Debugging utilities. The main functionality is the class `debug_mc_gbart`."""
27from dataclasses import replace
28from typing import Any
30from equinox import error_if
31from jax import numpy as jnp
32from jax import tree
33from jax.sharding import PartitionSpec
34from jax.tree_util import tree_map
35from jaxtyping import Array, Bool, Float32, Int32, UInt
37from bartz.BART import gbart, mc_gbart
38from bartz.debug._check import check_trace
39from bartz.grove import (
40 evaluate_forest,
41 forest_depth_distr,
42 format_tree,
43 points_per_node_distr,
44)
45from bartz.jaxext import equal_shards
46from bartz.mcmcloop import TreesTrace
49class debug_mc_gbart(mc_gbart):
50 """A subclass of `mc_gbart` that adds debugging functionality.
52 Parameters
53 ----------
54 *args
55 Passed to `mc_gbart`.
56 check_trees
57 If `True`, check all trees with `check_trace` after running the MCMC,
58 and assert that they are all valid.
59 check_replicated_trees
60 If the data is sharded across devices, check that the trees are equal
61 on all devices in the final state. Set to `False` to allow jax tracing.
62 **kwargs
63 Passed to `mc_gbart`.
64 """
66 def __init__(
67 self,
68 *args: Any,
69 check_trees: bool = True,
70 check_replicated_trees: bool = True,
71 **kwargs: Any,
72 ) -> None:
73 super().__init__(*args, **kwargs) 1dIJKeLM?@[sNOfPQgRShTUtuviVW]^_`EXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-{.
75 if check_trees: 75 ↛ 82line 75 didn't jump to line 82 because the condition on line 75 was always true1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.
76 bad = self.check_trees() 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.
77 bad_count = jnp.count_nonzero(bad) 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.
78 self._bart.__dict__['offset'] = error_if( 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.
79 self._bart.offset, bad_count > 0, 'invalid trees found in trace'
80 )
82 state = self._mcmc_state 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.
83 mesh = state.config.mesh 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.
84 if check_replicated_trees and mesh is not None and 'data' in mesh.axis_names: 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.
85 replicated_forest = replace(state.forest, leaf_indices=None) 1desfghtuvibcjklmnopqr
86 equal = equal_shards( 1desfghtuvibcjklmnopqr
87 replicated_forest, 'data', in_specs=PartitionSpec(), mesh=mesh
88 )
89 equal_array = jnp.stack(tree.leaves(equal)) 1desfghtuvibcjklmnopqr
90 all_equal = jnp.all(equal_array) 1desfghtuvibcjklmnopqr
91 # we could use error_if here for traceability, but last time we
92 # tried it hanged on error, maybe it was due to sharding.
93 assert all_equal.item(), 'the trees are different across devices' 1desfghtuvibcjklmnopqr
95 def print_tree(
96 self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False
97 ) -> None:
98 """Print a single tree in human-readable format.
100 Parameters
101 ----------
102 i_chain
103 The index of the MCMC chain.
104 i_sample
105 The index of the (post-burnin) sample in the chain.
106 i_tree
107 The index of the tree in the sample.
108 print_all
109 If `True`, also print the content of unused node slots.
110 """
111 tree = TreesTrace.from_dataclass(self._main_trace)
112 tree = tree_map(lambda x: x[i_chain, i_sample, i_tree, :], tree)
113 s = format_tree(tree, print_all=print_all)
114 print(s) # noqa: T201, this method is intended for debug
116 def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, ' mc_cores']:
117 """Return the harmonic mean of the error variance.
119 Parameters
120 ----------
121 prior
122 If `True`, use the prior distribution, otherwise use the full
123 conditional at the last MCMC iteration.
125 Returns
126 -------
127 The harmonic mean 1/E[1/sigma^2] in the selected distribution.
128 """
129 bart = self._mcmc_state
130 assert bart.error_cov_df is not None
131 assert bart.z is None
132 # inverse gamma prior: alpha = df / 2, beta = scale / 2
133 if prior:
134 alpha = bart.error_cov_df / 2
135 beta = bart.error_cov_scale / 2
136 else:
137 alpha = bart.error_cov_df / 2 + bart.resid.size / 2
138 norm2 = jnp.einsum('ij,ij->i', bart.resid, bart.resid)
139 beta = bart.error_cov_scale / 2 + norm2 / 2
140 error_cov_inv = alpha / beta
141 return jnp.sqrt(jnp.reciprocal(error_cov_inv))
143 def compare_resid(
144 self,
145 ) -> tuple[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]:
146 """Re-compute residuals to compare them with the updated ones.
148 Returns
149 -------
150 resid1 : Float32[Array, 'mc_cores n']
151 The final state of the residuals updated during the MCMC.
152 resid2 : Float32[Array, 'mc_cores n']
153 The residuals computed from the final state of the trees.
154 """
155 bart = self._mcmc_state 1/:;
156 resid1 = bart.resid 1/:;
158 forests = TreesTrace.from_dataclass(bart.forest) 1/:;
159 trees = evaluate_forest(bart.X, forests, sum_batch_axis=-1) 1/:;
161 if bart.z is not None: 1/:;
162 ref = bart.z 1:
163 else:
164 ref = bart.y 1/;
165 resid2 = ref - (trees + bart.offset) 1/:;
167 return resid1, resid2 1/:;
169 def avg_acc(
170 self,
171 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
172 """Compute the average acceptance rates of tree moves.
174 Returns
175 -------
176 acc_grow : Float32[Array, 'mc_cores']
177 The average acceptance rate of grow moves.
178 acc_prune : Float32[Array, 'mc_cores']
179 The average acceptance rate of prune moves.
180 """
181 trace = self._main_trace
183 def acc(prefix: str) -> Float32[Array, ' mc_cores']:
184 acc = getattr(trace, f'{prefix}_acc_count')
185 prop = getattr(trace, f'{prefix}_prop_count')
186 return acc.sum(axis=1) / prop.sum(axis=1)
188 return acc('grow'), acc('prune')
190 def avg_prop(
191 self,
192 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
193 """Compute the average proposal rate of grow and prune moves.
195 Returns
196 -------
197 prop_grow : Float32[Array, 'mc_cores']
198 The fraction of times grow was proposed instead of prune.
199 prop_prune : Float32[Array, 'mc_cores']
200 The fraction of times prune was proposed instead of grow.
202 Notes
203 -----
204 This function does not take into account cases where no move was
205 proposed.
206 """
207 trace = self._main_trace
209 def prop(prefix: str) -> Array:
210 return getattr(trace, f'{prefix}_prop_count').sum(axis=1)
212 pgrow = prop('grow')
213 pprune = prop('prune')
214 total = pgrow + pprune
215 return pgrow / total, pprune / total
217 def avg_move(
218 self,
219 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
220 """Compute the move rate.
222 Returns
223 -------
224 rate_grow : Float32[Array, 'mc_cores']
225 The fraction of times a grow move was proposed and accepted.
226 rate_prune : Float32[Array, 'mc_cores']
227 The fraction of times a prune move was proposed and accepted.
228 """
229 agrow, aprune = self.avg_acc()
230 pgrow, pprune = self.avg_prop()
231 return agrow * pgrow, aprune * pprune
233 def depth_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores d']:
234 """Histogram of tree depths for each state of the trees.
236 Returns
237 -------
238 A matrix where each row contains a histogram of tree depths.
239 """
240 out: Int32[Array, '*chains samples d']
241 out = forest_depth_distr(self._main_trace.split_tree) 1wxy
242 if out.ndim < 3: 242 ↛ 244line 242 didn't jump to line 244 because the condition on line 242 was always true1wxy
243 out = out[None, :, :] 1wxy
244 return out 1wxy
246 def _points_per_node_distr(
247 self, node_type: str
248 ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
249 out: Int32[Array, '*chains samples n+1']
250 out = points_per_node_distr( 1bzAcBC
251 self._mcmc_state.X,
252 self._main_trace.var_tree,
253 self._main_trace.split_tree,
254 node_type,
255 sum_batch_axis=-1,
256 )
257 if out.ndim < 3: 1bzAcBCD
258 out = out[None, :, :] 1bc
259 return out 1bzAcBCD
261 def points_per_decision_node_distr(
262 self,
263 ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
264 """Histogram of number of points belonging to parent-of-leaf nodes.
266 Returns
267 -------
268 For each chain, a matrix where each row contains a histogram of number of points.
269 """
270 return self._points_per_node_distr('leaf-parent') 1bzA
272 def points_per_leaf_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
273 """Histogram of number of points belonging to leaves.
275 Returns
276 -------
277 A matrix where each row contains a histogram of number of points.
278 """
279 return self._points_per_node_distr('leaf') 1cBC
281 def check_trees(self) -> UInt[Array, 'mc_cores ndpost/mc_cores ntree']:
282 """Apply `check_trace` to all the tree draws."""
283 out: UInt[Array, '*chains samples num_trees']
284 out = check_trace(self._main_trace, self._mcmc_state.forest.max_split) 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.
285 if out.ndim < 3: 1dIJKeLM=sNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.
286 out = out[None, :, :] 1defghiEbcFjGklmnHwxyopqr
287 return out 1dIJKeLM=sNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.
289 def tree_goes_bad(self) -> Bool[Array, 'mc_cores ndpost/mc_cores ntree']:
290 """Find iterations where a tree becomes invalid.
292 Returns
293 -------
294 A where (i,j) is `True` if tree j is invalid at iteration i but not i-1.
295 """
296 bad = self.check_trees().astype(bool)
297 bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)])
298 return bad & ~bad_before
301class debug_gbart(debug_mc_gbart, gbart):
302 """A subclass of `gbart` that adds debugging functionality.
304 Parameters
305 ----------
306 *args
307 Passed to `gbart`.
308 check_trees
309 If `True`, check all trees with `check_trace` after running the MCMC,
310 and assert that they are all valid.
311 check_replicated_trees
312 If the data is sharded across devices, check that the trees are equal
313 on all devices in the final state. Set to `False` to allow jax tracing.
314 **kw
315 Passed to `gbart`.
316 """