Coverage for src/bartz/debug/_traceconv.py: 89%

101 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/debug/_traceconv.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Parsing of R BART3 tree traces.""" 

26 

27import math 

28from re import fullmatch 

29from typing import ClassVar 

30 

31import numpy 

32from jax import numpy as jnp 

33from jax.sharding import Mesh 

34from jaxtyping import Array, Float32, UInt 

35 

36from bartz._jaxext import Module, field, minimal_unsigned_dtype 

37from bartz.BART._gbart import FloatLike 

38from bartz.grove import TreeHeaps 

39 

40 

41def _get_next_line(s: str, i: int) -> tuple[str, int]: 

42 """Get the next line from a string and the new index.""" 

43 i_new = s.find('\n', i) 

44 if i_new == -1: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true

45 return s[i:], len(s) 

46 return s[i:i_new], i_new + 1 

47 

48 

49class BARTTraceMeta(Module): 

50 """Metadata of R BART tree traces.""" 

51 

52 ndpost: int = field(static=True) 

53 """The number of posterior draws.""" 

54 

55 ntree: int = field(static=True) 

56 """The number of trees in the model.""" 

57 

58 numcut: UInt[Array, ' p'] 

59 """The maximum split value for each variable.""" 

60 

61 heap_size: int = field(static=True) 

62 """The size of the heap required to store the trees.""" 

63 

64 

65def scan_BART_trees(trees: str) -> BARTTraceMeta: 

66 """Scan an R BART tree trace checking for errors and parsing metadata. 

67 

68 Parameters 

69 ---------- 

70 trees 

71 The string representation of a trace of trees of the R BART package. 

72 Can be accessed from ``mc_gbart(...).treedraws['trees']``. 

73 

74 Returns 

75 ------- 

76 An object containing the metadata. 

77 

78 Raises 

79 ------ 

80 ValueError 

81 If the string is malformed or contains leftover characters. 

82 """ 

83 # parse first line 

84 line, i_char = _get_next_line(trees, 0) 

85 i_line = 1 

86 match = fullmatch(r'(\d+) (\d+) (\d+)', line) 

87 if match is None: 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true

88 msg = f'Malformed header at {i_line=}' 

89 raise ValueError(msg) 

90 ndpost, ntree, p = map(int, match.groups()) 

91 

92 # initial values for maxima 

93 max_heap_index = 0 

94 numcut = numpy.zeros(p, int) 

95 

96 # cycle over iterations and trees 

97 for i_iter in range(ndpost): 

98 for i_tree in range(ntree): 

99 # parse first line of tree definition 

100 line, i_char = _get_next_line(trees, i_char) 

101 i_line += 1 

102 match = fullmatch(r'(\d+)', line) 

103 if match is None: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}' 

105 raise ValueError(msg) 

106 num_nodes = int(line) 

107 

108 # cycle over nodes 

109 for i_node in range(num_nodes): 

110 # parse node definition 

111 line, i_char = _get_next_line(trees, i_char) 

112 i_line += 1 

113 match = fullmatch( 

114 r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line 

115 ) 

116 if match is None: 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true

117 msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}' 

118 raise ValueError(msg) 

119 i_heap = int(match.group(1)) 

120 var = int(match.group(2)) 

121 split = int(match.group(3)) 

122 

123 # update maxima 

124 numcut[var] = max(numcut[var], split) 

125 max_heap_index = max(max_heap_index, i_heap) 

126 

127 assert i_char <= len(trees) 

128 if i_char < len(trees): 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true

129 msg = f'Leftover {len(trees) - i_char} characters in string' 

130 raise ValueError(msg) 

131 

132 # determine minimal integer type for numcut 

133 numcut += 1 # because BART is 0-based 

134 split_dtype = minimal_unsigned_dtype(numcut.max().item()) 

135 numcut = jnp.array(numcut.astype(split_dtype)) 

136 

137 # determine minimum heap size to store the trees 

138 heap_size = 2 ** math.ceil(math.log2(max_heap_index + 1)) 

139 

140 return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size) 

141 

142 

143class TraceWithOffset(Module): 

144 """A trace of trees with an offset, compatible with `bartz.mcmcloop.evaluate_trace`.""" 

145 

146 leaf_tree: Float32[Array, 'ndpost ntree tree_size'] = field(samples=0) 

147 var_tree: UInt[Array, 'ndpost ntree tree_size//2'] = field(samples=0) 

148 split_tree: UInt[Array, 'ndpost ntree tree_size//2'] = field(samples=0) 

149 offset: Float32[Array, ''] 

150 

151 has_chains: ClassVar[bool] = False 

152 """No chain axis; each leading axis is just the sample axis.""" 

153 

154 mesh: ClassVar[Mesh | None] = None 

155 """No device mesh; the trees are host-built and unsharded.""" 

156 

157 @classmethod 

158 def from_trees_trace( 

159 cls, trees: TreeHeaps, offset: Float32[Array, ''] 

160 ) -> 'TraceWithOffset': 

161 """Create a `TraceWithOffset` from a `~bartz.grove.TreeHeaps`.""" 

162 return cls( 

163 leaf_tree=trees.leaf_tree, 

164 var_tree=trees.var_tree, 

165 split_tree=trees.split_tree, 

166 offset=offset, 

167 ) 

168 

169 

170def trees_BART_to_bartz( 

171 trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None 

172) -> tuple[TraceWithOffset, BARTTraceMeta]: 

173 """Convert trees from the R BART format to the bartz format. 

