Coverage for src / bartz / debug / _debuggbart.py: 47%

66 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 18:11 +0000

1# bartz/src/bartz/debug/_debuggbart.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Debugging utilities. The main functionality is the class `debug_mc_gbart`.""" 

26 

27from typing import Any 

28 

29from jax import numpy as jnp 

30from jax.tree_util import tree_map 

31from jaxtyping import Array, Bool, Float32, Int32, UInt 

32 

33from bartz.BART import gbart, mc_gbart 

34from bartz.grove import TreesTrace, format_tree 

35 

36 

37class debug_mc_gbart(mc_gbart): 

38 """A subclass of `mc_gbart` that adds debugging functionality. 

39 

40 Parameters 

41 ---------- 

42 *args 

43 Passed to `mc_gbart`. 

44 check_trees 

45 If `True`, check all trees with `check_trace` after running the MCMC, 

46 and assert that they are all valid. 

47 check_replicated_trees 

48 If the data is sharded across devices, check that the trees are equal 

49 on all devices in the final state. Set to `False` to allow jax tracing. 

50 **kwargs 

51 Passed to `mc_gbart`. 

52 """ 

53 

54 def __init__( 

55 self, 

56 *args: Any, 

57 check_trees: bool = True, 

58 check_replicated_trees: bool = True, 

59 **kwargs: Any, 

60 ) -> None: 

61 super().__init__(*args, **kwargs) 1b

62 if check_trees: 1bc

63 self._bart.check_trees(error=True) 1b

64 if check_replicated_trees: 1bc

65 self._bart.check_replicated_trees() 1b

66 

67 def print_tree( 

68 self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False 

69 ) -> None: 

70 """Print a single tree in human-readable format. 

71 

72 Parameters 

73 ---------- 

74 i_chain 

75 The index of the MCMC chain. 

76 i_sample 

77 The index of the (post-burnin) sample in the chain. 

78 i_tree 

79 The index of the tree in the sample. 

80 print_all 

81 If `True`, also print the content of unused node slots. 

82 """ 

83 tree = TreesTrace.from_dataclass(self._main_trace) 

84 tree = tree_map(lambda x: x[i_chain, i_sample, i_tree, :], tree) 

85 s = format_tree(tree, print_all=print_all) 

86 print(s) # noqa: T201, this method is intended for debug 

87 

88 def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, ' mc_cores']: 

89 """Return the harmonic mean of the error variance. 

90 

91 Parameters 

92 ---------- 

93 prior 

94 If `True`, use the prior distribution, otherwise use the full 

95 conditional at the last MCMC iteration. 

96 

97 Returns 

98 ------- 

99 The harmonic mean 1/E[1/sigma^2] in the selected distribution. 

100 """ 

101 bart = self._mcmc_state 

102 assert bart.error_cov_df is not None 

103 assert bart.z is None 

104 # inverse gamma prior: alpha = df / 2, beta = scale / 2 

105 if prior: 

106 alpha = bart.error_cov_df / 2 

107 beta = bart.error_cov_scale / 2 

108 else: 

109 alpha = bart.error_cov_df / 2 + bart.resid.size / 2 

110 norm2 = jnp.einsum('ij,ij->i', bart.resid, bart.resid) 

111 beta = bart.error_cov_scale / 2 + norm2 / 2 

112 error_cov_inv = alpha / beta 

113 return jnp.sqrt(jnp.reciprocal(error_cov_inv)) 

114 

115 def compare_resid( 

116 self, y: Float32[Array, ' n'] | Float32[Array, 'k n'] | None = None 

117 ) -> tuple[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]: 

118 """Re-compute residuals to compare them with the updated ones.""" 

119 return self._bart.compare_resid(y) 1d

120 

