Coverage for src / bartz / debug / _debuggbart.py: 66%

98 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/debug/_debuggbart.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Debugging utilities. The main functionality is the class `debug_mc_gbart`.""" 

26 

27from dataclasses import replace 

28from typing import Any 

29 

30from equinox import error_if 

31from jax import numpy as jnp 

32from jax import tree 

33from jax.sharding import PartitionSpec 

34from jax.tree_util import tree_map 

35from jaxtyping import Array, Bool, Float32, Int32, UInt 

36 

37from bartz.BART import gbart, mc_gbart 

38from bartz.debug._check import check_trace 

39from bartz.grove import ( 

40 evaluate_forest, 

41 forest_depth_distr, 

42 format_tree, 

43 points_per_node_distr, 

44) 

45from bartz.jaxext import equal_shards 

46from bartz.mcmcloop import TreesTrace 

47 

48 

49class debug_mc_gbart(mc_gbart): 

50 """A subclass of `mc_gbart` that adds debugging functionality. 

51 

52 Parameters 

53 ---------- 

54 *args 

55 Passed to `mc_gbart`. 

56 check_trees 

57 If `True`, check all trees with `check_trace` after running the MCMC, 

58 and assert that they are all valid. 

59 check_replicated_trees 

60 If the data is sharded across devices, check that the trees are equal 

61 on all devices in the final state. Set to `False` to allow jax tracing. 

62 **kwargs 

63 Passed to `mc_gbart`. 

64 """ 

65 

66 def __init__( 

67 self, 

68 *args: Any, 

69 check_trees: bool = True, 

70 check_replicated_trees: bool = True, 

71 **kwargs: Any, 

72 ) -> None: 

73 super().__init__(*args, **kwargs) 1dIJKeLM?@[sNOfPQgRShTUtuviVW]^_`EXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-{.

74 

75 if check_trees: 75 ↛ 82line 75 didn't jump to line 82 because the condition on line 75 was always true1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.

76 bad = self.check_trees() 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.

77 bad_count = jnp.count_nonzero(bad) 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.

78 self._bart.__dict__['offset'] = error_if( 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.

79 self._bart.offset, bad_count > 0, 'invalid trees found in trace' 

80 ) 

81 

82 state = self._mcmc_state 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.

83 mesh = state.config.mesh 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.

84 if check_replicated_trees and mesh is not None and 'data' in mesh.axis_names: 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.

85 replicated_forest = replace(state.forest, leaf_indices=None) 1desfghtuvibcjklmnopqr

86 equal = equal_shards( 1desfghtuvibcjklmnopqr

87 replicated_forest, 'data', in_specs=PartitionSpec(), mesh=mesh 

88 ) 

89 equal_array = jnp.stack(tree.leaves(equal)) 1desfghtuvibcjklmnopqr

90 all_equal = jnp.all(equal_array) 1desfghtuvibcjklmnopqr

91 # we could use error_if here for traceability, but last time we 

92 # tried it hanged on error, maybe it was due to sharding. 

93 assert all_equal.item(), 'the trees are different across devices' 1desfghtuvibcjklmnopqr

94 

95 def print_tree( 

96 self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False 

97 ) -> None: 

98 """Print a single tree in human-readable format. 

99 

100 Parameters 

101 ---------- 

102 i_chain 

103 The index of the MCMC chain. 

104 i_sample 

105 The index of the (post-burnin) sample in the chain. 

106 i_tree 

107 The index of the tree in the sample. 

108 print_all 

109 If `True`, also print the content of unused node slots. 

110 """ 

111 tree = TreesTrace.from_dataclass(self._main_trace) 

112 tree = tree_map(lambda x: x[i_chain, i_sample, i_tree, :], tree) 

113 s = format_tree(tree, print_all=print_all) 

114 print(s) # noqa: T201, this method is intended for debug 

115 

116 def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, ' mc_cores']: 

117 """Return the harmonic mean of the error variance. 

118 

119 Parameters 

120 ---------- 

121 prior 

122 If `True`, use the prior distribution, otherwise use the full 

123 conditional at the last MCMC iteration. 

124 

125 Returns 

126 ------- 

127 The harmonic mean 1/E[1/sigma^2] in the selected distribution. 

128 """ 

129 bart = self._mcmc_state 

130 assert bart.error_cov_df is not None 

131 assert bart.z is None 

132 # inverse gamma prior: alpha = df / 2, beta = scale / 2 

133 if prior: 

134 alpha = bart.error_cov_df / 2 

135 beta = bart.error_cov_scale / 2 

136 else: 

137 alpha = bart.error_cov_df / 2 + bart.resid.size / 2 

138 norm2 = jnp.einsum('ij,ij->i', bart.resid, bart.resid) 

139 beta = bart.error_cov_scale / 2 + norm2 / 2 

140 error_cov_inv = alpha / beta 

141 return jnp.sqrt(jnp.reciprocal(error_cov_inv)) 

142 

143 def compare_resid( 

144 self, 

145 ) -> tuple[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]: 

146 """Re-compute residuals to compare them with the updated ones. 

147 

148 Returns 

149 ------- 

150 resid1 : Float32[Array, 'mc_cores n'] 

151 The final state of the residuals updated during the MCMC. 

152 resid2 : Float32[Array, 'mc_cores n'] 

153 The residuals computed from the final state of the trees. 

154 """ 

155 bart = self._mcmc_state 1/:;

156 resid1 = bart.resid 1/:;

157 

158 forests = TreesTrace.from_dataclass(bart.forest) 1/:;

159 trees = evaluate_forest(bart.X, forests, sum_batch_axis=-1) 1/:;

160 

161 if bart.z is not None: 1/:;

162 ref = bart.z 1:

163 else: 

164 ref = bart.y 1/;

165 resid2 = ref - (trees + bart.offset) 1/:;

166 

167 return resid1, resid2 1/:;

168 

169 def avg_acc( 

170 self, 

171 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: 

172 """Compute the average acceptance rates of tree moves. 

