Coverage for src / bartz / debug / _traceconv.py: 88%

94 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 18:11 +0000

1# bartz/src/bartz/debug/_traceconv.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Debugging utilities. The main functionality is the class `debug_mc_gbart`.""" 

26 

27from math import ceil, log2 

28from re import fullmatch 

29 

30import numpy 

31from equinox import Module, field 

32from jax import numpy as jnp 

33from jaxtyping import Array, Float32, UInt 

34 

35from bartz.BART._gbart import FloatLike 

36from bartz.grove import TreeHeaps 

37from bartz.jaxext import minimal_unsigned_dtype 

38 

39 

40def _get_next_line(s: str, i: int) -> tuple[str, int]: 

41 """Get the next line from a string and the new index.""" 

42 i_new = s.find('\n', i) 1a

43 if i_new == -1: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true1a

44 return s[i:], len(s) 

45 return s[i:i_new], i_new + 1 1a

46 

47 

48class BARTTraceMeta(Module): 

49 """Metadata of R BART tree traces.""" 

50 

51 ndpost: int = field(static=True) 

52 """The number of posterior draws.""" 

53 

54 ntree: int = field(static=True) 

55 """The number of trees in the model.""" 

56 

57 numcut: UInt[Array, ' p'] 

58 """The maximum split value for each variable.""" 

59 

60 heap_size: int = field(static=True) 

61 """The size of the heap required to store the trees.""" 

62 

63 

64def scan_BART_trees(trees: str) -> BARTTraceMeta: 

65 """Scan an R BART tree trace checking for errors and parsing metadata. 

66 

67 Parameters 

68 ---------- 

69 trees 

70 The string representation of a trace of trees of the R BART package. 

71 Can be accessed from ``mc_gbart(...).treedraws['trees']``. 

72 

73 Returns 

74 ------- 

75 An object containing the metadata. 

76 

77 Raises 

78 ------ 

79 ValueError 

80 If the string is malformed or contains leftover characters. 

81 """ 

82 # parse first line 

83 line, i_char = _get_next_line(trees, 0) 1a

84 i_line = 1 1a

85 match = fullmatch(r'(\d+) (\d+) (\d+)', line) 1a

86 if match is None: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true1a

87 msg = f'Malformed header at {i_line=}' 

88 raise ValueError(msg) 

89 ndpost, ntree, p = map(int, match.groups()) 1a

90 

91 # initial values for maxima 

92 max_heap_index = 0 1a

93 numcut = numpy.zeros(p, int) 1a

94 

95 # cycle over iterations and trees 

96 for i_iter in range(ndpost): 1a

97 for i_tree in range(ntree): 1a

98 # parse first line of tree definition 

99 line, i_char = _get_next_line(trees, i_char) 1a

100 i_line += 1 1a

101 match = fullmatch(r'(\d+)', line) 1a

102 if match is None: 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true1a

103 msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}' 

104 raise ValueError(msg) 

105 num_nodes = int(line) 1a

106 

107 # cycle over nodes 

108 for i_node in range(num_nodes): 1a

109 # parse node definition 

110 line, i_char = _get_next_line(trees, i_char) 1a

111 i_line += 1 1a

112 match = fullmatch( 1a

113 r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line 

114 ) 

115 if match is None: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true1a

116 msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}' 

117 raise ValueError(msg) 

118 i_heap = int(match.group(1)) 1a

119 var = int(match.group(2)) 1a

120 split = int(match.group(3)) 1a

121 

122 # update maxima 

123 numcut[var] = max(numcut[var], split) 1a

124 max_heap_index = max(max_heap_index, i_heap) 1a

125 

126 assert i_char <= len(trees) 1a

127 if i_char < len(trees): 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true1a

128 msg = f'Leftover {len(trees) - i_char} characters in string' 

129 raise ValueError(msg) 

130 

131 # determine minimal integer type for numcut 

132 numcut += 1 # because BART is 0-based 1a

133 split_dtype = minimal_unsigned_dtype(numcut.max()) 1a

134 numcut = jnp.array(numcut.astype(split_dtype)) 1a

135 

136 # determine minimum heap size to store the trees 

137 heap_size = 2 ** ceil(log2(max_heap_index + 1)) 1a

138 

