Coverage for src / bartz / debug / _traceconv.py: 88%

94 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/debug/_traceconv.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Debugging utilities. The main functionality is the class `debug_mc_gbart`.""" 

26 

27from math import ceil, log2 

28from re import fullmatch 

29 

30import numpy 

31from equinox import Module, field 

32from jax import numpy as jnp 

33from jaxtyping import Array, Float32, UInt 

34 

35from bartz.BART._gbart import FloatLike 

36from bartz.grove import TreeHeaps 

37from bartz.jaxext import minimal_unsigned_dtype 

38 

39 

40def _get_next_line(s: str, i: int) -> tuple[str, int]: 

41 """Get the next line from a string and the new index.""" 

42 i_new = s.find('\n', i) 1abc

43 if i_new == -1: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true1abc

44 return s[i:], len(s) 

45 return s[i:i_new], i_new + 1 1abc

46 

47 

48class BARTTraceMeta(Module): 

49 """Metadata of R BART tree traces.""" 

50 

51 ndpost: int = field(static=True) 

52 """The number of posterior draws.""" 

53 

54 ntree: int = field(static=True) 

55 """The number of trees in the model.""" 

56 

57 numcut: UInt[Array, ' p'] 

58 """The maximum split value for each variable.""" 

59 

60 heap_size: int = field(static=True) 

61 """The size of the heap required to store the trees.""" 

62 

63 

64def scan_BART_trees(trees: str) -> BARTTraceMeta: 

65 """Scan an R BART tree trace checking for errors and parsing metadata. 

66 

67 Parameters 

68 ---------- 

69 trees 

70 The string representation of a trace of trees of the R BART package. 

71 Can be accessed from ``mc_gbart(...).treedraws['trees']``. 

72 

73 Returns 

74 ------- 

75 An object containing the metadata. 

76 

77 Raises 

78 ------ 

79 ValueError 

80 If the string is malformed or contains leftover characters. 

81 """ 

82 # parse first line 

83 line, i_char = _get_next_line(trees, 0) 1abc

84 i_line = 1 1abc

85 match = fullmatch(r'(\d+) (\d+) (\d+)', line) 1abc

86 if match is None: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true1abc

87 msg = f'Malformed header at {i_line=}' 

88 raise ValueError(msg) 

89 ndpost, ntree, p = map(int, match.groups()) 1abc

90 

91 # initial values for maxima 

92 max_heap_index = 0 1abc

93 numcut = numpy.zeros(p, int) 1abc

94 

95 # cycle over iterations and trees 

96 for i_iter in range(ndpost): 1abc

97 for i_tree in range(ntree): 1abc

98 # parse first line of tree definition 

99 line, i_char = _get_next_line(trees, i_char) 1abc

100 i_line += 1 1abc

101 match = fullmatch(r'(\d+)', line) 1abc

102 if match is None: 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true1abc

103 msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}' 

104 raise ValueError(msg) 

105 num_nodes = int(line) 1abc

106 

107 # cycle over nodes 

108 for i_node in range(num_nodes): 1abc

109 # parse node definition 

110 line, i_char = _get_next_line(trees, i_char) 1abc

111 i_line += 1 1abc

112 match = fullmatch( 1abc

113 r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line 

114 ) 

115 if match is None: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true1abc

116 msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}' 

117 raise ValueError(msg) 

118 i_heap = int(match.group(1)) 1abc

119 var = int(match.group(2)) 1abc

120 split = int(match.group(3)) 1abc

121 

122 # update maxima 

123 numcut[var] = max(numcut[var], split) 1abc

124 max_heap_index = max(max_heap_index, i_heap) 1abc

125 

126 assert i_char <= len(trees) 1abc

127 if i_char < len(trees): 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true1abc

128 msg = f'Leftover {len(trees) - i_char} characters in string' 

129 raise ValueError(msg) 

130 

131 # determine minimal integer type for numcut 

132 numcut += 1 # because BART is 0-based 1abc

133 split_dtype = minimal_unsigned_dtype(numcut.max()) 1abc

134 numcut = jnp.array(numcut.astype(split_dtype)) 1abc

135 

136 # determine minimum heap size to store the trees 

137 heap_size = 2 ** ceil(log2(max_heap_index + 1)) 1abc

138 

139 return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size) 1abc

140 

141 

142class TraceWithOffset(Module): 

143 """Implementation of `bartz.mcmcloop.Trace`.""" 

144 

145 leaf_tree: Float32[Array, 'ndpost ntree 2**d'] 

146 var_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] 

147 split_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] 

148 offset: Float32[Array, ' ndpost'] 

149 

150 @classmethod 

151 def from_trees_trace( 

152 cls, trees: TreeHeaps, offset: Float32[Array, ''] 

153 ) -> 'TraceWithOffset': 

154 """Create a `TraceWithOffset` from a `TreeHeaps`.""" 

155 ndpost, _, _ = trees.leaf_tree.shape 1efgh

156 return cls( 1efgh

157 leaf_tree=trees.leaf_tree, 

158 var_tree=trees.var_tree, 

159 split_tree=trees.split_tree, 

160 offset=jnp.full(ndpost, offset), 

161 ) 

162 

163 

164def trees_BART_to_bartz( 

165 trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None 

166) -> tuple[TraceWithOffset, BARTTraceMeta]: 

167 """Convert trees from the R BART format to the bartz format. 

