Coverage for src/bartz/debug/_traceconv.py: 89%
101 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/debug/_traceconv.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Parsing of R BART3 tree traces."""
27import math
28from re import fullmatch
29from typing import ClassVar
31import numpy
32from jax import numpy as jnp
33from jax.sharding import Mesh
34from jaxtyping import Array, Float32, UInt
36from bartz._jaxext import Module, field, minimal_unsigned_dtype
37from bartz.BART._gbart import FloatLike
38from bartz.grove import TreeHeaps
41def _get_next_line(s: str, i: int) -> tuple[str, int]:
42 """Get the next line from a string and the new index."""
43 i_new = s.find('\n', i)
44 if i_new == -1: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true
45 return s[i:], len(s)
46 return s[i:i_new], i_new + 1
49class BARTTraceMeta(Module):
50 """Metadata of R BART tree traces."""
52 ndpost: int = field(static=True)
53 """The number of posterior draws."""
55 ntree: int = field(static=True)
56 """The number of trees in the model."""
58 numcut: UInt[Array, ' p']
59 """The maximum split value for each variable."""
61 heap_size: int = field(static=True)
62 """The size of the heap required to store the trees."""
65def scan_BART_trees(trees: str) -> BARTTraceMeta:
66 """Scan an R BART tree trace checking for errors and parsing metadata.
68 Parameters
69 ----------
70 trees
71 The string representation of a trace of trees of the R BART package.
72 Can be accessed from ``mc_gbart(...).treedraws['trees']``.
74 Returns
75 -------
76 An object containing the metadata.
78 Raises
79 ------
80 ValueError
81 If the string is malformed or contains leftover characters.
82 """
83 # parse first line
84 line, i_char = _get_next_line(trees, 0)
85 i_line = 1
86 match = fullmatch(r'(\d+) (\d+) (\d+)', line)
87 if match is None: 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true
88 msg = f'Malformed header at {i_line=}'
89 raise ValueError(msg)
90 ndpost, ntree, p = map(int, match.groups())
92 # initial values for maxima
93 max_heap_index = 0
94 numcut = numpy.zeros(p, int)
96 # cycle over iterations and trees
97 for i_iter in range(ndpost):
98 for i_tree in range(ntree):
99 # parse first line of tree definition
100 line, i_char = _get_next_line(trees, i_char)
101 i_line += 1
102 match = fullmatch(r'(\d+)', line)
103 if match is None: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true
104 msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}'
105 raise ValueError(msg)
106 num_nodes = int(line)
108 # cycle over nodes
109 for i_node in range(num_nodes):
110 # parse node definition
111 line, i_char = _get_next_line(trees, i_char)
112 i_line += 1
113 match = fullmatch(
114 r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line
115 )
116 if match is None: 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true
117 msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}'
118 raise ValueError(msg)
119 i_heap = int(match.group(1))
120 var = int(match.group(2))
121 split = int(match.group(3))
123 # update maxima
124 numcut[var] = max(numcut[var], split)
125 max_heap_index = max(max_heap_index, i_heap)
127 assert i_char <= len(trees)
128 if i_char < len(trees): 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true
129 msg = f'Leftover {len(trees) - i_char} characters in string'
130 raise ValueError(msg)
132 # determine minimal integer type for numcut
133 numcut += 1 # because BART is 0-based
134 split_dtype = minimal_unsigned_dtype(numcut.max().item())
135 numcut = jnp.array(numcut.astype(split_dtype))
137 # determine minimum heap size to store the trees
138 heap_size = 2 ** math.ceil(math.log2(max_heap_index + 1))
140 return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size)
143class TraceWithOffset(Module):
144 """A trace of trees with an offset, compatible with `bartz.mcmcloop.evaluate_trace`."""
146 leaf_tree: Float32[Array, 'ndpost ntree tree_size'] = field(samples=0)
147 var_tree: UInt[Array, 'ndpost ntree tree_size//2'] = field(samples=0)
148 split_tree: UInt[Array, 'ndpost ntree tree_size//2'] = field(samples=0)
149 offset: Float32[Array, '']
151 has_chains: ClassVar[bool] = False
152 """No chain axis; each leading axis is just the sample axis."""
154 mesh: ClassVar[Mesh | None] = None
155 """No device mesh; the trees are host-built and unsharded."""
157 @classmethod
158 def from_trees_trace(
159 cls, trees: TreeHeaps, offset: Float32[Array, '']
160 ) -> 'TraceWithOffset':
161 """Create a `TraceWithOffset` from a `~bartz.grove.TreeHeaps`."""
162 return cls(
163 leaf_tree=trees.leaf_tree,
164 var_tree=trees.var_tree,
165 split_tree=trees.split_tree,
166 offset=offset,
167 )
170def trees_BART_to_bartz(
171 trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None
172) -> tuple[TraceWithOffset, BARTTraceMeta]:
173 """Convert trees from the R BART format to the bartz format.
175 Parameters
176 ----------
177 trees
178 The string representation of a trace of trees of the R BART package.
179 Can be accessed from ``mc_gbart(...).treedraws['trees']``.
180 min_maxdepth
181 The maximum tree depth of the output will be set to the maximum
182 observed depth in the input trees. Use this parameter to require at
183 least this maximum depth in the output format.
184 offset
185 The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be
186 summed to the sum of trees. To match that behavior, this function
187 returns an offset as well, zero by default. Set with this parameter
188 otherwise.
190 Returns
191 -------
192 trace : TraceWithOffset
193 A representation of the trees compatible with the trace returned by
194 `bartz.mcmcloop.run_mcmc`.
195 meta : BARTTraceMeta
196 The metadata of the trace, containing the number of iterations, trees,
197 and the maximum split value.
198 """
199 # scan all the string checking for errors and determining sizes
200 meta = scan_BART_trees(trees)
202 # skip first line
203 _, i_char = _get_next_line(trees, 0)
205 heap_size = max(meta.heap_size, 2**min_maxdepth)
206 leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32)
207 var_trees = numpy.zeros(
208 (meta.ndpost, meta.ntree, heap_size // 2),
209 dtype=minimal_unsigned_dtype(meta.numcut.size - 1),
210 )
211 split_trees = numpy.zeros(
212 (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype
213 )
215 # cycle over iterations and trees
216 for i_iter in range(meta.ndpost):
217 for i_tree in range(meta.ntree):
218 # parse first line of tree definition
219 line, i_char = _get_next_line(trees, i_char)
220 num_nodes = int(line)
222 is_internal = numpy.zeros(heap_size // 2, dtype=bool)
224 # cycle over nodes
225 for _ in range(num_nodes):
226 # parse node definition
227 line, i_char = _get_next_line(trees, i_char)
228 values = line.split()
229 i_heap = int(values[0])
230 var = int(values[1])
231 split = int(values[2])
232 leaf = float(values[3])
234 # update values
235 leaf_trees[i_iter, i_tree, i_heap] = leaf
236 is_internal[i_heap // 2] = True
237 if i_heap < heap_size // 2:
238 var_trees[i_iter, i_tree, i_heap] = var
239 split_trees[i_iter, i_tree, i_heap] = split + 1
241 is_internal[0] = False
242 split_trees[i_iter, i_tree, ~is_internal] = 0
244 return TraceWithOffset(
245 leaf_tree=jnp.array(leaf_trees),
246 var_tree=jnp.array(var_trees),
247 split_tree=jnp.array(split_trees),
248 offset=jnp.float32(0.0 if offset is None else offset),
249 ), meta