139 return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size) 1a

140 

141 

142class TraceWithOffset(Module): 

143 """Implementation of `bartz.mcmcloop.Trace`.""" 

144 

145 leaf_tree: Float32[Array, 'ndpost ntree 2**d'] 

146 var_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] 

147 split_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] 

148 offset: Float32[Array, ' ndpost'] 

149 

150 @classmethod 

151 def from_trees_trace( 

152 cls, trees: TreeHeaps, offset: Float32[Array, ''] 

153 ) -> 'TraceWithOffset': 

154 """Create a `TraceWithOffset` from a `TreeHeaps`.""" 

155 ndpost, _, _ = trees.leaf_tree.shape 1c

156 return cls( 1c

157 leaf_tree=trees.leaf_tree, 

158 var_tree=trees.var_tree, 

159 split_tree=trees.split_tree, 

160 offset=jnp.full(ndpost, offset), 

161 ) 

162 

163 

164def trees_BART_to_bartz( 

165 trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None 

166) -> tuple[TraceWithOffset, BARTTraceMeta]: 

167 """Convert trees from the R BART format to the bartz format. 

168 

169 Parameters 

170 ---------- 

171 trees 

172 The string representation of a trace of trees of the R BART package. 

173 Can be accessed from ``mc_gbart(...).treedraws['trees']``. 

174 min_maxdepth 

175 The maximum tree depth of the output will be set to the maximum 

176 observed depth in the input trees. Use this parameter to require at 

177 least this maximum depth in the output format. 

178 offset 

179 The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be 

180 summed to the sum of trees. To match that behavior, this function 

181 returns an offset as well, zero by default. Set with this parameter 

182 otherwise. 

183 

184 Returns 

185 ------- 

186 trace : TraceWithOffset 

187 A representation of the trees compatible with the trace returned by 

188 `bartz.mcmcloop.run_mcmc`. 

189 meta : BARTTraceMeta 

190 The metadata of the trace, containing the number of iterations, trees, 

191 and the maximum split value. 

192 """ 

193 # scan all the string checking for errors and determining sizes 

194 meta = scan_BART_trees(trees) 1a

195 

196 # skip first line 

197 _, i_char = _get_next_line(trees, 0) 1a

198 

199 heap_size = max(meta.heap_size, 2**min_maxdepth) 1a

200 leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32) 1a

201 var_trees = numpy.zeros( 1a

202 (meta.ndpost, meta.ntree, heap_size // 2), 

203 dtype=minimal_unsigned_dtype(meta.numcut.size - 1), 

204 ) 

205 split_trees = numpy.zeros( 1a

206 (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype 

207 ) 

208 

209 # cycle over iterations and trees 

210 for i_iter in range(meta.ndpost): 1a

211 for i_tree in range(meta.ntree): 1a

212 # parse first line of tree definition 

213 line, i_char = _get_next_line(trees, i_char) 1a

214 num_nodes = int(line) 1a

215 

216 is_internal = numpy.zeros(heap_size // 2, dtype=bool) 1a

217 

218 # cycle over nodes 

219 for _ in range(num_nodes): 1a

220 # parse node definition 

221 line, i_char = _get_next_line(trees, i_char) 1a

222 values = line.split() 1a

223 i_heap = int(values[0]) 1a

224 var = int(values[1]) 1a

225 split = int(values[2]) 1a

226 leaf = float(values[3]) 1a

227 

228 # update values 

229 leaf_trees[i_iter, i_tree, i_heap] = leaf 1a

230 is_internal[i_heap // 2] = True 1a

231 if i_heap < heap_size // 2: 1a

232 var_trees[i_iter, i_tree, i_heap] = var 1a

233 split_trees[i_iter, i_tree, i_heap] = split + 1 1a

234 

235 is_internal[0] = False 1a

236 split_trees[i_iter, i_tree, ~is_internal] = 0 1a

237 

238 return TraceWithOffset( 1a

239 leaf_tree=jnp.array(leaf_trees), 

240 var_tree=jnp.array(var_trees), 

241 split_tree=jnp.array(split_trees), 

242 offset=jnp.zeros(meta.ndpost) 

243 if offset is None 

244 else jnp.full(meta.ndpost, offset), 

245 ), meta