174 

175 Parameters 

176 ---------- 

177 trees 

178 The string representation of a trace of trees of the R BART package. 

179 Can be accessed from ``mc_gbart(...).treedraws['trees']``. 

180 min_maxdepth 

181 The maximum tree depth of the output will be set to the maximum 

182 observed depth in the input trees. Use this parameter to require at 

183 least this maximum depth in the output format. 

184 offset 

185 The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be 

186 summed to the sum of trees. To match that behavior, this function 

187 returns an offset as well, zero by default. Set with this parameter 

188 otherwise. 

189 

190 Returns 

191 ------- 

192 trace : TraceWithOffset 

193 A representation of the trees compatible with the trace returned by 

194 `bartz.mcmcloop.run_mcmc`. 

195 meta : BARTTraceMeta 

196 The metadata of the trace, containing the number of iterations, trees, 

197 and the maximum split value. 

198 """ 

199 # scan all the string checking for errors and determining sizes 

200 meta = scan_BART_trees(trees) 

201 

202 # skip first line 

203 _, i_char = _get_next_line(trees, 0) 

204 

205 heap_size = max(meta.heap_size, 2**min_maxdepth) 

206 leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32) 

207 var_trees = numpy.zeros( 

208 (meta.ndpost, meta.ntree, heap_size // 2), 

209 dtype=minimal_unsigned_dtype(meta.numcut.size - 1), 

210 ) 

211 split_trees = numpy.zeros( 

212 (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype 

213 ) 

214 

215 # cycle over iterations and trees 

216 for i_iter in range(meta.ndpost): 

217 for i_tree in range(meta.ntree): 

218 # parse first line of tree definition 

219 line, i_char = _get_next_line(trees, i_char) 

220 num_nodes = int(line) 

221 

222 is_internal = numpy.zeros(heap_size // 2, dtype=bool) 

223 

224 # cycle over nodes 

225 for _ in range(num_nodes): 

226 # parse node definition 

227 line, i_char = _get_next_line(trees, i_char) 

228 values = line.split() 

229 i_heap = int(values[0]) 

230 var = int(values[1]) 

231 split = int(values[2]) 

232 leaf = float(values[3]) 

233 

234 # update values 

235 leaf_trees[i_iter, i_tree, i_heap] = leaf 

236 is_internal[i_heap // 2] = True 

237 if i_heap < heap_size // 2: 

238 var_trees[i_iter, i_tree, i_heap] = var 

239 split_trees[i_iter, i_tree, i_heap] = split + 1 

240 

241 is_internal[0] = False 

242 split_trees[i_iter, i_tree, ~is_internal] = 0 

243 

244 return TraceWithOffset( 

245 leaf_tree=jnp.array(leaf_trees), 

246 var_tree=jnp.array(var_trees), 

247 split_tree=jnp.array(split_trees), 

248 offset=jnp.float32(0.0 if offset is None else offset), 

249 ), meta