Coverage for src / bartz / debug / _traceconv.py: 88%
94 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/debug/_traceconv.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Debugging utilities. The main functionality is the class `debug_mc_gbart`."""
27from math import ceil, log2
28from re import fullmatch
30import numpy
31from equinox import Module, field
32from jax import numpy as jnp
33from jaxtyping import Array, Float32, UInt
35from bartz.BART._gbart import FloatLike
36from bartz.grove import TreeHeaps
37from bartz.jaxext import minimal_unsigned_dtype
40def _get_next_line(s: str, i: int) -> tuple[str, int]:
41 """Get the next line from a string and the new index."""
42 i_new = s.find('\n', i) 1abc
43 if i_new == -1: 43 ↛ 44line 43 didn't jump to line 44 because the condition on line 43 was never true1abc
44 return s[i:], len(s)
45 return s[i:i_new], i_new + 1 1abc
48class BARTTraceMeta(Module):
49 """Metadata of R BART tree traces."""
51 ndpost: int = field(static=True)
52 """The number of posterior draws."""
54 ntree: int = field(static=True)
55 """The number of trees in the model."""
57 numcut: UInt[Array, ' p']
58 """The maximum split value for each variable."""
60 heap_size: int = field(static=True)
61 """The size of the heap required to store the trees."""
64def scan_BART_trees(trees: str) -> BARTTraceMeta:
65 """Scan an R BART tree trace checking for errors and parsing metadata.
67 Parameters
68 ----------
69 trees
70 The string representation of a trace of trees of the R BART package.
71 Can be accessed from ``mc_gbart(...).treedraws['trees']``.
73 Returns
74 -------
75 An object containing the metadata.
77 Raises
78 ------
79 ValueError
80 If the string is malformed or contains leftover characters.
81 """
82 # parse first line
83 line, i_char = _get_next_line(trees, 0) 1abc
84 i_line = 1 1abc
85 match = fullmatch(r'(\d+) (\d+) (\d+)', line) 1abc
86 if match is None: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true1abc
87 msg = f'Malformed header at {i_line=}'
88 raise ValueError(msg)
89 ndpost, ntree, p = map(int, match.groups()) 1abc
91 # initial values for maxima
92 max_heap_index = 0 1abc
93 numcut = numpy.zeros(p, int) 1abc
95 # cycle over iterations and trees
96 for i_iter in range(ndpost): 1abc
97 for i_tree in range(ntree): 1abc
98 # parse first line of tree definition
99 line, i_char = _get_next_line(trees, i_char) 1abc
100 i_line += 1 1abc
101 match = fullmatch(r'(\d+)', line) 1abc
102 if match is None: 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true1abc
103 msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}'
104 raise ValueError(msg)
105 num_nodes = int(line) 1abc
107 # cycle over nodes
108 for i_node in range(num_nodes): 1abc
109 # parse node definition
110 line, i_char = _get_next_line(trees, i_char) 1abc
111 i_line += 1 1abc
112 match = fullmatch( 1abc
113 r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line
114 )
115 if match is None: 115 ↛ 116line 115 didn't jump to line 116 because the condition on line 115 was never true1abc
116 msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}'
117 raise ValueError(msg)
118 i_heap = int(match.group(1)) 1abc
119 var = int(match.group(2)) 1abc
120 split = int(match.group(3)) 1abc
122 # update maxima
123 numcut[var] = max(numcut[var], split) 1abc
124 max_heap_index = max(max_heap_index, i_heap) 1abc
126 assert i_char <= len(trees) 1abc
127 if i_char < len(trees): 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true1abc
128 msg = f'Leftover {len(trees) - i_char} characters in string'
129 raise ValueError(msg)
131 # determine minimal integer type for numcut
132 numcut += 1 # because BART is 0-based 1abc
133 split_dtype = minimal_unsigned_dtype(numcut.max()) 1abc
134 numcut = jnp.array(numcut.astype(split_dtype)) 1abc
136 # determine minimum heap size to store the trees
137 heap_size = 2 ** ceil(log2(max_heap_index + 1)) 1abc
139 return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size) 1abc
142class TraceWithOffset(Module):
143 """Implementation of `bartz.mcmcloop.Trace`."""
145 leaf_tree: Float32[Array, 'ndpost ntree 2**d']
146 var_tree: UInt[Array, 'ndpost ntree 2**(d-1)']
147 split_tree: UInt[Array, 'ndpost ntree 2**(d-1)']
148 offset: Float32[Array, ' ndpost']
150 @classmethod
151 def from_trees_trace(
152 cls, trees: TreeHeaps, offset: Float32[Array, '']
153 ) -> 'TraceWithOffset':
154 """Create a `TraceWithOffset` from a `TreeHeaps`."""
155 ndpost, _, _ = trees.leaf_tree.shape 1efgh
156 return cls( 1efgh
157 leaf_tree=trees.leaf_tree,
158 var_tree=trees.var_tree,
159 split_tree=trees.split_tree,
160 offset=jnp.full(ndpost, offset),
161 )
164def trees_BART_to_bartz(
165 trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None
166) -> tuple[TraceWithOffset, BARTTraceMeta]:
167 """Convert trees from the R BART format to the bartz format.
169 Parameters
170 ----------
171 trees
172 The string representation of a trace of trees of the R BART package.
173 Can be accessed from ``mc_gbart(...).treedraws['trees']``.
174 min_maxdepth
175 The maximum tree depth of the output will be set to the maximum
176 observed depth in the input trees. Use this parameter to require at
177 least this maximum depth in the output format.
178 offset
179 The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be
180 summed to the sum of trees. To match that behavior, this function
181 returns an offset as well, zero by default. Set with this parameter
182 otherwise.
184 Returns
185 -------
186 trace : TraceWithOffset
187 A representation of the trees compatible with the trace returned by
188 `bartz.mcmcloop.run_mcmc`.
189 meta : BARTTraceMeta
190 The metadata of the trace, containing the number of iterations, trees,
191 and the maximum split value.
192 """
193 # scan all the string checking for errors and determining sizes
194 meta = scan_BART_trees(trees) 1abc
196 # skip first line
197 _, i_char = _get_next_line(trees, 0) 1abc
199 heap_size = max(meta.heap_size, 2**min_maxdepth) 1abc
200 leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32) 1abc
201 var_trees = numpy.zeros( 1abc
202 (meta.ndpost, meta.ntree, heap_size // 2),
203 dtype=minimal_unsigned_dtype(meta.numcut.size - 1),
204 )
205 split_trees = numpy.zeros( 1abc
206 (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype
207 )
209 # cycle over iterations and trees
210 for i_iter in range(meta.ndpost): 1abc
211 for i_tree in range(meta.ntree): 1abc
212 # parse first line of tree definition
213 line, i_char = _get_next_line(trees, i_char) 1abc
214 num_nodes = int(line) 1abc
216 is_internal = numpy.zeros(heap_size // 2, dtype=bool) 1abc
218 # cycle over nodes
219 for _ in range(num_nodes): 1abc
220 # parse node definition
221 line, i_char = _get_next_line(trees, i_char) 1abc
222 values = line.split() 1abc
223 i_heap = int(values[0]) 1abc
224 var = int(values[1]) 1abc
225 split = int(values[2]) 1abc
226 leaf = float(values[3]) 1abc
228 # update values
229 leaf_trees[i_iter, i_tree, i_heap] = leaf 1abc
230 is_internal[i_heap // 2] = True 1abc
231 if i_heap < heap_size // 2: 1abc
232 var_trees[i_iter, i_tree, i_heap] = var 1abc
233 split_trees[i_iter, i_tree, i_heap] = split + 1 1abc
235 is_internal[0] = False 1abc
236 split_trees[i_iter, i_tree, ~is_internal] = 0 1abc
238 return TraceWithOffset( 1abc
239 leaf_tree=jnp.array(leaf_trees),
240 var_tree=jnp.array(var_trees),
241 split_tree=jnp.array(split_trees),
242 offset=jnp.zeros(meta.ndpost)
243 if offset is None
244 else jnp.full(meta.ndpost, offset),
245 ), meta