Coverage for src / bartz / debug / _debuggbart.py: 47%
66 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
1# bartz/src/bartz/debug/_debuggbart.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Debugging utilities. The main functionality is the class `debug_mc_gbart`."""
27from typing import Any
29from jax import numpy as jnp
30from jax.tree_util import tree_map
31from jaxtyping import Array, Bool, Float32, Int32, UInt
33from bartz.BART import gbart, mc_gbart
34from bartz.grove import TreesTrace, format_tree
37class debug_mc_gbart(mc_gbart):
38 """A subclass of `mc_gbart` that adds debugging functionality.
40 Parameters
41 ----------
42 *args
43 Passed to `mc_gbart`.
44 check_trees
45 If `True`, check all trees with `check_trace` after running the MCMC,
46 and assert that they are all valid.
47 check_replicated_trees
48 If the data is sharded across devices, check that the trees are equal
49 on all devices in the final state. Set to `False` to allow jax tracing.
50 **kwargs
51 Passed to `mc_gbart`.
52 """
54 def __init__(
55 self,
56 *args: Any,
57 check_trees: bool = True,
58 check_replicated_trees: bool = True,
59 **kwargs: Any,
60 ) -> None:
61 super().__init__(*args, **kwargs) 1b
62 if check_trees: 1bc
63 self._bart.check_trees(error=True) 1b
64 if check_replicated_trees: 1bc
65 self._bart.check_replicated_trees() 1b
67 def print_tree(
68 self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False
69 ) -> None:
70 """Print a single tree in human-readable format.
72 Parameters
73 ----------
74 i_chain
75 The index of the MCMC chain.
76 i_sample
77 The index of the (post-burnin) sample in the chain.
78 i_tree
79 The index of the tree in the sample.
80 print_all
81 If `True`, also print the content of unused node slots.
82 """
83 tree = TreesTrace.from_dataclass(self._main_trace)
84 tree = tree_map(lambda x: x[i_chain, i_sample, i_tree, :], tree)
85 s = format_tree(tree, print_all=print_all)
86 print(s) # noqa: T201, this method is intended for debug
88 def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, ' mc_cores']:
89 """Return the harmonic mean of the error variance.
91 Parameters
92 ----------
93 prior
94 If `True`, use the prior distribution, otherwise use the full
95 conditional at the last MCMC iteration.
97 Returns
98 -------
99 The harmonic mean 1/E[1/sigma^2] in the selected distribution.
100 """
101 bart = self._mcmc_state
102 assert bart.error_cov_df is not None
103 assert bart.z is None
104 # inverse gamma prior: alpha = df / 2, beta = scale / 2
105 if prior:
106 alpha = bart.error_cov_df / 2
107 beta = bart.error_cov_scale / 2
108 else:
109 alpha = bart.error_cov_df / 2 + bart.resid.size / 2
110 norm2 = jnp.einsum('ij,ij->i', bart.resid, bart.resid)
111 beta = bart.error_cov_scale / 2 + norm2 / 2
112 error_cov_inv = alpha / beta
113 return jnp.sqrt(jnp.reciprocal(error_cov_inv))
115 def compare_resid(
116 self, y: Float32[Array, ' n'] | Float32[Array, 'k n'] | None = None
117 ) -> tuple[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]:
118 """Re-compute residuals to compare them with the updated ones."""
119 return self._bart.compare_resid(y) 1d
121 def avg_acc(
122 self,
123 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
124 """Compute the average acceptance rates of tree moves.
126 Returns
127 -------
128 acc_grow : Float32[Array, 'mc_cores']
129 The average acceptance rate of grow moves.
130 acc_prune : Float32[Array, 'mc_cores']
131 The average acceptance rate of prune moves.
132 """
133 trace = self._main_trace
135 def acc(prefix: str) -> Float32[Array, ' mc_cores']:
136 acc = getattr(trace, f'{prefix}_acc_count')
137 prop = getattr(trace, f'{prefix}_prop_count')
138 return acc.sum(axis=1) / prop.sum(axis=1)
140 return acc('grow'), acc('prune')
142 def avg_prop(
143 self,
144 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
145 """Compute the average proposal rate of grow and prune moves.
147 Returns
148 -------
149 prop_grow : Float32[Array, 'mc_cores']
150 The fraction of times grow was proposed instead of prune.
151 prop_prune : Float32[Array, 'mc_cores']
152 The fraction of times prune was proposed instead of grow.
154 Notes
155 -----
156 This function does not take into account cases where no move was
157 proposed.
158 """
159 trace = self._main_trace
161 def prop(prefix: str) -> Array:
162 return getattr(trace, f'{prefix}_prop_count').sum(axis=1)
164 pgrow = prop('grow')
165 pprune = prop('prune')
166 total = pgrow + pprune
167 return pgrow / total, pprune / total
169 def avg_move(
170 self,
171 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
172 """Compute the move rate.
174 Returns
175 -------
176 rate_grow : Float32[Array, 'mc_cores']
177 The fraction of times a grow move was proposed and accepted.
178 rate_prune : Float32[Array, 'mc_cores']
179 The fraction of times a prune move was proposed and accepted.
180 """
181 agrow, aprune = self.avg_acc()
182 pgrow, pprune = self.avg_prop()
183 return agrow * pgrow, aprune * pprune
185 def depth_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores d']:
186 """Histogram of tree depths for each state of the trees."""
187 return self._bart.depth_distr() 1e
189 def _points_per_node_distr(
190 self, node_type: str
191 ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
192 return self._bart._points_per_node_distr(node_type) # noqa: SLF001
194 def points_per_decision_node_distr(
195 self,
196 ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
197 """Histogram of number of points belonging to parent-of-leaf nodes."""
198 return self._bart.points_per_decision_node_distr() 1f
200 def points_per_leaf_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
201 """Histogram of number of points belonging to leaves."""
202 return self._bart.points_per_leaf_distr() 1g
204 def check_trees(self) -> UInt[Array, 'mc_cores ndpost/mc_cores ntree']:
205 """Apply `check_trace` to all the tree draws."""
206 return self._bart.check_trees()
208 def tree_goes_bad(self) -> Bool[Array, 'mc_cores ndpost/mc_cores ntree']:
209 """Find iterations where a tree becomes invalid.
211 Returns
212 -------
213 A where (i,j) is `True` if tree j is invalid at iteration i but not i-1.
214 """
215 bad = self.check_trees().astype(bool)
216 bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)])
217 return bad & ~bad_before
220class debug_gbart(debug_mc_gbart, gbart):
221 """A subclass of `gbart` that adds debugging functionality.
223 Parameters
224 ----------
225 *args
226 Passed to `gbart`.
227 check_trees
228 If `True`, check all trees with `check_trace` after running the MCMC,
229 and assert that they are all valid.
230 check_replicated_trees
231 If the data is sharded across devices, check that the trees are equal
232 on all devices in the final state. Set to `False` to allow jax tracing.
233 **kw
234 Passed to `gbart`.
235 """