173 

174 Returns 

175 ------- 

176 acc_grow : Float32[Array, 'mc_cores'] 

177 The average acceptance rate of grow moves. 

178 acc_prune : Float32[Array, 'mc_cores'] 

179 The average acceptance rate of prune moves. 

180 """ 

181 trace = self._main_trace 

182 

183 def acc(prefix: str) -> Float32[Array, ' mc_cores']: 

184 acc = getattr(trace, f'{prefix}_acc_count') 

185 prop = getattr(trace, f'{prefix}_prop_count') 

186 return acc.sum(axis=1) / prop.sum(axis=1) 

187 

188 return acc('grow'), acc('prune') 

189 

190 def avg_prop( 

191 self, 

192 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: 

193 """Compute the average proposal rate of grow and prune moves. 

194 

195 Returns 

196 ------- 

197 prop_grow : Float32[Array, 'mc_cores'] 

198 The fraction of times grow was proposed instead of prune. 

199 prop_prune : Float32[Array, 'mc_cores'] 

200 The fraction of times prune was proposed instead of grow. 

201 

202 Notes 

203 ----- 

204 This function does not take into account cases where no move was 

205 proposed. 

206 """ 

207 trace = self._main_trace 

208 

209 def prop(prefix: str) -> Array: 

210 return getattr(trace, f'{prefix}_prop_count').sum(axis=1) 

211 

212 pgrow = prop('grow') 

213 pprune = prop('prune') 

214 total = pgrow + pprune 

215 return pgrow / total, pprune / total 

216 

217 def avg_move( 

218 self, 

219 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: 

220 """Compute the move rate. 

221 

222 Returns 

223 ------- 

224 rate_grow : Float32[Array, 'mc_cores'] 

225 The fraction of times a grow move was proposed and accepted. 

226 rate_prune : Float32[Array, 'mc_cores'] 

227 The fraction of times a prune move was proposed and accepted. 

228 """ 

229 agrow, aprune = self.avg_acc() 

230 pgrow, pprune = self.avg_prop() 

231 return agrow * pgrow, aprune * pprune 

232 

233 def depth_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores d']: 

234 """Histogram of tree depths for each state of the trees. 

235 

236 Returns 

237 ------- 

238 A matrix where each row contains a histogram of tree depths. 

239 """ 

240 out: Int32[Array, '*chains samples d'] 

241 out = forest_depth_distr(self._main_trace.split_tree) 1wxy

242 if out.ndim < 3: 242 ↛ 244line 242 didn't jump to line 244 because the condition on line 242 was always true1wxy

243 out = out[None, :, :] 1wxy

244 return out 1wxy

245 

246 def _points_per_node_distr( 

247 self, node_type: str 

248 ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']: 

249 out: Int32[Array, '*chains samples n+1'] 

250 out = points_per_node_distr( 1bzAcBC

251 self._mcmc_state.X, 

252 self._main_trace.var_tree, 

253 self._main_trace.split_tree, 

254 node_type, 

255 sum_batch_axis=-1, 

256 ) 

257 if out.ndim < 3: 1bzAcBCD

258 out = out[None, :, :] 1bc

259 return out 1bzAcBCD

260 

261 def points_per_decision_node_distr( 

262 self, 

263 ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']: 

264 """Histogram of number of points belonging to parent-of-leaf nodes. 

265 

266 Returns 

267 ------- 

268 For each chain, a matrix where each row contains a histogram of number of points. 

269 """ 

270 return self._points_per_node_distr('leaf-parent') 1bzA

271 

272 def points_per_leaf_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']: 

273 """Histogram of number of points belonging to leaves. 

274 

275 Returns 

276 ------- 

277 A matrix where each row contains a histogram of number of points. 

278 """ 

279 return self._points_per_node_distr('leaf') 1cBC

280 

281 def check_trees(self) -> UInt[Array, 'mc_cores ndpost/mc_cores ntree']: 

282 """Apply `check_trace` to all the tree draws.""" 

283 out: UInt[Array, '*chains samples num_trees'] 

284 out = check_trace(self._main_trace, self._mcmc_state.forest.max_split) 1dIJKeLMsNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.

285 if out.ndim < 3: 1dIJKeLM=sNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.

286 out = out[None, :, :] 1defghiEbcFjGklmnHwxyopqr

287 return out 1dIJKeLM=sNOfPQgRShTUtuviVWEXYbzAcBCFZ0j12G34k56l78m9!n#$Hwxyo%p'(q)*rD+,-.

288 

289 def tree_goes_bad(self) -> Bool[Array, 'mc_cores ndpost/mc_cores ntree']: 

290 """Find iterations where a tree becomes invalid. 

291 

292 Returns 

293 ------- 

294 A where (i,j) is `True` if tree j is invalid at iteration i but not i-1. 

295 """ 

296 bad = self.check_trees().astype(bool) 

297 bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)]) 

298 return bad & ~bad_before 

299 

300 

301class debug_gbart(debug_mc_gbart, gbart): 

302 """A subclass of `gbart` that adds debugging functionality. 

303 

304 Parameters 

305 ---------- 

306 *args 

307 Passed to `gbart`. 

308 check_trees 

309 If `True`, check all trees with `check_trace` after running the MCMC, 

310 and assert that they are all valid. 

311 check_replicated_trees 

312 If the data is sharded across devices, check that the trees are equal 

313 on all devices in the final state. Set to `False` to allow jax tracing. 

314 **kw 

315 Passed to `gbart`. 

316 """