Coverage for src / bartz / debug / _prior.py: 100%
97 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/debug/_prior.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Debugging utilities. The main functionality is the class `debug_mc_gbart`."""
27from dataclasses import replace
28from functools import partial
30from equinox import Module
31from jax import jit, lax, random
32from jax import numpy as jnp
33from jax.tree_util import tree_map
34from jaxtyping import Array, Bool, Float32, Int32, Key, UInt
36from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc
37from bartz.jaxext import split as split_key
38from bartz.mcmcstep._moves import randint_masked
41class SamplePriorStack(Module):
42 """Represent the manually managed stack used in `sample_prior`.
44 Each level of the stack represents a recursion into a child node in a
45 binary tree of maximum depth `d`.
46 """
48 nonterminal: Bool[Array, ' d-1']
49 """Whether the node is valid or the recursion is into unused node slots."""
51 lower: UInt[Array, 'd-1 p']
52 """The available cutpoints along ``var`` are in the integer range
53 ``[1 + lower[var], 1 + upper[var])``."""
55 upper: UInt[Array, 'd-1 p']
56 """The available cutpoints along ``var`` are in the integer range
57 ``[1 + lower[var], 1 + upper[var])``."""
59 var: UInt[Array, ' d-1']
60 """The variable of a decision node."""
62 split: UInt[Array, ' d-1']
63 """The cutpoint of a decision node."""
65 @classmethod
66 def initial(
67 cls, p_nonterminal: Float32[Array, ' d-1'], max_split: UInt[Array, ' p']
68 ) -> 'SamplePriorStack':
69 """Initialize the stack.
71 Parameters
72 ----------
73 p_nonterminal
74 The prior probability of a node being non-terminal conditional on
75 its ancestors and on having available decision rules, at each depth.
76 max_split
77 The number of cutpoints along each variable.
79 Returns
80 -------
81 A `SamplePriorStack` initialized to start the recursion.
82 """
83 var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1abcd
84 return cls( 1abcd
85 nonterminal=jnp.ones(p_nonterminal.size, bool),
86 lower=jnp.zeros((p_nonterminal.size, max_split.size), max_split.dtype),
87 upper=jnp.broadcast_to(max_split, (p_nonterminal.size, max_split.size)),
88 var=jnp.zeros(p_nonterminal.size, var_dtype),
89 split=jnp.zeros(p_nonterminal.size, max_split.dtype),
90 )
93class SamplePriorTrees(Module):
94 """Object holding the trees generated by `sample_prior`."""
96 leaf_tree: Float32[Array, '* 2**d']
97 """The array representing the trees, see `bartz.grove`."""
99 var_tree: UInt[Array, '* 2**(d-1)']
100 """The array representing the trees, see `bartz.grove`."""
102 split_tree: UInt[Array, '* 2**(d-1)']
103 """The array representing the trees, see `bartz.grove`."""
105 @classmethod
106 def initial(
107 cls,
108 key: Key[Array, ''],
109 sigma_mu: Float32[Array, ''],
110 p_nonterminal: Float32[Array, ' d-1'],
111 max_split: UInt[Array, ' p'],
112 ) -> 'SamplePriorTrees':
113 """Initialize the trees.
115 The leaves are already correct and do not need to be changed.
117 Parameters
118 ----------
119 key
120 A jax random key.
121 sigma_mu
122 The prior standard deviation of each leaf.
123 p_nonterminal
124 The prior probability of a node being non-terminal conditional on
125 its ancestors and on having available decision rules, at each depth.
126 max_split
127 The number of cutpoints along each variable.
129 Returns
130 -------
131 Trees initialized with random leaves and stub tree structures.
132 """
133 heap_size = 2 ** (p_nonterminal.size + 1) 1abcd
134 return cls( 1abcd
135 leaf_tree=sigma_mu * random.normal(key, (heap_size,)),
136 var_tree=jnp.zeros(
137 heap_size // 2, dtype=minimal_unsigned_dtype(max_split.size - 1)
138 ),
139 split_tree=jnp.zeros(heap_size // 2, dtype=max_split.dtype),
140 )
143class SamplePriorCarry(Module):
144 """Object holding values carried along the recursion in `sample_prior`."""
146 key: Key[Array, '']
147 """A jax random key used to sample decision rules."""
149 stack: SamplePriorStack
150 """The stack used to manage the recursion."""
152 trees: SamplePriorTrees
153 """The output arrays."""
155 @classmethod
156 def initial(
157 cls,
158 key: Key[Array, ''],
159 sigma_mu: Float32[Array, ''],
160 p_nonterminal: Float32[Array, ' d-1'],
161 max_split: UInt[Array, ' p'],
162 ) -> 'SamplePriorCarry':
163 """Initialize the carry object.
165 Parameters
166 ----------
167 key
168 A jax random key.
169 sigma_mu
170 The prior standard deviation of each leaf.
171 p_nonterminal
172 The prior probability of a node being non-terminal conditional on
173 its ancestors and on having available decision rules, at each depth.
174 max_split
175 The number of cutpoints along each variable.
177 Returns
178 -------
179 A `SamplePriorCarry` initialized to start the recursion.
180 """
181 keys = split_key(key) 1abcd
182 return cls( 1abcd
183 keys.pop(),
184 SamplePriorStack.initial(p_nonterminal, max_split),
185 SamplePriorTrees.initial(keys.pop(), sigma_mu, p_nonterminal, max_split),
186 )
189class SamplePriorX(Module):
190 """Object representing the recursion scan in `sample_prior`.
192 The sequence of nodes to visit is pre-computed recursively once, unrolling
193 the recursion schedule.
194 """
196 node: Int32[Array, ' 2**(d-1)-1']
197 """The heap index of the node to visit."""
199 depth: Int32[Array, ' 2**(d-1)-1']
200 """The depth of the node."""
202 next_depth: Int32[Array, ' 2**(d-1)-1']
203 """The depth of the next node to visit, either the left child or the right
204 sibling of the node or of an ancestor."""
206 @classmethod
207 def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX':
208 """Initialize the sequence of nodes to visit.
210 Parameters
211 ----------
212 p_nonterminal
213 The prior probability of a node being non-terminal conditional on
214 its ancestors and on having available decision rules, at each depth.
216 Returns
217 -------
218 A `SamplePriorX` initialized with the sequence of nodes to visit.
219 """
220 seq = cls._sequence(p_nonterminal.size) 1abcd
221 assert len(seq) == 2**p_nonterminal.size - 1 1abcd
222 node = [node for node, depth in seq] 1abcd
223 depth = [depth for node, depth in seq] 1abcd
224 next_depth = [*depth[1:], p_nonterminal.size] 1abcd
225 return cls( 1abcd
226 node=jnp.array(node),
227 depth=jnp.array(depth),
228 next_depth=jnp.array(next_depth),
229 )
231 @classmethod
232 def _sequence(
233 cls, max_depth: int, depth: int = 0, node: int = 1
234 ) -> tuple[tuple[int, int], ...]:
235 """Recursively generate a sequence [(node, depth), ...]."""
236 if depth < max_depth: 1abcd
237 out = ((node, depth),) 1abcd
238 out += cls._sequence(max_depth, depth + 1, 2 * node) 1abcd
239 out += cls._sequence(max_depth, depth + 1, 2 * node + 1) 1abcd
240 return out 1abcd
241 return () 1abcd
244def sample_prior_onetree(
245 key: Key[Array, ''],
246 max_split: UInt[Array, ' p'],
247 p_nonterminal: Float32[Array, ' d-1'],
248 sigma_mu: Float32[Array, ''],
249) -> SamplePriorTrees:
250 """Sample a tree from the BART prior.
252 Parameters
253 ----------
254 key
255 A jax random key.
256 max_split
257 The maximum split value for each variable.
258 p_nonterminal
259 The prior probability of a node being non-terminal conditional on
260 its ancestors and on having available decision rules, at each depth.
261 sigma_mu
262 The prior standard deviation of each leaf.
264 Returns
265 -------
266 An object containing a generated tree.
267 """
268 carry = SamplePriorCarry.initial(key, sigma_mu, p_nonterminal, max_split) 1abcd
269 xs = SamplePriorX.initial(p_nonterminal) 1abcd
271 def loop(carry: SamplePriorCarry, x: SamplePriorX) -> tuple[SamplePriorCarry, None]: 1abcd
272 keys = split_key(carry.key, 4) 1abcd
274 # get variables at current stack level
275 stack = carry.stack 1abcd
276 nonterminal = stack.nonterminal[x.depth] 1abcd
277 lower = stack.lower[x.depth, :] 1abcd
278 upper = stack.upper[x.depth, :] 1abcd
280 # sample a random decision rule
281 available: Bool[Array, ' p'] = lower < upper 1abcd
282 allowed = jnp.any(available) 1abcd
283 var = randint_masked(keys.pop(), available) 1abcd
284 split = 1 + random.randint(keys.pop(), (), lower[var], upper[var]) 1abcd
286 # cast to shorter integer types
287 var = var.astype(carry.trees.var_tree.dtype) 1abcd
288 split = split.astype(carry.trees.split_tree.dtype) 1abcd
290 # decide whether to try to grow the node if it is growable
291 pnt = p_nonterminal[x.depth] 1abcd
292 try_nonterminal: Bool[Array, ''] = random.bernoulli(keys.pop(), pnt) 1abcd
293 nonterminal &= try_nonterminal & allowed 1abcd
295 # update trees
296 trees = carry.trees 1abcd
297 trees = replace( 1abcd
298 trees,
299 var_tree=trees.var_tree.at[x.node].set(var),
300 split_tree=trees.split_tree.at[x.node].set(
301 jnp.where(nonterminal, split, 0)
302 ),
303 )
305 def write_push_stack() -> SamplePriorStack: 1abcd
306 """Update the stack to go to the left child."""
307 return replace( 1abcd
308 stack,
309 nonterminal=stack.nonterminal.at[x.next_depth].set(nonterminal),
310 lower=stack.lower.at[x.next_depth, :].set(lower),
311 upper=stack.upper.at[x.next_depth, :].set(upper.at[var].set(split - 1)),
312 var=stack.var.at[x.depth].set(var),
313 split=stack.split.at[x.depth].set(split),
314 )
316 def pop_push_stack() -> SamplePriorStack: 1abcd
317 """Update the stack to go to the right sibling, possibly at lower depth."""
318 var = stack.var[x.next_depth - 1] 1abcd
319 split = stack.split[x.next_depth - 1] 1abcd
320 lower = stack.lower[x.next_depth - 1, :] 1abcd
321 upper = stack.upper[x.next_depth - 1, :] 1abcd
322 return replace( 1abcd
323 stack,
324 lower=stack.lower.at[x.next_depth, :].set(lower.at[var].set(split)),
325 upper=stack.upper.at[x.next_depth, :].set(upper),
326 )
328 # update stack
329 stack = lax.cond(x.next_depth > x.depth, write_push_stack, pop_push_stack) 1abcd
331 # update carry
332 carry = replace(carry, key=keys.pop(), stack=stack, trees=trees) 1abcd
333 return carry, None 1abcd
335 carry, _ = lax.scan(loop, carry, xs) 1abcd
336 return carry.trees 1abcd
339@partial(vmap_nodoc, in_axes=(0, None, None, None))
340def sample_prior_forest(
341 keys: Key[Array, ' num_trees'],
342 max_split: UInt[Array, ' p'],
343 p_nonterminal: Float32[Array, ' d-1'],
344 sigma_mu: Float32[Array, ''],
345) -> SamplePriorTrees:
346 """Sample a set of independent trees from the BART prior.
348 Parameters
349 ----------
350 keys
351 A sequence of jax random keys, one for each tree. This determined the
352 number of trees sampled.
353 max_split
354 The maximum split value for each variable.
355 p_nonterminal
356 The prior probability of a node being non-terminal conditional on
357 its ancestors and on having available decision rules, at each depth.
358 sigma_mu
359 The prior standard deviation of each leaf.
361 Returns
362 -------
363 An object containing the generated trees.
364 """
365 return sample_prior_onetree(keys, max_split, p_nonterminal, sigma_mu) 1abcd
368@partial(jit, static_argnums=(1, 2))
369def sample_prior(
370 key: Key[Array, ''],
371 trace_length: int,
372 num_trees: int,
373 max_split: UInt[Array, ' p'],
374 p_nonterminal: Float32[Array, ' d-1'],
375 sigma_mu: Float32[Array, ''],
376) -> SamplePriorTrees:
377 """Sample independent trees from the BART prior.
379 Parameters
380 ----------
381 key
382 A jax random key.
383 trace_length
384 The number of iterations.
385 num_trees
386 The number of trees for each iteration.
387 max_split
388 The number of cutpoints along each variable.
389 p_nonterminal
390 The prior probability of a node being non-terminal conditional on
391 its ancestors and on having available decision rules, at each depth.
392 This determines the maximum depth of the trees.
393 sigma_mu
394 The prior standard deviation of each leaf.
396 Returns
397 -------
398 An object containing the generated trees, with batch shape (trace_length, num_trees).
399 """
400 keys = random.split(key, trace_length * num_trees) 1abcd
401 trees = sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu) 1abcd
402 return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees) 1abcd