168 

169 Parameters 

170 ---------- 

171 trees 

172 The string representation of a trace of trees of the R BART package. 

173 Can be accessed from ``mc_gbart(...).treedraws['trees']``. 

174 min_maxdepth 

175 The maximum tree depth of the output will be set to the maximum 

176 observed depth in the input trees. Use this parameter to require at 

177 least this maximum depth in the output format. 

178 offset 

179 The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be 

180 summed to the sum of trees. To match that behavior, this function 

181 returns an offset as well, zero by default. Set with this parameter 

182 otherwise. 

183 

184 Returns 

185 ------- 

186 trace : TraceWithOffset 

187 A representation of the trees compatible with the trace returned by 

188 `bartz.mcmcloop.run_mcmc`. 

189 meta : BARTTraceMeta 

190 The metadata of the trace, containing the number of iterations, trees, 

191 and the maximum split value. 

192 """ 

193 # scan all the string checking for errors and determining sizes 

194 meta = scan_BART_trees(trees) 1abc

195 

196 # skip first line 

197 _, i_char = _get_next_line(trees, 0) 1abc

198 

199 heap_size = max(meta.heap_size, 2**min_maxdepth) 1abc

200 leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32) 1abc

201 var_trees = numpy.zeros( 1abc

202 (meta.ndpost, meta.ntree, heap_size // 2), 

203 dtype=minimal_unsigned_dtype(meta.numcut.size - 1), 

204 ) 

205 split_trees = numpy.zeros( 1abc

206 (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype 

207 ) 

208 

209 # cycle over iterations and trees 

210 for i_iter in range(meta.ndpost): 1abc

211 for i_tree in range(meta.ntree): 1abc

212 # parse first line of tree definition 

213 line, i_char = _get_next_line(trees, i_char) 1abc

214 num_nodes = int(line) 1abc

215 

216 is_internal = numpy.zeros(heap_size // 2, dtype=bool) 1abc

217 

218 # cycle over nodes 

219 for _ in range(num_nodes): 1abc

220 # parse node definition 

221 line, i_char = _get_next_line(trees, i_char) 1abc

222 values = line.split() 1abc

223 i_heap = int(values[0]) 1abc

224 var = int(values[1]) 1abc

225 split = int(values[2]) 1abc

226 leaf = float(values[3]) 1abc

227 

228 # update values 

229 leaf_trees[i_iter, i_tree, i_heap] = leaf 1abc

230 is_internal[i_heap // 2] = True 1abc

231 if i_heap < heap_size // 2: 1abc

232 var_trees[i_iter, i_tree, i_heap] = var 1abc

233 split_trees[i_iter, i_tree, i_heap] = split + 1 1abc

234 

235 is_internal[0] = False 1abc

236 split_trees[i_iter, i_tree, ~is_internal] = 0 1abc

237 

238 return TraceWithOffset( 1abc

239 leaf_tree=jnp.array(leaf_trees), 

240 var_tree=jnp.array(var_trees), 

241 split_tree=jnp.array(split_trees), 

242 offset=jnp.zeros(meta.ndpost) 

243 if offset is None 

244 else jnp.full(meta.ndpost, offset), 

245 ), meta