Coverage for src/bartz/mcmcstep/_step.py: 99%
490 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/mcmcstep/_step.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement `step`, `step_trees`, and the accept-reject logic."""
27from dataclasses import replace
28from functools import partial
29from typing import overload
31from equinox import AbstractVar
32from jax import lax, named_call, random, vmap
33from jax import numpy as jnp
34from jax.nn import softmax
35from jax.scipy.linalg import solve_triangular
36from jax.scipy.special import gammaln, logsumexp
37from jaxtyping import Array, Bool, Float32, Int32, Key, Shaped, UInt, UInt32
39from bartz._jaxext import (
40 Module,
41 field,
42 jit,
43 split,
44 truncated_normal_onesided,
45 vmap_nodoc,
46)
47from bartz._jaxext.random import loggamma
48from bartz.grove import var_histogram
49from bartz.mcmcstep._moves import Moves, propose_moves, split_range
50from bartz.mcmcstep._reduction import ReductionConfig
51from bartz.mcmcstep._state import (
52 Forest,
53 State,
54 StepConfig,
55 chol_with_gersh,
56 get_axis_size,
57 shard_map_state,
58 split_key_for_chains,
59 vmap_chains,
60)
63@jit(donate_argnums=(1,))
64@split_key_for_chains
65@shard_map_state
66@vmap_chains
67def step(key: Key[Array, ''], state: State) -> State:
68 """
69 Do one MCMC step.
71 Parameters
72 ----------
73 key
74 A jax random key.
75 state
76 A BART mcmc state, as created by `init`.
78 Returns
79 -------
80 The new BART mcmc state.
82 Notes
83 -----
84 The memory of the input state is re-used for the output state, so the input
85 state can not be used any more after calling `step`. All this applies
86 outside of `jax.jit`.
87 """
88 keys = split(key, 4)
90 state = step_trees(keys.pop(), state)
92 if state.z is not None:
93 state = step_z(keys.pop(), state)
95 if state.error_cov_inv.nu is not None:
96 state = step_error_cov_inv(keys.pop(), state)
98 state = step_sparse(keys.pop(), state)
99 return step_config(state)
102@named_call
103def step_trees(key: Key[Array, ''], state: State) -> State:
104 """
105 Forest sampling step of BART MCMC.
107 Parameters
108 ----------
109 key
110 A jax random key.
111 state
112 A BART mcmc state, as created by `init`.
114 Returns
115 -------
116 The new BART mcmc state.
118 Notes
119 -----
120 This function zeroes the proposal counters.
121 """
122 keys = split(key)
123 moves = propose_moves(keys.pop(), state.forest)
124 return accept_moves_and_sample_leaves(keys.pop(), state, moves)
127@named_call
128def accept_moves_and_sample_leaves(
129 key: Key[Array, ''], state: State, moves: Moves
130) -> State:
131 """
132 Accept or reject the proposed moves and sample the new leaf values.
134 Parameters
135 ----------
136 key
137 A jax random key.
138 state
139 A valid BART mcmc state.
140 moves
141 The proposed moves, see `propose_moves`.
143 Returns
144 -------
145 A new (valid) BART mcmc state.
146 """
147 pso = accept_moves_parallel_stage(key, state, moves)
148 state, moves = accept_moves_sequential_stage(pso)
149 return accept_moves_final_stage(state, moves)
152class Counts(Module):
153 """Number of datapoints in the nodes involved in proposed moves for each tree."""
155 lrt: UInt[Array, '*num_trees 3']
156 """Number of datapoints in the left child, right child, and parent
157 (``= left + right``), stacked along the trailing axis."""
160class PreLkV(Module):
161 """Non-sequential terms of the likelihood ratio for each tree.
163 These terms are derived from the leaf precompute terms (`PreLf`) gathered
164 at the nodes involved in each move. The terms for the left child, right
165 child, and their join (the parent node) are stacked along the axis right
166 after the tree axis. Each term is, in the univariate case, the scalar
168 ``error_cov_inv^2 / (leaf_prior_cov_inv + n * error_cov_inv)``.
170 In the multivariate homoskedastic or scalar weight case, this is the matrix term
172 ``error_cov_inv @ inv(leaf_prior_cov_inv + n * error_cov_inv) @ error_cov_inv``.
174 In the multivariate vector-weight case, this is instead
176 ``chol(leaf_prior_cov_inv + n * error_cov_inv)``
178 ``n`` is the number of datapoints in the node, or the likelihood precision
179 scale in the heteroskedastic case.
180 """
182 # `log_sqrt_term` is declared before `lrt` so its single (union-free)
183 # annotation binds the variadic `*num_trees` axis first; otherwise the
184 # runtime typechecker can greedily mis-bind `*num_trees` against the `k`
185 # axis of the `... | ... k k` union (the multivariate and univariate
186 # layouts are rank-ambiguous).
187 log_sqrt_term: Float32[Array, '*num_trees']
188 """The logarithm of the square root term of the likelihood ratio."""
190 lrt: Float32[Array, '*num_trees 3'] | Float32[Array, '*num_trees 3 k k']
191 """Scaled full conditional variance, scaled covariance, or precision
192 cholesky, for the left child, right child, and their join."""
195class PreLf(Module):
196 """Pre-computed terms used to sample leaves from their posterior.
198 These terms can be computed in parallel across trees.
200 For each tree and leaf, the terms are scalars in the univariate case
201 (`PreLfUV`), and matrices/vectors in the multivariate case (`PreLfMV`,
202 `PreLfMVHet`).
204 Abstract base: the layouts differ in rank, so they live in concrete
205 subclasses with union-free annotations; a single class carrying a shape
206 union would make the greedy variadic mis-bind against the ``k`` axes under
207 the runtime typechecker. The concrete class also tags the meaning of
208 `mean_factor`, which drives the dispatch in `precompute_likelihood_terms`
209 and in the sequential stage. The ``num_trees`` axis is variadic so the same
210 annotations also match a per-element layout if vmapped over trees.
211 """
213 mean_factor: AbstractVar[
214 Float32[Array, '*num_trees tree_size']
215 | Float32[Array, '*num_trees k k tree_size']
216 ]
217 """The factor to be right-multiplied by the sum of the scaled residuals to
218 obtain the posterior mean."""
220 centered_leaves: AbstractVar[
221 Float32[Array, '*num_trees tree_size']
222 | Float32[Array, '*num_trees k tree_size']
223 ]
224 """The mean-zero normal values to be added to the posterior mean to
225 obtain the posterior leaf samples."""
228class PreLfUV(PreLf):
229 """`PreLf` for the univariate case."""
231 mean_factor: Float32[Array, '*num_trees tree_size']
232 """``error_cov_inv / prec``, where ``prec`` is the posterior precision of
233 the leaf."""
235 centered_leaves: Float32[Array, '*num_trees tree_size']
236 """Zero-mean normal draws with the posterior variance of each leaf."""
239class PreLfMV(PreLf):
240 """`PreLf` for the multivariate homoskedastic or scalar-weight case."""
242 mean_factor: Float32[Array, '*num_trees k k tree_size']
243 """``error_cov_inv @ inv(prec)``, where ``prec`` is the posterior precision
244 of the leaf."""
246 centered_leaves: Float32[Array, '*num_trees k tree_size']
247 """Zero-mean normal draws with the posterior covariance of each leaf."""
249 logdet_prec: Float32[Array, '*num_trees tree_size']
250 """The log-determinant of the posterior precision of each leaf."""
253class PreLfMVHet(PreLf):
254 """`PreLf` for the multivariate vector-weight case."""
256 mean_factor: Float32[Array, '*num_trees k k tree_size']
257 """The lower Cholesky factor of the posterior precision of each leaf; the
258 mean solve happens downstream in the sequential stage."""
260 centered_leaves: Float32[Array, '*num_trees k tree_size']
261 """Zero-mean normal draws with the posterior covariance of each leaf."""
264class ParallelStageOut(Module):
265 """The output of `accept_moves_parallel_stage`."""
267 state: State
268 """A partially updated BART mcmc state."""
270 moves: Moves
271 """The proposed moves, with `partial_ratio` set to `None` and
272 `log_trans_prior_ratio` set to its final value."""
274 # `num_trees` stays a fixed (non-variadic) axis: `ParallelStageOut` is always
275 # built with the tree axis present (never per tree under vmap), so the union
276 # is disambiguated by rank/dtype and needs no anchor (cf. `PreLf`).
277 prec_trees: (
278 Float32[Array, 'num_trees tree_size']
279 | UInt32[Array, 'num_trees tree_size']
280 | Float32[Array, 'num_trees k k tree_size']
281 )
282 """The likelihood precision scale in each potential or actual leaf node."""
284 prelkv: PreLkV
285 """Object with pre-computed terms of the likelihood ratios."""
287 prelf: PreLf
288 """Object with pre-computed terms of the leaf samples."""
291@named_call
292def accept_moves_parallel_stage(
293 key: Key[Array, ''], state: State, moves: Moves
294) -> ParallelStageOut:
295 """
296 Pre-compute quantities used to accept moves, in parallel across trees.
298 Parameters
299 ----------
300 key
301 A jax random key.
302 state
303 A BART mcmc state.
304 moves
305 The proposed moves, see `propose_moves`.
307 Returns
308 -------
309 An object with all that could be done in parallel.
310 """
311 # where the move is grow, modify the state like the move was accepted
312 state = replace(
313 state,
314 forest=replace(
315 state.forest,
316 var_tree=moves.var_tree,
317 leaf_indices=apply_grow_to_indices(
318 moves, state.forest.leaf_indices, state.X
319 ),
320 leaf_tree=adapt_leaf_trees_to_grow_indices(state.forest.leaf_tree, moves),
321 ),
322 )
324 # update the cached number of datapoints per leaf at the nodes involved
325 # in the moves
326 if (
327 state.forest.min_points_per_decision_node is not None
328 or state.forest.min_points_per_leaf is not None
329 or state.prec_scale is None
330 ):
331 assert state.forest.count_tree is not None
332 count_trees, move_counts = compute_count_trees(
333 state.forest.count_tree, state.forest.leaf_indices, moves, state.config
334 )
335 state = replace(state, forest=replace(state.forest, count_tree=count_trees))
337 # affluence of the nodes touched by each move: whether they would be
338 # growable as leaves (admissible rule + enough datapoints). The children
339 # must also lie within the heap, i.e. not be at the bottom level; the
340 # parent always does. These feed the transition ratio and the final
341 # `affluence_tree` update.
342 _, half = state.forest.var_tree.shape
343 lrt_affluent = (moves.lrt_nodes < half) & moves.lrt_growable
344 if state.forest.min_points_per_decision_node is not None:
345 lrt_affluent &= move_counts.lrt >= state.forest.min_points_per_decision_node
346 moves = replace(moves, lrt_affluent=lrt_affluent)
348 # veto grove move if new leaves don't have enough datapoints
349 if state.forest.min_points_per_leaf is not None:
350 moves = replace(
351 moves,
352 allowed=moves.allowed
353 & jnp.all(
354 move_counts.lrt[..., :2] >= state.forest.min_points_per_leaf, axis=-1
355 ),
356 )
358 # update the cached number of datapoints per leaf, weighted by error
359 # precision scale, at the nodes involved in the moves
360 if state.prec_scale is None:
361 prec_trees = count_trees
362 else:
363 assert state.forest.prec_tree is not None
364 prec_trees = compute_prec_trees(
365 state.forest.prec_tree,
366 state.prec_scale,
367 state.forest.leaf_indices,
368 moves,
369 state.config,
370 )
371 state = replace(state, forest=replace(state.forest, prec_tree=prec_trees))
373 # compute some missing information about moves
374 moves = complete_ratio(moves, state.forest.p_nonterminal)
375 save_ratios = state.forest.log_likelihood is not None
376 state = replace(
377 state,
378 forest=replace(
379 state.forest,
380 grow_prop_count=jnp.sum(moves.grow),
381 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
382 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
383 ),
384 )
386 prelf = precompute_leaf_terms(
387 key, prec_trees, state.error_cov_inv.value, state.forest.leaf_prior_cov_inv
388 )
389 prelkv = precompute_likelihood_terms(
390 state.error_cov_inv.value, state.forest.leaf_prior_cov_inv, prelf, moves
391 )
393 return ParallelStageOut(
394 state=state, moves=moves, prec_trees=prec_trees, prelkv=prelkv, prelf=prelf
395 )
398@named_call
399def apply_grow_to_indices(
400 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
401) -> UInt[Array, 'num_trees n']:
402 """
403 Update the leaf indices to apply a grow move.
405 Parameters
406 ----------
407 moves
408 The proposed moves, see `propose_moves`.
409 leaf_indices
410 The index of the leaf each datapoint falls into.
411 X
412 The predictors matrix.
414 Returns
415 -------
416 The updated leaf indices.
417 """
418 return _apply_grow_to_indices(moves, leaf_indices, X)
421@partial(vmap_nodoc, in_axes=(0, 0, None))
422def _apply_grow_to_indices(
423 moves: Moves, leaf_indices: UInt[Array, ' n'], X: UInt[Array, 'p n']
424) -> UInt[Array, ' n']:
425 """Implement `apply_grow_to_indices`."""
426 left_child = moves.lrt_nodes[0].astype(leaf_indices.dtype)
427 x: UInt[Array, ' n'] = X[moves.grow_var, :]
428 go_right = x >= moves.grow_split
429 tree_size = jnp.array(2 * moves.var_tree.size)
430 node_to_update = jnp.where(moves.grow, moves.lrt_nodes[2], tree_size)
431 return jnp.where(
432 leaf_indices == node_to_update, left_child + go_right, leaf_indices
433 )
436def _fill_lrt_total(lrt: Shaped[Array, '*k_k 3']) -> Shaped[Array, '*k_k 3']:
437 """Set the total slot of stacked (left, right, total) values to left + right.
439 The left and right slots pass through unchanged, the stale value in the
440 total slot is ignored. Implemented with fusable elementwise operations.
441 """
442 total = lrt[..., 0] + lrt[..., 1]
443 return jnp.where(jnp.arange(3) == 2, total[..., None], lrt)
446@overload
447def _compute_count_or_prec_trees(
448 prec_scale: None,
449 trees: UInt32[Array, 'num_trees tree_size'],
450 leaf_indices: UInt[Array, 'num_trees n'],
451 moves: Moves,
452 config: StepConfig,
453) -> tuple[UInt32[Array, 'num_trees tree_size'], Counts]: ...
456@overload
457def _compute_count_or_prec_trees(
458 prec_scale: Float32[Array, ' n'] | Float32[Array, 'k k n'],
459 trees: Float32[Array, 'num_trees tree_size']
460 | Float32[Array, 'num_trees k k tree_size'],
461 leaf_indices: UInt[Array, 'num_trees n'],
462 moves: Moves,
463 config: StepConfig,
464) -> (
465 tuple[Float32[Array, 'num_trees tree_size'], None]
466 | tuple[Float32[Array, 'num_trees k k tree_size'], None]
467): ...
470def _compute_count_or_prec_trees(
471 prec_scale: Float32[Array, ' n'] | Float32[Array, 'k k n'] | None,
472 trees: UInt32[Array, 'num_trees tree_size']
473 | Float32[Array, 'num_trees tree_size']
474 | Float32[Array, 'num_trees k k tree_size'],
475 leaf_indices: UInt[Array, 'num_trees n'],
476 moves: Moves,
477 config: StepConfig,
478) -> (
479 tuple[UInt32[Array, 'num_trees tree_size'], Counts]
480 | tuple[Float32[Array, 'num_trees tree_size'], None]
481 | tuple[Float32[Array, 'num_trees k k tree_size'], None]
482):
483 """Implement `compute_count_trees` and `compute_prec_trees`."""
484 if config.prec_count_num_trees is None: 484 ↛ 488line 484 didn't jump to line 488 because the condition on line 484 was always true
485 compute = vmap(_compute_count_or_prec_tree, in_axes=(None, 0, 0, 0, None))
486 return compute(prec_scale, trees, leaf_indices, moves, config)
488 def compute(
489 args: tuple[
490 UInt32[Array, ' tree_size']
491 | Float32[Array, ' tree_size']
492 | Float32[Array, 'k k tree_size'],
493 UInt[Array, ' n'],
494 Moves,
495 ],
496 ) -> (
497 tuple[UInt32[Array, ' tree_size'], Counts]
498 | tuple[Float32[Array, ' tree_size'], None]
499 | tuple[Float32[Array, 'k k tree_size'], None]
500 ):
501 tree, leaf_indices, moves = args
502 return _compute_count_or_prec_tree(
503 prec_scale, tree, leaf_indices, moves, config
504 )
506 return lax.map(
507 compute, (trees, leaf_indices, moves), batch_size=config.prec_count_num_trees
508 )
511def _compute_count_or_prec_tree(
512 prec_scale: Float32[Array, ' n'] | Float32[Array, 'k k n'] | None,
513 tree: UInt32[Array, ' tree_size']
514 | Float32[Array, ' tree_size']
515 | Float32[Array, 'k k tree_size'],
516 leaf_indices: UInt[Array, ' n'],
517 moves: Moves,
518 config: StepConfig,
519) -> (
520 tuple[UInt32[Array, ' tree_size'], Counts]
521 | tuple[Float32[Array, ' tree_size'], None]
522 | tuple[Float32[Array, 'k k tree_size'], None]
523):
524 """Update the cached count or precision tree for a single tree."""
525 (tree_size,) = moves.var_tree.shape
526 tree_size *= 2
528 if prec_scale is None:
529 value = 1
530 dtype = jnp.uint32
531 reduction_config = config.count_reduction_config
532 else:
533 value = prec_scale
534 dtype = jnp.float32
535 reduction_config = config.prec_reduction_config
537 # the cached tree is valid at the leaves, and the move only changes the
538 # values at the nodes it involves, so reduce into the move's children alone:
539 # the contiguous pair (left, right) = (2 * node, 2 * node + 1) = lrt_nodes[:2]
540 lr = reduction_config._reduce( # noqa: SLF001
541 value,
542 leaf_indices,
543 size=tree_size,
544 subset_start=moves.lrt_nodes[0],
545 subset_length=2,
546 dtype=dtype,
547 data_sharded=config.data_sharded,
548 )
550 # write the children sums into the cache along with their total at the
551 # parent node (a non-leaf in the post-grow indexing the reduce runs on);
552 # the weighted version of the counts is not needed because the likelihood
553 # terms are derived from the leaf terms
554 total = lr[..., 0] + lr[..., 1]
555 lrt = jnp.concatenate([lr, total[..., None]], axis=-1)
556 tree = tree.at[..., moves.lrt_nodes].set(lrt)
558 if prec_scale is None:
559 return tree, Counts(lrt=lrt)
560 else:
561 return tree, None
564@named_call
565def compute_count_trees(
566 count_trees: UInt32[Array, 'num_trees tree_size'],
567 leaf_indices: UInt[Array, 'num_trees n'],
568 moves: Moves,
569 config: StepConfig,
570) -> tuple[UInt32[Array, 'num_trees tree_size'], Counts]:
571 """
572 Update the cached number of datapoints per leaf at the moves' nodes.
574 Parameters
575 ----------
576 count_trees
577 The cached number of points in each leaf; valid at the leaves of the
578 pre-move trees.
579 leaf_indices
580 The index of the leaf each datapoint falls into, with the deeper version
581 of the tree (post-GROW, pre-PRUNE).
582 moves
583 The proposed moves, see `propose_moves`.
584 config
585 The MCMC configuration.
587 Returns
588 -------
589 count_trees : UInt32[Array, 'num_trees tree_size']
590 The updated cache, valid in each potential or actual leaf node.
591 counts : Counts
592 The counts of the number of points in the leaves grown or pruned by the
593 moves.
594 """
595 return _compute_count_or_prec_trees(None, count_trees, leaf_indices, moves, config)
598@named_call
599def compute_prec_trees(
600 prec_trees: Float32[Array, 'num_trees tree_size']
601 | Float32[Array, 'num_trees k k tree_size'],
602 prec_scale: Float32[Array, ' n'] | Float32[Array, 'k k n'],
603 leaf_indices: UInt[Array, 'num_trees n'],
604 moves: Moves,
605 config: StepConfig,
606) -> Float32[Array, 'num_trees tree_size'] | Float32[Array, 'num_trees k k tree_size']:
607 """
608 Update the cached per-leaf likelihood precision scale at the moves' nodes.
610 Parameters
611 ----------
612 prec_trees
613 The cached likelihood precision scale in each leaf; valid at the leaves
614 of the pre-move trees.
615 prec_scale
616 The scale of the precision of the error on each datapoint.
617 leaf_indices
618 The index of the leaf each datapoint falls into, with the deeper version
619 of the tree (post-GROW, pre-PRUNE).
620 moves
621 The proposed moves, see `propose_moves`.
622 config
623 The MCMC configuration.
625 Returns
626 -------
627 The updated cache, valid in each potential or actual leaf node.
628 """
629 trees, _ = _compute_count_or_prec_trees(
630 prec_scale, prec_trees, leaf_indices, moves, config
631 )
632 return trees
635@partial(vmap_nodoc, in_axes=(0, None))
636def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' tree_size']) -> Moves:
637 """
638 Complete non-likelihood MH ratio calculation.
640 This function adds the probability of choosing a prune move over the grow
641 move in the inverse transition, and the prior odds that the modified node
642 is nonterminal with terminal children.
644 Parameters
645 ----------
646 moves
647 The proposed moves. Must have already been updated to keep into account
648 the thresholds on the number of datapoints per node, this happens in
649 `accept_moves_parallel_stage`.
650 p_nonterminal
651 The a priori probability of each node being nonterminal conditional on
652 its ancestors, including at the maximum depth where it should be zero.
654 Returns
655 -------
656 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
657 """
658 assert moves.lrt_affluent is not None
660 # can the children be grown by the proposal? `lrt_affluent` already folds
661 # in the `min_points_per_decision_node` threshold, because the grow
662 # proposal draws from the pool of leaves that pass it. This enters only the
663 # transition probability.
665 # p_prune if grow
666 other_growable_leaves = moves.num_growable >= 2
667 grow_again_allowed = other_growable_leaves | jnp.any(moves.lrt_affluent[:2])
668 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1.0)
670 # p_prune if prune
671 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
673 # select p_prune
674 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune)
676 # prior odds of the node being nonterminal, times the prior probability of
677 # both children being terminal. The children terminality uses the
678 # admissibility ignoring counts, because the standard BART prior conditions
679 # the non-terminal probability only on the existence of available decision
680 # rules, not on the count thresholds (which are a bartz proposal-efficiency
681 # device, not part of the target distribution). The fill value avoids a 0
682 # and then an inf in the log if the move is not allowed and the indices are
683 # out of bounds.
684 pnt = p_nonterminal.at[moves.lrt_nodes].get(mode='fill', fill_value=0.5)
685 prior_ratio = pnt[2] / (1 - pnt[2]) * jnp.prod(1 - pnt[:2] * moves.lrt_growable[:2])
687 assert moves.partial_ratio is not None
688 return replace(
689 moves,
690 log_trans_prior_ratio=jnp.log(moves.partial_ratio * prior_ratio * p_prune),
691 partial_ratio=None,
692 )
695@named_call
696def adapt_leaf_trees_to_grow_indices(
697 leaf_trees: Float32[Array, 'num_trees tree_size']
698 | Float32[Array, 'num_trees k tree_size'],
699 moves: Moves,
700) -> Float32[Array, 'num_trees tree_size'] | Float32[Array, 'num_trees k tree_size']:
701 """
702 Modify leaves such that post-grow indices work on the original tree.
704 The value of the leaf to grow is copied to what would be its children if the
705 grow move was accepted.
707 Parameters
708 ----------
709 leaf_trees
710 The leaf values.
711 moves
712 The proposed moves, see `propose_moves`.
714 Returns
715 -------
716 The modified leaf values.
717 """
718 return _adapt_leaf_trees_to_grow_indices(leaf_trees, moves)
721@vmap_nodoc
722def _adapt_leaf_trees_to_grow_indices(
723 leaf_trees: Float32[Array, ' tree_size'] | Float32[Array, ' k tree_size'],
724 moves: Moves,
725) -> Float32[Array, ' tree_size'] | Float32[Array, ' k tree_size']:
726 """Implement `adapt_leaf_trees_to_grow_indices`."""
727 # the parent slot is written back unchanged to share a single scatter
728 values_at_node = leaf_trees[..., moves.lrt_nodes[2]]
729 return leaf_trees.at[
730 ..., jnp.where(moves.grow, moves.lrt_nodes, leaf_trees.size)
731 ].set(values_at_node[..., None])
734def _logdet_from_chol(L: Float32[Array, '... k k']) -> Float32[Array, '...']:
735 """Compute logdet of A = LL' via Cholesky (sum of log of diag^2)."""
736 diags: Float32[Array, '... k'] = jnp.diagonal(L, axis1=-2, axis2=-1)
737 return 2.0 * jnp.sum(jnp.log(diags), axis=-1)
740def compute_B(
741 error_cov_inv: Float32[Array, 'k k'], resid: Float32[Array, 'k k *tree_size']
742) -> Float32[Array, ' k *tree_size']:
743 """Compute the leaf score from the leaf weighted sum of residuals."""
744 return jnp.einsum('ab,ab...->a...', error_cov_inv, resid)
747def _precompute_leaf_terms_uv(
748 key: Key[Array, ''],
749 prec_trees: Float32[Array, 'num_trees tree_size']
750 | UInt32[Array, 'num_trees tree_size'],
751 error_cov_inv: Float32[Array, ''],
752 leaf_prior_cov_inv: Float32[Array, ''],
753 z: Float32[Array, 'num_trees tree_size'] | None = None,
754) -> PreLfUV:
755 prec_lk = prec_trees * error_cov_inv
756 var_post = jnp.reciprocal(prec_lk + leaf_prior_cov_inv)
757 if z is None:
758 z = random.normal(key, prec_trees.shape, error_cov_inv.dtype)
759 return PreLfUV(
760 mean_factor=var_post * error_cov_inv,
761 # | mean = mean_lk * prec_lk * var_post
762 # | resid_tree = mean_lk * prec_tree -->
763 # | --> mean_lk = resid_tree / prec_tree (kind of)
764 # | mean_factor =
765 # | = mean / resid_tree =
766 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
767 # | = 1 / prec_tree * prec_tree / sigma2 * var_post =
768 # | = var_post / sigma2
769 centered_leaves=z * jnp.sqrt(var_post),
770 )
773def _precompute_leaf_terms_mv(
774 key: Key[Array, ''],
775 prec_trees: Float32[Array, 'num_trees tree_size']
776 | UInt32[Array, 'num_trees tree_size'],
777 error_cov_inv: Float32[Array, 'k k'],
778 leaf_prior_cov_inv: Float32[Array, 'k k'],
779 z: Float32[Array, 'num_trees k tree_size'] | None = None,
780) -> PreLfMV:
781 num_trees, tree_size = prec_trees.shape
782 k, _ = error_cov_inv.shape
783 if z is None: 783 ↛ 786line 783 didn't jump to line 786 because the condition on line 783 was always true
784 z = random.normal(key, (num_trees, k, tree_size))
786 def per_leaf(
787 prec: Float32[Array, ''] | UInt32[Array, ''], z: Float32[Array, ' k']
788 ) -> tuple[Float32[Array, 'k k'], Float32[Array, ' k'], Float32[Array, '']]:
789 L_prec = chol_with_gersh(leaf_prior_cov_inv + prec * error_cov_inv)
790 Y = solve_triangular(L_prec, error_cov_inv, lower=True)
791 mean_factor = solve_triangular(L_prec, Y, trans='T', lower=True).mT
792 centered = solve_triangular(L_prec, z[:, None], trans='T', lower=True).squeeze(
793 -1
794 )
795 # only a few leaves per tree end up using their logdet, but reducing
796 # right away is lighter on memory than storing diagonals for later
797 return mean_factor, centered, _logdet_from_chol(L_prec)
799 # vmap over trees then over leaves; the leaf axis is trailing in both
800 # `prec_trees`/`z` (in_axes) and the stored output (out_axes=-1)
801 return PreLfMV(*vmap(vmap(per_leaf, in_axes=(0, -1), out_axes=-1))(prec_trees, z))
804def _precompute_leaf_terms_mv_het(
805 key: Key[Array, ''],
806 prec_trees: Float32[Array, 'num_trees k k tree_size'],
807 error_cov_inv: Float32[Array, 'k k'],
808 leaf_prior_cov_inv: Float32[Array, 'k k'],
809 z: Float32[Array, 'num_trees k tree_size'] | None = None,
810) -> PreLfMVHet:
811 num_trees, k, _, tree_size = prec_trees.shape
812 if z is None: 812 ↛ 815line 812 didn't jump to line 815 because the condition on line 812 was always true
813 z = random.normal(key, (num_trees, k, tree_size))
815 def per_leaf(
816 prec: Float32[Array, 'k k'], z: Float32[Array, ' k']
817 ) -> tuple[Float32[Array, 'k k'], Float32[Array, ' k']]:
818 # mean_factor stores the precision cholesky itself; the mean solve happens
819 # downstream in `accept_move_and_sample_leaves`
820 L_prec = chol_with_gersh(leaf_prior_cov_inv + error_cov_inv * prec)
821 centered = solve_triangular(L_prec, z[:, None], trans='T', lower=True).squeeze(
822 -1
823 )
824 return L_prec, centered
826 # vmap over trees then over leaves; the leaf axis is trailing in both
827 # `prec_trees`/`z` (in_axes=-1) and the stored output (out_axes=-1)
828 return PreLfMVHet(
829 *vmap(vmap(per_leaf, in_axes=(-1, -1), out_axes=-1))(prec_trees, z)
830 )
833@named_call
834def precompute_leaf_terms(
835 key: Key[Array, ''],
836 prec_trees: Float32[Array, 'num_trees tree_size']
837 | UInt32[Array, 'num_trees tree_size']
838 | Float32[Array, 'num_trees k k tree_size'],
839 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
840 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
841 z: Float32[Array, 'num_trees tree_size']
842 | Float32[Array, 'num_trees k tree_size']
843 | None = None,
844) -> PreLf:
845 """
846 Pre-compute terms used to sample leaves from their posterior.
848 Handles both univariate and multivariate cases based on the shape of the
849 input arrays.
851 Parameters
852 ----------
853 key
854 A jax random key.
855 prec_trees
856 The likelihood precision scale in each potential or actual leaf node.
857 error_cov_inv
858 The inverse error variance (univariate) or the inverse of error
859 covariance matrix (multivariate). For univariate case, this is the
860 inverse global error variance factor if `prec_scale` is set.
861 leaf_prior_cov_inv
862 The inverse prior variance of each leaf (univariate) or the inverse of
863 prior covariance matrix of each leaf (multivariate).
864 z
865 Optional standard normal noise to use for sampling the centered leaves.
866 This is intended for testing purposes only.
868 Returns
869 -------
870 Pre-computed terms for leaf sampling.
871 """
872 if error_cov_inv.ndim == 0:
873 return _precompute_leaf_terms_uv(
874 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
875 )
876 elif prec_trees.ndim == 4:
877 return _precompute_leaf_terms_mv_het(
878 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
879 )
880 else:
881 return _precompute_leaf_terms_mv(
882 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
883 )
886@vmap_nodoc
887def _gather_lrt(
888 leaf_values: Float32[Array, '*k_k tree_size'], lrt_nodes: Int32[Array, ' 3']
889) -> Float32[Array, ' 3 *k_k']:
890 """Gather per-tree leaf values at the left child, right child, and parent."""
891 return jnp.moveaxis(leaf_values[..., lrt_nodes], -1, 0)
894def _precompute_likelihood_terms_uv(
895 error_cov_inv: Float32[Array, ''],
896 leaf_prior_cov_inv: Float32[Array, ''],
897 prelf: PreLfUV,
898 lrt_nodes: Int32[Array, 'num_trees 3'],
899) -> PreLkV:
900 # mean_factor is error_cov_inv / prec, complete the sandwich
901 lrt = error_cov_inv * _gather_lrt(prelf.mean_factor, lrt_nodes)
902 # the same value with the prior-only precision, computed with the same
903 # operations as in `_precompute_leaf_terms_uv` such that it matches `lrt`
904 # bitwise on empty nodes and the ratio is exactly 1 without data
905 prior_lrt = error_cov_inv * (jnp.reciprocal(leaf_prior_cov_inv) * error_cov_inv)
906 log_sqrt_term = jnp.log(lrt[..., 0] * lrt[..., 1] / (prior_lrt * lrt[..., 2])) / 2
907 return PreLkV(lrt=lrt, log_sqrt_term=log_sqrt_term)
910def _precompute_likelihood_terms_mv(
911 error_cov_inv: Float32[Array, 'k k'],
912 leaf_prior_cov_inv: Float32[Array, 'k k'],
913 prelf: PreLfMV,
914 lrt_nodes: Int32[Array, 'num_trees 3'],
915) -> PreLkV:
916 logdet_prior = _logdet_from_chol(chol_with_gersh(leaf_prior_cov_inv))
917 logdet_prec = _gather_lrt(prelf.logdet_prec, lrt_nodes)
918 log_sqrt_term = (logdet_prior + logdet_prec @ jnp.array([-1.0, -1.0, 1.0])) / 2
920 # mean_factor is error_cov_inv @ inv(prec), complete the sandwich
921 mean_factor = _gather_lrt(prelf.mean_factor, lrt_nodes) # (num_trees, 3, k, k)
922 return PreLkV(lrt=mean_factor @ error_cov_inv, log_sqrt_term=log_sqrt_term)
925def _precompute_likelihood_terms_mv_het(
926 leaf_prior_cov_inv: Float32[Array, 'k k'],
927 prelf: PreLfMVHet,
928 lrt_nodes: Int32[Array, 'num_trees 3'],
929) -> PreLkV:
930 logdet_prior = _logdet_from_chol(chol_with_gersh(leaf_prior_cov_inv))
932 # mean_factor is the precision cholesky itself
933 L = _gather_lrt(prelf.mean_factor, lrt_nodes) # (num_trees, 3, k, k)
934 log_sqrt_term = (
935 logdet_prior + _logdet_from_chol(L) @ jnp.array([-1.0, -1.0, 1.0])
936 ) / 2
937 return PreLkV(lrt=L, log_sqrt_term=log_sqrt_term)
940@named_call
941def precompute_likelihood_terms(
942 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
943 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
944 prelf: PreLf,
945 moves: Moves,
946) -> PreLkV:
947 """
948 Pre-compute terms used in the likelihood ratio of the acceptance step.
950 The likelihood ratio terms are mostly a subset of the leaf sampling terms,
951 so they are derived from `prelf`, gathered at the nodes involved in the
952 moves.
954 Parameters
955 ----------
956 error_cov_inv
957 The inverse error variance (univariate) or the inverse of the error
958 covariance matrix (multivariate). This is the inverse global error
959 variance factor if `prec_scale` is set.
960 leaf_prior_cov_inv
961 The inverse prior variance of each leaf (univariate) or the inverse of
962 prior covariance matrix of each leaf (multivariate).
963 prelf
964 The pre-computed terms of the leaf sampling, see `precompute_leaf_terms`.
965 moves
966 The proposed moves, see `propose_moves`.
968 Returns
969 -------
970 Pre-computed terms of the likelihood ratio, one per tree.
971 """
972 if isinstance(prelf, PreLfUV):
973 return _precompute_likelihood_terms_uv(
974 error_cov_inv, leaf_prior_cov_inv, prelf, moves.lrt_nodes
975 )
976 elif isinstance(prelf, PreLfMVHet):
977 return _precompute_likelihood_terms_mv_het(
978 leaf_prior_cov_inv, prelf, moves.lrt_nodes
979 )
980 else:
981 assert isinstance(prelf, PreLfMV)
982 return _precompute_likelihood_terms_mv(
983 error_cov_inv, leaf_prior_cov_inv, prelf, moves.lrt_nodes
984 )
987@named_call
988def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
989 """
990 Accept/reject the moves one tree at a time.
992 This is the most performance-sensitive function because it contains all and
993 only the parts of the algorithm that can not be parallelized across trees.
995 Parameters
996 ----------
997 pso
998 The output of `accept_moves_parallel_stage`.
1000 Returns
1001 -------
1002 state : State
1003 A partially updated BART mcmc state.
1004 moves : Moves
1005 The accepted/rejected moves, with `acc` and `to_prune` set.
1006 """
1008 def loop(
1009 resid: Float32[Array, ' n'] | Float32[Array, ' k n'], pt: SeqStageInPerTree
1010 ) -> tuple[
1011 Float32[Array, ' n'] | Float32[Array, ' k n'],
1012 tuple[
1013 Float32[Array, ' tree_size'] | Float32[Array, ' k tree_size'],
1014 Bool[Array, ''],
1015 Bool[Array, ''],
1016 Float32[Array, ''] | None,
1017 ],
1018 ]:
1019 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves(
1020 resid,
1021 SeqStageInAllTrees(
1022 pso.state.X,
1023 pso.state.config.resid_reduction_config,
1024 pso.state.config.data_sharded,
1025 pso.state.prec_scale,
1026 pso.state.forest.log_likelihood is not None,
1027 pso.state.error_cov_inv.value
1028 if isinstance(pso.prelf, PreLfMVHet)
1029 else None,
1030 ),
1031 pt,
1032 )
1033 return resid, (leaf_tree, acc, to_prune, lkratio)
1035 pts = SeqStageInPerTree(
1036 pso.state.forest.leaf_tree,
1037 pso.prec_trees,
1038 pso.moves,
1039 pso.state.forest.leaf_indices,
1040 pso.prelkv,
1041 pso.prelf,
1042 )
1043 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(
1044 loop, pso.state.resid, pts, unroll=pso.state.config.sequential_unroll
1045 )
1047 state = replace(
1048 pso.state,
1049 resid=resid,
1050 forest=replace(pso.state.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
1051 )
1052 moves = replace(pso.moves, acc=acc, to_prune=to_prune)
1054 return state, moves
1057class SeqStageInAllTrees(Module):
1058 """The inputs to `accept_move_and_sample_leaves` that are shared by all trees."""
1060 X: UInt[Array, 'p n']
1061 """The predictors."""
1063 resid_reduction_config: ReductionConfig
1064 """How to sum the residuals in each leaf."""
1066 data_sharded: bool = field(static=True)
1067 """Whether the data axis is sharded across devices."""
1069 prec_scale: Float32[Array, ' n'] | Float32[Array, 'k k n'] | None
1070 """The scale of the precision of the error on each datapoint. If None, it
1071 is assumed to be 1."""
1073 save_ratios: bool = field(static=True)
1074 """Whether to save the acceptance ratios."""
1076 error_cov_inv: Float32[Array, 'k k'] | None
1077 """The global error precision scale. Set only in the multivariate
1078 vector-weight case, where the sequential stage needs it to compute the
1079 leaf scores."""
1082class SeqStageInPerTree(Module):
1083 """The inputs to `accept_move_and_sample_leaves` that are separate for each tree."""
1085 # Although consumed one tree at a time by `lax.scan`, this object is only
1086 # ever constructed in the stacked (batched) form fed to the scan, so
1087 # `num_trees` stays a fixed (non-variadic) leading axis disambiguated by
1088 # rank/dtype (cf. `ParallelStageOut`); the per-tree slices reach `loop` via
1089 # scan, which does not re-run `__init__`.
1090 leaf_tree: (
1091 Float32[Array, 'num_trees tree_size'] | Float32[Array, 'num_trees k tree_size']
1092 )
1093 """The leaf values of the trees."""
1095 prec_tree: (
1096 Float32[Array, 'num_trees tree_size']
1097 | UInt32[Array, 'num_trees tree_size']
1098 | Float32[Array, 'num_trees k k tree_size']
1099 )
1100 """The likelihood precision scale in each potential or actual leaf node."""
1102 move: Moves
1103 """The proposed move, see `propose_moves`."""
1105 leaf_indices: UInt[Array, 'num_trees n']
1106 """The leaf indices for the largest version of the tree compatible with
1107 the move."""
1109 prelkv: PreLkV
1110 """The pre-computed terms of the likelihood ratio which are specific to the tree."""
1112 prelf: PreLf
1113 """The pre-computed terms of the leaf sampling which are specific to the tree."""
1116@named_call
1117def accept_move_and_sample_leaves(
1118 resid: Float32[Array, ' n'] | Float32[Array, ' k n'],
1119 at: SeqStageInAllTrees,
1120 pt: SeqStageInPerTree,
1121) -> tuple[
1122 Float32[Array, ' n'] | Float32[Array, ' k n'],
1123 Float32[Array, ' tree_size'] | Float32[Array, ' k tree_size'],
1124 Bool[Array, ''],
1125 Bool[Array, ''],
1126 Float32[Array, ''] | None,
1127]:
1128 """
1129 Accept or reject a proposed move and sample the new leaf values.
1131 Parameters
1132 ----------
1133 resid
1134 The residuals (data minus forest value).
1135 at
1136 The inputs that are the same for all trees.
1137 pt
1138 The inputs that are separate for each tree.
1140 Returns
1141 -------
1142 resid : Float32[Array, 'n'] | Float32[Array, ' k n']
1143 The updated residuals (data minus forest value).
1144 leaf_tree : Float32[Array, 'tree_size'] | Float32[Array, ' k tree_size']
1145 The new leaf values of the tree.
1146 acc : Bool[Array, '']
1147 Whether the move was accepted.
1148 to_prune : Bool[Array, '']
1149 Whether, to reflect the acceptance status of the move, the state should
1150 be updated by pruning the leaves involved in the move.
1151 log_lk_ratio : Float32[Array, ''] | None
1152 The logarithm of the likelihood ratio for the move. `None` if not to be
1153 saved.
1154 """
1155 # sum residuals in each leaf, in tree proposed by grow move
1156 if at.prec_scale is None:
1157 scaled_resid = resid
1158 else:
1159 scaled_resid = resid * at.prec_scale
1161 tree_size = pt.leaf_tree.shape[-1] # 2**d
1163 resid_tree = sum_resid(
1164 scaled_resid,
1165 pt.leaf_indices,
1166 tree_size,
1167 at.resid_reduction_config,
1168 at.data_sharded,
1169 )
1171 # subtract starting tree from function
1172 resid_tree += pt.prec_tree * pt.leaf_tree
1174 # sum residuals in parent node modified by move and compute likelihood;
1175 # the children slots are written back unchanged to share a single scatter
1176 assert pt.move.lrt_nodes.dtype == jnp.int32
1177 resid_lrt = _fill_lrt_total(resid_tree[..., pt.move.lrt_nodes])
1178 resid_tree = resid_tree.at[..., pt.move.lrt_nodes].set(resid_lrt)
1180 log_lk_ratio = compute_likelihood_ratio(resid_lrt, pt.prelkv, at.error_cov_inv)
1182 # calculate accept/reject ratio
1183 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
1184 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
1185 if not at.save_ratios:
1186 log_lk_ratio = None
1188 # determine whether to accept the move
1189 acc = pt.move.allowed & (pt.move.logu <= log_ratio)
1191 # compute leaves posterior and sample leaves
1192 if at.error_cov_inv is not None:
1193 # multivariate w/ vector weights
1194 b_tree = compute_B(at.error_cov_inv, resid_tree) # (k, 2**d)
1195 l_lead = jnp.moveaxis(pt.prelf.mean_factor, -1, 0) # (2**d, k, k)
1196 b_lead = b_tree.T[:, :, None] # (2**d, k, 1)
1197 y = solve_triangular(l_lead, b_lead, lower=True)
1198 mu = solve_triangular(l_lead, y, lower=True, trans='T').squeeze(-1)
1199 mean_post = mu.T # (k, 2**d)
1200 elif resid.ndim > 1:
1201 # multivariate homoskedastic or scalar weights
1202 mean_post = jnp.einsum('kil,kl->il', pt.prelf.mean_factor, resid_tree)
1203 else:
1204 # univariate
1205 mean_post = resid_tree * pt.prelf.mean_factor
1206 leaf_tree = mean_post + pt.prelf.centered_leaves
1208 # copy leaves around such that the leaf indices point to the correct leaf;
1209 # the parent slot is written back unchanged to share a single scatter
1210 to_prune = acc ^ pt.move.grow
1211 leaf_tree = leaf_tree.at[
1212 ..., jnp.where(to_prune, pt.move.lrt_nodes, tree_size)
1213 ].set(leaf_tree[..., pt.move.lrt_nodes[2], None])
1214 # replace old tree with new tree in function values
1215 resid += (pt.leaf_tree - leaf_tree)[..., pt.leaf_indices]
1217 return resid, leaf_tree, acc, to_prune, log_lk_ratio
1220@named_call
1221def sum_resid(
1222 scaled_resid: (
1223 Float32[Array, ' n'] | Float32[Array, 'k n'] | Float32[Array, 'k k n']
1224 ),
1225 leaf_indices: UInt[Array, ' n'],
1226 tree_size: int,
1227 reduction_config: ReductionConfig,
1228 data_sharded: bool,
1229) -> (
1230 Float32[Array, ' {tree_size}']
1231 | Float32[Array, 'k {tree_size}']
1232 | Float32[Array, 'k k {tree_size}']
1233):
1234 """
1235 Sum the residuals in each leaf.
1237 Parameters
1238 ----------
1239 scaled_resid
1240 The residuals (data minus forest value) multiplied by the error
1241 precision scale.
1242 leaf_indices
1243 The leaf indices of the tree (in which leaf each data point falls into).
1244 tree_size
1245 The size of the tree array (2 ** d).
1246 reduction_config
1247 How to sum the residuals in each leaf.
1248 data_sharded
1249 Whether the data axis is sharded; if true, the result is psum-reduced
1250 across the ``'data'`` axis of the enclosing `shard_map`.
1252 Returns
1253 -------
1254 The per-leaf sum, with the same leading dimensions as ``scaled_resid`` and a trailing axis over the leaves.
1255 """
1256 return reduction_config._reduce( # noqa: SLF001
1257 scaled_resid,
1258 leaf_indices,
1259 size=tree_size,
1260 dtype=jnp.float32,
1261 data_sharded=data_sharded,
1262 )
1265def _compute_likelihood_ratio_uv(
1266 resid_lrt: Float32[Array, ' 3'], prelkv: PreLkV
1267) -> Float32[Array, '']:
1268 # quadratic form r * v * r for each of the (left, right, total) terms
1269 qf = resid_lrt * resid_lrt * prelkv.lrt
1270 exp_term = 0.5 * (qf @ jnp.array([1.0, 1.0, -1.0]))
1271 return prelkv.log_sqrt_term + exp_term
1274def _compute_likelihood_ratio_mv(
1275 resid_lrt: Float32[Array, 'k 3'], prelkv: PreLkV
1276) -> Float32[Array, '']:
1277 # quadratic form r' M r for each of the (left, right, total) terms
1278 qf = jnp.einsum('it,tij,jt->t', resid_lrt, prelkv.lrt, resid_lrt)
1279 exp_term = 0.5 * (qf @ jnp.array([1.0, 1.0, -1.0]))
1280 return prelkv.log_sqrt_term + exp_term
1283def _compute_likelihood_ratio_mv_het(
1284 resid_lrt: Float32[Array, 'k k 3'],
1285 error_cov_inv: Float32[Array, 'k k'],
1286 prelkv: PreLkV,
1287) -> Float32[Array, '']:
1288 b = compute_B(error_cov_inv, resid_lrt) # (k, 3)
1289 y = solve_triangular(prelkv.lrt, b.T[..., None], lower=True).squeeze(-1) # (3, k)
1290 qf = jnp.einsum('ti,ti->t', y, y)
1291 exp_term = 0.5 * (qf @ jnp.array([1.0, 1.0, -1.0]))
1292 return prelkv.log_sqrt_term + exp_term
1295@named_call
1296def compute_likelihood_ratio(
1297 resid_lrt: (Float32[Array, ' 3'] | Float32[Array, 'k 3'] | Float32[Array, 'k k 3']),
1298 prelkv: PreLkV,
1299 error_cov_inv: Float32[Array, 'k k'] | None,
1300) -> Float32[Array, '']:
1301 """
1302 Compute the likelihood ratio of a grow move.
1304 Parameters
1305 ----------
1306 resid_lrt
1307 The sum of the residuals (scaled by error precision scale) of the
1308 datapoints falling in the left child, right child, and parent node
1309 involved in the move, stacked along the trailing axis.
1310 prelkv
1311 The pre-computed terms of the likelihood ratio, see
1312 `precompute_likelihood_terms`.
1313 error_cov_inv
1314 The global error precision scale. Set only in the multivariate
1315 vector-weight case.
1317 Returns
1318 -------
1319 The log-likelihood ratio log P(data | new tree) - log P(data | old tree).
1320 """
1321 if error_cov_inv is not None:
1322 return _compute_likelihood_ratio_mv_het(resid_lrt, error_cov_inv, prelkv)
1323 elif resid_lrt.ndim > 1:
1324 return _compute_likelihood_ratio_mv(resid_lrt, prelkv)
1325 else:
1326 return _compute_likelihood_ratio_uv(resid_lrt, prelkv)
1329@named_call
1330def accept_moves_final_stage(state: State, moves: Moves) -> State:
1331 """
1332 Post-process the mcmc state after accepting/rejecting the moves.
1334 This function is separate from `accept_moves_sequential_stage` to signal it
1335 can work in parallel across trees.
1337 Parameters
1338 ----------
1339 state
1340 A partially updated BART mcmc state.
1341 moves
1342 The proposed moves (see `propose_moves`) as updated by
1343 `accept_moves_sequential_stage`.
1345 Returns
1346 -------
1347 The fully updated BART mcmc state.
1348 """
1349 assert moves.acc is not None
1350 return replace(
1351 state,
1352 forest=replace(
1353 state.forest,
1354 grow_acc_count=jnp.sum(moves.acc & moves.grow),
1355 prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
1356 leaf_indices=apply_moves_to_leaf_indices(state.forest.leaf_indices, moves),
1357 split_tree=apply_moves_to_split_trees(state.forest.split_tree, moves),
1358 affluence_tree=apply_moves_to_affluence_trees(
1359 state.forest.affluence_tree, moves
1360 ),
1361 ),
1362 )
1365@named_call
1366def apply_moves_to_leaf_indices(
1367 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
1368) -> UInt[Array, 'num_trees n']:
1369 """
1370 Update the leaf indices to match the accepted move.
1372 Parameters
1373 ----------
1374 leaf_indices
1375 The index of the leaf each datapoint falls into, if the grow move was
1376 accepted.
1377 moves
1378 The proposed moves (see `propose_moves`), as updated by
1379 `accept_moves_sequential_stage`.
1381 Returns
1382 -------
1383 The updated leaf indices.
1384 """
1385 return _apply_moves_to_leaf_indices(leaf_indices, moves)
1388@vmap_nodoc
1389def _apply_moves_to_leaf_indices(
1390 leaf_indices: UInt[Array, ' n'], moves: Moves
1391) -> UInt[Array, ' n']:
1392 """Implement `apply_moves_to_leaf_indices`."""
1393 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
1394 is_child = (leaf_indices & mask) == moves.lrt_nodes[0]
1395 assert moves.to_prune is not None
1396 return jnp.where(
1397 is_child & moves.to_prune,
1398 moves.lrt_nodes[2].astype(leaf_indices.dtype),
1399 leaf_indices,
1400 )
1403@named_call
1404def apply_moves_to_split_trees(
1405 split_tree: UInt[Array, 'num_trees half_tree_size'], moves: Moves
1406) -> UInt[Array, 'num_trees half_tree_size']:
1407 """
1408 Update the split trees to match the accepted move.
1410 Parameters
1411 ----------
1412 split_tree
1413 The cutpoints of the decision nodes in the initial trees.
1414 moves
1415 The proposed moves (see `propose_moves`), as updated by
1416 `accept_moves_sequential_stage`.
1418 Returns
1419 -------
1420 The updated split trees.
1421 """
1422 return _apply_moves_to_split_trees(split_tree, moves)
1425@vmap_nodoc
1426def _apply_moves_to_split_trees(
1427 split_tree: UInt[Array, ' half_tree_size'], moves: Moves
1428) -> UInt[Array, ' half_tree_size']:
1429 """Implement `apply_moves_to_split_trees`."""
1430 assert moves.to_prune is not None
1431 # a single scatter serves both cases: an accepted grow writes the new
1432 # cutpoint, while pruning (accepted prune or rejected grow) zeroes the node
1433 return split_tree.at[
1434 jnp.where(moves.grow | moves.to_prune, moves.lrt_nodes[2], split_tree.size)
1435 ].set(jnp.where(moves.to_prune, 0, moves.grow_split).astype(split_tree.dtype))
1438@named_call
1439def apply_moves_to_affluence_trees(
1440 affluence_tree: Bool[Array, 'num_trees half_tree_size'], moves: Moves
1441) -> Bool[Array, 'num_trees half_tree_size']:
1442 """
1443 Update the affluence trees to match the accepted move.
1445 The affluence tree marks the growable leaves; this restores that invariant
1446 after the move by re-marking only the nodes it touched, starting from the
1447 clean pre-move mask.
1449 Parameters
1450 ----------
1451 affluence_tree
1452 The mask of the growable leaves in the initial trees.
1453 moves
1454 The proposed moves (see `propose_moves`), as updated by
1455 `accept_moves_sequential_stage`.
1457 Returns
1458 -------
1459 The updated affluence trees.
1460 """
1461 return _apply_moves_to_affluence_trees(affluence_tree, moves)
1464@vmap_nodoc
1465def _apply_moves_to_affluence_trees(
1466 affluence_tree: Bool[Array, ' half_tree_size'], moves: Moves
1467) -> Bool[Array, ' half_tree_size']:
1468 """Implement `apply_moves_to_affluence_trees`."""
1469 assert moves.to_prune is not None
1470 assert moves.lrt_affluent is not None
1471 # GROW: node becomes internal, children become leaves with their affluence.
1472 # PRUNE (accepted prune or rejected grow): node becomes a leaf with its
1473 # affluence, children are deleted. Either way all three nodes are written:
1474 # the mask keeps the affluence of the nodes that become leaves and zeroes
1475 # the rest. If no move is applied (a rejected prune), the indices resolve
1476 # to `size` and the writes drop.
1477 becomes_leaf = moves.to_prune ^ jnp.array([True, True, False])
1478 return affluence_tree.at[
1479 jnp.where(moves.grow | moves.to_prune, moves.lrt_nodes, affluence_tree.size)
1480 ].set(moves.lrt_affluent & becomes_leaf)
1483@jit
1484def _sample_wishart_bartlett(
1485 key: Key[Array, ''],
1486 df: Float32[Array, ''] | float,
1487 scale_inv: Float32[Array, 'k k'],
1488) -> Float32[Array, 'k k']:
1489 """
1490 Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition.
1492 Parameters
1493 ----------
1494 key
1495 A JAX random key
1496 df
1497 Degrees of freedom
1498 scale_inv
1499 Scale matrix of the corresponding Inverse Wishart distribution
1501 Returns
1502 -------
1503 A sample from Wishart(df, scale)
1504 """
1505 keys = split(key)
1507 # Diagonal elements: A_ii ~ sqrt(chi^2(df - i)), with chi^2(k) = Gamma(k/2, scale=2).
1508 # sqrt(2 * Gamma) = sqrt(2) * exp(loggamma / 2), folding the sqrt into the exp.
1509 k, _ = scale_inv.shape
1510 df_vector = df - jnp.arange(k)
1511 diag_A = jnp.sqrt(2.0) * jnp.exp(loggamma(keys.pop(), df_vector / 2.0) / 2.0)
1513 off_diag_A = random.normal(keys.pop(), (k, k))
1514 A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A)
1515 L = chol_with_gersh(scale_inv, absolute_eps=True)
1516 T = solve_triangular(L, A, lower=True, trans='T')
1518 return T @ T.T
1521def _step_error_cov_inv_mv(key: Key[Array, ''], state: State) -> State:
1522 assert state.error_cov_inv.nu is not None
1523 assert state.error_cov_inv.rate is not None
1525 resid = state.resid
1526 if state.inv_sdev_scale is None:
1527 _, n_eff = resid.shape
1528 n_eff *= get_axis_size(state.config.mesh, 'data')
1529 else:
1530 # 2-D inv_sdev_scale dispatches to the diagonal path, so here it is 1-D
1531 n_eff = jnp.sum(state.inv_sdev_scale != 0, axis=-1)
1532 if state.config.data_sharded:
1533 n_eff = lax.psum(n_eff, 'data')
1534 resid *= state.inv_sdev_scale
1535 df_post = state.error_cov_inv.nu + n_eff
1536 rrt = resid @ resid.T
1537 if state.config.data_sharded:
1538 rrt = lax.psum(rrt, 'data')
1539 scale_post = state.error_cov_inv.rate + rrt
1541 prec = _sample_wishart_bartlett(key, df_post, scale_post)
1542 return replace(state, error_cov_inv=replace(state.error_cov_inv, value=prec))
1545def _step_error_cov_inv_diag(key: Key[Array, ''], state: State) -> State:
1546 """Per-component inverse-gamma update for univariate, mixed, and partial-missing paths."""
1547 assert state.error_cov_inv.rate is not None
1548 assert state.error_cov_inv.nu is not None
1550 resid = state.resid
1551 if state.inv_sdev_scale is not None:
1552 resid *= state.inv_sdev_scale
1554 # alpha
1555 if state.inv_sdev_scale is None:
1556 *_, n_eff = resid.shape
1557 n_eff *= get_axis_size(state.config.mesh, 'data')
1558 else:
1559 n_eff = jnp.sum(state.inv_sdev_scale != 0, axis=-1)
1560 if state.config.data_sharded:
1561 n_eff = lax.psum(n_eff, 'data')
1562 alpha = state.error_cov_inv.nu / 2 + n_eff / 2
1564 # beta
1565 norm2 = jnp.einsum('...n,...n->...', resid, resid)
1566 if state.config.data_sharded:
1567 norm2 = lax.psum(norm2, 'data')
1568 scale = state.error_cov_inv.rate
1569 kshape = resid.shape[:-1]
1570 if kshape:
1571 scale = jnp.diag(scale)
1572 beta = scale / 2 + norm2 / 2
1574 # draw the gamma from the first of a split, mirroring the Bartlett sampler
1575 # in the multivariate path so the two branches coincide at k=1
1576 keys = split(key)
1577 samples = jnp.exp(loggamma(keys.pop(), alpha, kshape))
1578 prec = samples / beta
1579 if state.binary_indices is not None:
1580 prec = prec.at[state.binary_indices].set(1.0)
1581 if kshape:
1582 prec = jnp.diag(prec)
1583 return replace(state, error_cov_inv=replace(state.error_cov_inv, value=prec))
1586@named_call
1587def step_error_cov_inv(key: Key[Array, ''], state: State) -> State:
1588 """MCMC-update the inverse error covariance."""
1589 if (
1590 state.error_cov_inv.value.ndim == 2
1591 and state.binary_indices is None
1592 and (state.inv_sdev_scale is None or state.inv_sdev_scale.ndim == 1)
1593 ):
1594 return _step_error_cov_inv_mv(key, state)
1595 else:
1596 return _step_error_cov_inv_diag(key, state)
1599@named_call
1600def step_z(key: Key[Array, ''], state: State) -> State:
1601 """
1602 MCMC-update the latent variable for binary regression.
1604 Parameters
1605 ----------
1606 key
1607 A jax random key.
1608 state
1609 A BART MCMC state.
1611 Returns
1612 -------
1613 The updated BART MCMC state.
1614 """
1615 assert state.z is not None
1616 assert state.binary_y is not None
1618 if state.binary_indices is not None:
1619 resid = state.resid[..., state.binary_indices, :]
1620 else:
1621 resid = state.resid
1623 trees_plus_offset = state.z - resid
1624 if state.config.data_sharded:
1625 # decorrelate the seed across data shards; the seed is replicated
1626 # because the trees and most of the algorithm are replicated
1627 key = random.fold_in(key, lax.axis_index('data'))
1628 resid = truncated_normal_onesided(key, (), ~state.binary_y, -trees_plus_offset)
1629 z = trees_plus_offset + resid
1631 if state.binary_indices is not None:
1632 resid = state.resid.at[..., state.binary_indices, :].set(resid)
1634 return replace(state, z=z, resid=resid)
1637def _blocked_mass_tree(
1638 key: Key[Array, ''],
1639 var_tree: UInt[Array, ' half_tree_size'],
1640 split_tree: UInt[Array, ' half_tree_size'],
1641 max_split: UInt[Array, ' p'],
1642 s: Float32[Array, ' p'],
1643) -> Float32[Array, ' p']:
1644 """Per-variable data-augmentation mass blocked by a single tree.
1646 At each internal node, draws the latent augmentation weight ``lambda / e``
1647 (``lambda`` exponential, ``e`` the eligible split probability mass at the
1648 node) and adds it to every variable ineligible at that node.
1650 Parameters
1651 ----------
1652 key
1653 Random key for sampling.
1654 var_tree
1655 The splitting axes of the tree.
1656 split_tree
1657 The splitting points of the tree.
1658 max_split
1659 The maximum split index for each variable.
1660 s
1661 Split probabilities normalized over selectable variables.
1663 Returns
1664 -------
1665 The blocked mass for each variable.
1666 """
1667 (half_tree_size,) = split_tree.shape
1668 d_minus_1 = half_tree_size.bit_length() - 1 # number of decision-node levels
1669 p = max_split.size
1670 nodes = jnp.arange(half_tree_size)
1671 split = split_tree.astype(jnp.int32)
1672 is_internal = split_tree.astype(bool)
1674 # Range [lo, hi) of cutpoints still available for each node's own splitting
1675 # variable, given the constraints inherited from the ancestors.
1676 lo, hi = vmap(split_range, in_axes=(None, None, None, 0, 0))(
1677 var_tree, split_tree, max_split, nodes, var_tree
1678 )
1680 # An internal node exhausts its own variable for a child when its cutpoint
1681 # sits at the matching end of the available range, so the variable becomes
1682 # ineligible throughout that child's subtree. Row 0 is the left child (low
1683 # end lo), row 1 the right child (high end hi - 1).
1684 blocks = is_internal & (split == jnp.stack([lo, hi - 1]))
1686 # A node can block at most its own splitting variable, so the per-variable
1687 # totals are recovered from these per-node blocks via top-down/bottom-up
1688 # accumulation over depth levels, rather than scanning each node's ancestors.
1690 # Ineligible mass per node: the s-mass of the variables blocked along the
1691 # path from the root. Each variable is blocked at exactly one node per path,
1692 # so summing the per-node increments top-down reproduces the per-node sum
1693 # over distinct ineligible variables.
1694 parent = nodes >> 1
1695 side = nodes & 1 # 0 if the node is a left child, 1 if a right child
1696 parent_blocks = blocks[side, parent]
1697 # var_tree[parent] is a valid index wherever parent_blocks holds (the parent
1698 # is then internal); elsewhere the clamped gather is masked away
1699 ineligible_mass = jnp.where(parent_blocks, s[var_tree[parent]], 0.0)
1700 for level in range(1, d_minus_1):
1701 lhs, rhs = 1 << level, 1 << (level + 1)
1702 parent_mass = jnp.repeat(ineligible_mass[lhs >> 1 : rhs >> 1], 2)
1703 ineligible_mass = ineligible_mass.at[lhs:rhs].add(parent_mass)
1705 # Per-node augmentation weight lambda_b / e_b, zero at non-internal nodes. The
1706 # eligible mass is positive at internal nodes (the split variable is eligible);
1707 # the floor only guards against round-off, and is unused where weight is zero.
1708 eligible_mass = jnp.maximum(1.0 - ineligible_mass, jnp.finfo(jnp.float32).eps)
1709 weight = jnp.where(is_internal, random.exponential(key, (half_tree_size,)), 0.0)
1710 weight /= eligible_mass
1712 # Subtree weight: total weight of each node's internal descendants and itself,
1713 # accumulated bottom-up. The children of the deepest decision level are leaves
1714 # and contribute nothing.
1715 subtree_weight = weight
1716 for level in range(d_minus_1 - 2, -1, -1):
1717 lhs, rhs = 1 << level, 1 << (level + 1)
1718 children = subtree_weight[2 * lhs : 2 * rhs].reshape(-1, 2).sum(axis=1)
1719 subtree_weight = subtree_weight.at[lhs:rhs].add(children)
1721 # A variable blocked by node b at one of its children is ineligible in that
1722 # whole child subtree, so it accumulates the subtree weight; scatter it onto
1723 # the splitting variable. Only the upper half of nodes have internal children
1724 # (the deepest decision nodes block into leaves, contributing nothing); their
1725 # children are exactly subtree_weight reshaped into [left, right] pairs.
1726 half = half_tree_size // 2
1727 contrib = (blocks[:, :half] * subtree_weight.reshape(-1, 2).T).sum(axis=0)
1728 scatter_var = jnp.where(is_internal, var_tree, p)[:half]
1729 return jnp.zeros(p).at[scatter_var].add(contrib)
1732def sample_s_augmentation(key: Key[Array, ''], forest: Forest) -> Int32[Array, ' p']:
1733 """Sample the data-augmentation counts for the exact full conditional of `s`.
1735 At each internal node, the variables with no available cutpoint given the
1736 ancestors (plus the globally blocked ones) cannot be split on, so the plain
1737 Dirichlet update for `s` is only approximate. This samples, for each
1738 variable, the number of ineligible draws discarded before each realized
1739 split, to be added to the variable usage counts.
1741 Parameters
1742 ----------
1743 key
1744 Random key for sampling.
1745 forest
1746 The forest, providing the trees and the current `log_s`.
1748 Returns
1749 -------
1750 The discarded-draws count for each variable.
1751 """
1752 assert forest.log_s is not None
1753 keys = split(key)
1754 (num_trees, _) = forest.var_tree.shape
1756 # split probabilities normalized over the selectable (non-blocked) variables
1757 selectable = forest.max_split > 0
1758 s = softmax(forest.log_s, where=selectable)
1760 # blocked_mass[j] = sum over internal nodes where j is ineligible of
1761 # lambda_b / e_b, with lambda_b ~ Exponential(1)
1762 blocked_mass = vmap(_blocked_mass_tree, in_axes=(0, 0, 0, None, None))(
1763 keys.pop(num_trees), forest.var_tree, forest.split_tree, forest.max_split, s
1764 ).sum(axis=0) # shape (p,)
1766 # the per-node discarded-draw counts are negative-multinomial, with no
1767 # closed form when summed over nodes, but their Gamma-Poisson mixture does:
1768 # A_j | {lambda_b} ~ Poisson(s_j * blocked_mass[j]), independent across j
1769 return random.poisson(keys.pop(), s * blocked_mass, dtype=jnp.int32)
1772@named_call
1773def step_s(key: Key[Array, ''], state: State) -> State:
1774 """
1775 Update `log_s` using Dirichlet sampling.
1777 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
1778 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
1779 varcount is the count of how many times each variable is used in the
1780 current forest.
1782 Parameters
1783 ----------
1784 key
1785 Random key for sampling.
1786 state
1787 The current BART state.
1789 Returns
1790 -------
1791 Updated BART state with re-sampled `log_s`.
1793 Notes
1794 -----
1795 By default this full conditional is approximate, because it ignores the
1796 decision rules forbidden by the ancestors of each node. If
1797 ``state.config.augment`` is set, the forbidden rules are accounted for
1798 exactly with the data augmentation of `sample_s_augmentation`.
1799 """
1800 assert state.forest.theta is not None
1802 # reserve the Dirichlet draw key first and unconditionally, so it does not
1803 # depend on whether augmentation is on; then the two modes draw identically
1804 # when there are no forbidden rules, since the augmentation is exactly zero
1805 keys = split(key)
1806 log_s_key = keys.pop()
1808 # histogram current variable usage
1809 p = state.forest.max_split.size
1810 varcount = var_histogram(
1811 p, state.forest.var_tree, state.forest.split_tree, sum_batch_axis=-1
1812 )
1814 # the Dirichlet posterior concentration, optionally completed with the exact
1815 # accounting of forbidden rules via data augmentation
1816 alpha = state.forest.theta / p + varcount
1817 if state.config.augment:
1818 alpha = alpha + sample_s_augmentation(keys.pop(), state.forest)
1820 # sample from the Dirichlet posterior and update the forest with the new s
1821 log_s = loggamma(log_s_key, alpha)
1822 return replace(state, forest=replace(state.forest, log_s=log_s))
1825@named_call
1826def step_theta(key: Key[Array, ''], state: State, *, num_grid: int = 1000) -> State:
1827 """
1828 Update `theta`.
1830 The prior is theta / (theta + rho) ~ Beta(a, b).
1832 Parameters
1833 ----------
1834 key
1835 Random key for sampling.
1836 state
1837 The current BART state.
1838 num_grid
1839 The number of points in the evenly-spaced grid used to sample
1840 theta / (theta + rho).
1842 Returns
1843 -------
1844 Updated BART state with re-sampled `theta`.
1845 """
1846 assert state.forest.log_s is not None
1847 assert state.forest.rho is not None
1848 assert state.forest.a is not None
1849 assert state.forest.b is not None
1851 # the grid points are the midpoints of num_grid bins in (0, 1)
1852 padding = 1 / (2 * num_grid)
1853 lambda_grid = jnp.linspace(padding, 1 - padding, num_grid)
1855 # normalize s
1856 log_s = state.forest.log_s - logsumexp(state.forest.log_s)
1858 # sample lambda
1859 logp, theta_grid = _log_p_lambda(
1860 lambda_grid, log_s, state.forest.rho, state.forest.a, state.forest.b
1861 )
1862 i = random.categorical(key, logp)
1863 theta = theta_grid[i]
1865 return replace(state, forest=replace(state.forest, theta=theta))
1868def _log_p_lambda(
1869 lambda_: Float32[Array, ' num_grid'],
1870 log_s: Float32[Array, ' p'],
1871 rho: Float32[Array, ''],
1872 a: Float32[Array, ''],
1873 b: Float32[Array, ''],
1874) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
1875 # in the following I use lambda_[::-1] == 1 - lambda_
1876 theta = rho * lambda_ / lambda_[::-1]
1877 p = log_s.size
1878 return (
1879 (a - 1) * jnp.log1p(-lambda_[::-1]) # log(lambda)
1880 + (b - 1) * jnp.log1p(-lambda_) # log(1 - lambda)
1881 + gammaln(theta)
1882 - p * gammaln(theta / p)
1883 + theta / p * jnp.sum(log_s)
1884 ), theta
1887@named_call
1888def step_sparse(key: Key[Array, ''], state: State) -> State:
1889 """
1890 Update the sparsity parameters.
1892 This invokes `step_s`, and then `step_theta` only if the parameters of
1893 the theta prior are defined.
1895 Parameters
1896 ----------
1897 key
1898 Random key for sampling.
1899 state
1900 The current BART state.
1902 Returns
1903 -------
1904 Updated BART state with re-sampled `log_s` and `theta`.
1905 """
1906 if state.config.sparse_on_at is not None:
1907 state = lax.cond(
1908 state.config.steps_done < state.config.sparse_on_at,
1909 lambda _key, state: state,
1910 _step_sparse,
1911 key,
1912 state,
1913 )
1914 return state
1917def _step_sparse(key: Key[Array, ''], state: State) -> State:
1918 keys = split(key)
1919 state = step_s(keys.pop(), state)
1920 if state.forest.rho is not None:
1921 state = step_theta(keys.pop(), state)
1922 return state
1925@named_call
1926def step_config(state: State) -> State:
1927 config = state.config
1928 config = replace(config, steps_done=config.steps_done + 1)
1929 return replace(state, config=config)