Coverage for src / bartz / debug / _traceconv.py: 88%
94 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
1# bartz/src/bartz/debug/_traceconv.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Debugging utilities. The main functionality is the class `debug_mc_gbart`."""
27from math import ceil, log2
28from re import fullmatch
30import numpy
31from equinox import Module, field
32from jax import numpy as jnp
33from jaxtyping import Array, Float32, UInt
35from bartz.BART._gbart import FloatLike
36from bartz.grove import TreeHeaps
37from bartz.jaxext import minimal_unsigned_dtype
40def _get_next_line(s: str, i: int) -> tuple[str, int]:
41 """Get the next line from a string and the new index."""
42 i_new = s.find('\n', i) 1a
43 if i_new == -1: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true1a
44 return s[i:], len(s)
45 return s[i:i_new], i_new + 1 1a
48class BARTTraceMeta(Module):
49 """Metadata of R BART tree traces."""
51 ndpost: int = field(static=True)
52 """The number of posterior draws."""
54 ntree: int = field(static=True)
55 """The number of trees in the model."""
57 numcut: UInt[Array, ' p']
58 """The maximum split value for each variable."""
60 heap_size: int = field(static=True)
61 """The size of the heap required to store the trees."""
64def scan_BART_trees(trees: str) -> BARTTraceMeta:
65 """Scan an R BART tree trace checking for errors and parsing metadata.
67 Parameters
68 ----------
69 trees
70 The string representation of a trace of trees of the R BART package.
71 Can be accessed from ``mc_gbart(...).treedraws['trees']``.
73 Returns
74 -------
75 An object containing the metadata.
77 Raises
78 ------
79 ValueError
80 If the string is malformed or contains leftover characters.
81 """
82 # parse first line
83 line, i_char = _get_next_line(trees, 0) 1a
84 i_line = 1 1a
85 match = fullmatch(r'(\d+) (\d+) (\d+)', line) 1a
86 if match is None: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true1a
87 msg = f'Malformed header at {i_line=}'
88 raise ValueError(msg)
89 ndpost, ntree, p = map(int, match.groups()) 1a
91 # initial values for maxima
92 max_heap_index = 0 1a
93 numcut = numpy.zeros(p, int) 1a
95 # cycle over iterations and trees
96 for i_iter in range(ndpost): 1a
97 for i_tree in range(ntree): 1a
98 # parse first line of tree definition
99 line, i_char = _get_next_line(trees, i_char) 1a
100 i_line += 1 1a
101 match = fullmatch(r'(\d+)', line) 1a
102 if match is None: 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true1a
103 msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}'
104 raise ValueError(msg)
105 num_nodes = int(line) 1a
107 # cycle over nodes
108 for i_node in range(num_nodes): 1a
109 # parse node definition
110 line, i_char = _get_next_line(trees, i_char) 1a
111 i_line += 1 1a
112 match = fullmatch( 1a
113 r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line
114 )
115 if match is None: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true1a
116 msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}'
117 raise ValueError(msg)
118 i_heap = int(match.group(1)) 1a
119 var = int(match.group(2)) 1a
120 split = int(match.group(3)) 1a
122 # update maxima
123 numcut[var] = max(numcut[var], split) 1a
124 max_heap_index = max(max_heap_index, i_heap) 1a
126 assert i_char <= len(trees) 1a
127 if i_char < len(trees): 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true1a
128 msg = f'Leftover {len(trees) - i_char} characters in string'
129 raise ValueError(msg)
131 # determine minimal integer type for numcut
132 numcut += 1 # because BART is 0-based 1a
133 split_dtype = minimal_unsigned_dtype(numcut.max()) 1a
134 numcut = jnp.array(numcut.astype(split_dtype)) 1a
136 # determine minimum heap size to store the trees
137 heap_size = 2 ** ceil(log2(max_heap_index + 1)) 1a
139 return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size) 1a
142class TraceWithOffset(Module):
143 """Implementation of `bartz.mcmcloop.Trace`."""
145 leaf_tree: Float32[Array, 'ndpost ntree 2**d']
146 var_tree: UInt[Array, 'ndpost ntree 2**(d-1)']
147 split_tree: UInt[Array, 'ndpost ntree 2**(d-1)']
148 offset: Float32[Array, ' ndpost']
150 @classmethod
151 def from_trees_trace(
152 cls, trees: TreeHeaps, offset: Float32[Array, '']
153 ) -> 'TraceWithOffset':
154 """Create a `TraceWithOffset` from a `TreeHeaps`."""
155 ndpost, _, _ = trees.leaf_tree.shape 1c
156 return cls( 1c
157 leaf_tree=trees.leaf_tree,
158 var_tree=trees.var_tree,
159 split_tree=trees.split_tree,
160 offset=jnp.full(ndpost, offset),
161 )
164def trees_BART_to_bartz(
165 trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None
166) -> tuple[TraceWithOffset, BARTTraceMeta]:
167 """Convert trees from the R BART format to the bartz format.
169 Parameters
170 ----------
171 trees
172 The string representation of a trace of trees of the R BART package.
173 Can be accessed from ``mc_gbart(...).treedraws['trees']``.
174 min_maxdepth
175 The maximum tree depth of the output will be set to the maximum
176 observed depth in the input trees. Use this parameter to require at
177 least this maximum depth in the output format.
178 offset
179 The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be
180 summed to the sum of trees. To match that behavior, this function
181 returns an offset as well, zero by default. Set with this parameter
182 otherwise.
184 Returns
185 -------
186 trace : TraceWithOffset
187 A representation of the trees compatible with the trace returned by
188 `bartz.mcmcloop.run_mcmc`.
189 meta : BARTTraceMeta
190 The metadata of the trace, containing the number of iterations, trees,
191 and the maximum split value.
192 """
193 # scan all the string checking for errors and determining sizes
194 meta = scan_BART_trees(trees) 1a
196 # skip first line
197 _, i_char = _get_next_line(trees, 0) 1a
199 heap_size = max(meta.heap_size, 2**min_maxdepth) 1a
200 leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32) 1a
201 var_trees = numpy.zeros( 1a
202 (meta.ndpost, meta.ntree, heap_size // 2),
203 dtype=minimal_unsigned_dtype(meta.numcut.size - 1),
204 )
205 split_trees = numpy.zeros( 1a
206 (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype
207 )
209 # cycle over iterations and trees
210 for i_iter in range(meta.ndpost): 1a
211 for i_tree in range(meta.ntree): 1a
212 # parse first line of tree definition
213 line, i_char = _get_next_line(trees, i_char) 1a
214 num_nodes = int(line) 1a
216 is_internal = numpy.zeros(heap_size // 2, dtype=bool) 1a
218 # cycle over nodes
219 for _ in range(num_nodes): 1a
220 # parse node definition
221 line, i_char = _get_next_line(trees, i_char) 1a
222 values = line.split() 1a
223 i_heap = int(values[0]) 1a
224 var = int(values[1]) 1a
225 split = int(values[2]) 1a
226 leaf = float(values[3]) 1a
228 # update values
229 leaf_trees[i_iter, i_tree, i_heap] = leaf 1a
230 is_internal[i_heap // 2] = True 1a
231 if i_heap < heap_size // 2: 1a
232 var_trees[i_iter, i_tree, i_heap] = var 1a
233 split_trees[i_iter, i_tree, i_heap] = split + 1 1a
235 is_internal[0] = False 1a
236 split_trees[i_iter, i_tree, ~is_internal] = 0 1a
238 return TraceWithOffset( 1a
239 leaf_tree=jnp.array(leaf_trees),
240 var_tree=jnp.array(var_trees),
241 split_tree=jnp.array(split_trees),
242 offset=jnp.zeros(meta.ndpost)
243 if offset is None
244 else jnp.full(meta.ndpost, offset),
245 ), meta