121 def avg_acc( 

122 self, 

123 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: 

124 """Compute the average acceptance rates of tree moves. 

125 

126 Returns 

127 ------- 

128 acc_grow : Float32[Array, 'mc_cores'] 

129 The average acceptance rate of grow moves. 

130 acc_prune : Float32[Array, 'mc_cores'] 

131 The average acceptance rate of prune moves. 

132 """ 

133 trace = self._main_trace 

134 

135 def acc(prefix: str) -> Float32[Array, ' mc_cores']: 

136 acc = getattr(trace, f'{prefix}_acc_count') 

137 prop = getattr(trace, f'{prefix}_prop_count') 

138 return acc.sum(axis=1) / prop.sum(axis=1) 

139 

140 return acc('grow'), acc('prune') 

141 

142 def avg_prop( 

143 self, 

144 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: 

145 """Compute the average proposal rate of grow and prune moves. 

146 

147 Returns 

148 ------- 

149 prop_grow : Float32[Array, 'mc_cores'] 

150 The fraction of times grow was proposed instead of prune. 

151 prop_prune : Float32[Array, 'mc_cores'] 

152 The fraction of times prune was proposed instead of grow. 

153 

154 Notes 

155 ----- 

156 This function does not take into account cases where no move was 

157 proposed. 

158 """ 

159 trace = self._main_trace 

160 

161 def prop(prefix: str) -> Array: 

162 return getattr(trace, f'{prefix}_prop_count').sum(axis=1) 

163 

164 pgrow = prop('grow') 

165 pprune = prop('prune') 

166 total = pgrow + pprune 

167 return pgrow / total, pprune / total 

168 

169 def avg_move( 

170 self, 

171 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]: 

172 """Compute the move rate. 

173 

174 Returns 

175 ------- 

176 rate_grow : Float32[Array, 'mc_cores'] 

177 The fraction of times a grow move was proposed and accepted. 

178 rate_prune : Float32[Array, 'mc_cores'] 

179 The fraction of times a prune move was proposed and accepted. 

180 """ 

181 agrow, aprune = self.avg_acc() 

182 pgrow, pprune = self.avg_prop() 

183 return agrow * pgrow, aprune * pprune 

184 

185 def depth_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores d']: 

186 """Histogram of tree depths for each state of the trees.""" 

187 return self._bart.depth_distr() 1e

188 

189 def _points_per_node_distr( 

190 self, node_type: str 

191 ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']: 

192 return self._bart._points_per_node_distr(node_type) # noqa: SLF001 

193 

194 def points_per_decision_node_distr( 

195 self, 

196 ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']: 

197 """Histogram of number of points belonging to parent-of-leaf nodes.""" 

198 return self._bart.points_per_decision_node_distr() 1f

199 

200 def points_per_leaf_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']: 

201 """Histogram of number of points belonging to leaves.""" 

202 return self._bart.points_per_leaf_distr() 1g

203 

204 def check_trees(self) -> UInt[Array, 'mc_cores ndpost/mc_cores ntree']: 

205 """Apply `check_trace` to all the tree draws.""" 

206 return self._bart.check_trees() 

207 

208 def tree_goes_bad(self) -> Bool[Array, 'mc_cores ndpost/mc_cores ntree']: 

209 """Find iterations where a tree becomes invalid. 

210 

211 Returns 

212 ------- 

213 A where (i,j) is `True` if tree j is invalid at iteration i but not i-1. 

214 """ 

215 bad = self.check_trees().astype(bool) 

216 bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)]) 

217 return bad & ~bad_before 

218 

219 

220class debug_gbart(debug_mc_gbart, gbart): 

221 """A subclass of `gbart` that adds debugging functionality. 

222 

223 Parameters 

224 ---------- 

225 *args 

226 Passed to `gbart`. 

227 check_trees 

228 If `True`, check all trees with `check_trace` after running the MCMC, 

229 and assert that they are all valid. 

230 check_replicated_trees 

231 If the data is sharded across devices, check that the trees are equal 

232 on all devices in the final state. Set to `False` to allow jax tracing. 

233 **kw 

234 Passed to `gbart`. 

235 """