Coverage for src / bartz / mcmcstep / _step.py: 99%
441 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
1# bartz/src/bartz/mcmcstep/_step.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement `step`, `step_trees`, and the accept-reject logic."""
27from dataclasses import replace
28from functools import partial
30# WORKAROUND(jax<0.6.1): shard_map was promoted from jax.experimental to top-level in 0.6.1
31try:
32 from jax import shard_map
33except ImportError:
34 from jax.experimental.shard_map import shard_map
36import jax
37from equinox import Module, tree_at
38from jax import jit, lax, named_call, random, vmap
39from jax import numpy as jnp
40from jax.scipy.linalg import solve_triangular
41from jax.scipy.special import gammaln, logsumexp
42from jax.sharding import Mesh, PartitionSpec
43from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt, UInt32
45from bartz.grove import var_histogram
46from bartz.jaxext import split, truncated_normal_onesided, vmap_nodoc
47from bartz.mcmcstep._moves import Moves, propose_moves
48from bartz.mcmcstep._state import State, StepConfig, chol_with_gersh, field, vmap_chains
51@partial(jit, donate_argnums=(1,))
52@vmap_chains
53def step(key: Key[Array, ''], bart: State) -> State:
54 """
55 Do one MCMC step.
57 Parameters
58 ----------
59 key
60 A jax random key.
61 bart
62 A BART mcmc state, as created by `init`.
64 Returns
65 -------
66 The new BART mcmc state.
68 Notes
69 -----
70 The memory of the input state is re-used for the output state, so the input
71 state can not be used any more after calling `step`. All this applies
72 outside of `jax.jit`.
73 """
74 keys = split(key, 4) 1ab
76 bart = step_trees(keys.pop(), bart) 1ab
78 if bart.z is not None: 1ahbj
79 bart = step_z(keys.pop(), bart) 1hj
81 if bart.error_cov_df is not None: 1ahbj
82 bart = step_error_cov_inv(keys.pop(), bart) 1ab
84 bart = step_sparse(keys.pop(), bart) 1ahbj
85 return step_config(bart) 1ab
88@named_call
89def step_trees(key: Key[Array, ''], bart: State) -> State:
90 """
91 Forest sampling step of BART MCMC.
93 Parameters
94 ----------
95 key
96 A jax random key.
97 bart
98 A BART mcmc state, as created by `init`.
100 Returns
101 -------
102 The new BART mcmc state.
104 Notes
105 -----
106 This function zeroes the proposal counters.
107 """
108 keys = split(key) 1ab
109 moves = propose_moves(keys.pop(), bart.forest) 1ab
110 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1ab
113@named_call
114def accept_moves_and_sample_leaves(
115 key: Key[Array, ''], bart: State, moves: Moves
116) -> State:
117 """
118 Accept or reject the proposed moves and sample the new leaf values.
120 Parameters
121 ----------
122 key
123 A jax random key.
124 bart
125 A valid BART mcmc state.
126 moves
127 The proposed moves, see `propose_moves`.
129 Returns
130 -------
131 A new (valid) BART mcmc state.
132 """
133 pso = accept_moves_parallel_stage(key, bart, moves) 1ab
134 bart, moves = accept_moves_sequential_stage(pso) 1ab
135 return accept_moves_final_stage(bart, moves) 1ab
138class Counts(Module):
139 """Number of datapoints in the nodes involved in proposed moves for each tree."""
141 left: UInt[Array, '*chains num_trees'] = field(chains=True)
142 """Number of datapoints in the left child."""
144 right: UInt[Array, '*chains num_trees'] = field(chains=True)
145 """Number of datapoints in the right child."""
147 total: UInt[Array, '*chains num_trees'] = field(chains=True)
148 """Number of datapoints in the parent (``= left + right``)."""
151class Precs(Module):
152 """Likelihood precision scale in the nodes involved in proposed moves for each tree.
154 The "likelihood precision scale" of a tree node is the sum of the inverse
155 squared error scales of the datapoints selected by the node.
156 """
158 left: Float32[Array, '*chains num_trees'] = field(chains=True)
159 """Likelihood precision scale in the left child."""
161 right: Float32[Array, '*chains num_trees'] = field(chains=True)
162 """Likelihood precision scale in the right child."""
164 total: Float32[Array, '*chains num_trees'] = field(chains=True)
165 """Likelihood precision scale in the parent (``= left + right``)."""
168class PreLkV(Module):
169 """Non-sequential terms of the likelihood ratio for each tree.
171 These terms can be computed in parallel across trees.
172 """
174 left: (
175 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
176 ) = field(chains=True)
177 """In the univariate case, this is the scalar term
179 ``1 / error_cov_inv + n_left / leaf_prior_cov_inv``.
181 In the multivariate case, this is the matrix term
183 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_left * error_cov_inv) @ error_cov_inv``.
185 ``n_left`` is the number of datapoints in the left child, or the
186 likelihood precision scale in the heteroskedastic case."""
188 right: (
189 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
190 ) = field(chains=True)
191 """In the univariate case, this is the scalar term
193 ``1 / error_cov_inv + n_right / leaf_prior_cov_inv``.
195 In the multivariate case, this is the matrix term
197 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_right * error_cov_inv) @ error_cov_inv``.
199 ``n_right`` is the number of datapoints in the right child, or the
200 likelihood precision scale in the heteroskedastic case."""
202 total: (
203 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
204 ) = field(chains=True)
205 """In the univariate case, this is the scalar term
207 ``1 / error_cov_inv + n_total / leaf_prior_cov_inv``.
209 In the multivariate case, this is the matrix term
211 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_total * error_cov_inv) @ error_cov_inv``.
213 ``n_total`` is the number of datapoints in the parent node, or the
214 likelihood precision scale in the heteroskedastic case."""
216 log_sqrt_term: Float32[Array, '*chains num_trees'] = field(chains=True)
217 """The logarithm of the square root term of the likelihood ratio."""
220class PreLk(Module):
221 """Non-sequential terms of the likelihood ratio shared by all trees."""
223 exp_factor: Float32[Array, '*chains'] = field(chains=True)
224 """The factor to multiply the likelihood ratio by, shared by all trees."""
227class PreLf(Module):
228 """Pre-computed terms used to sample leaves from their posterior.
230 These terms can be computed in parallel across trees.
232 For each tree and leaf, the terms are scalars in the univariate case, and
233 matrices/vectors in the multivariate case.
234 """
236 mean_factor: (
237 Float32[Array, '*chains num_trees 2**d']
238 | Float32[Array, '*chains num_trees k k 2**d']
239 ) = field(chains=True)
240 """The factor to be right-multiplied by the sum of the scaled residuals to
241 obtain the posterior mean."""
243 centered_leaves: (
244 Float32[Array, '*chains num_trees 2**d']
245 | Float32[Array, '*chains num_trees k 2**d']
246 ) = field(chains=True)
247 """The mean-zero normal values to be added to the posterior mean to
248 obtain the posterior leaf samples."""
251class ParallelStageOut(Module):
252 """The output of `accept_moves_parallel_stage`."""
254 bart: State
255 """A partially updated BART mcmc state."""
257 moves: Moves
258 """The proposed moves, with `partial_ratio` set to `None` and
259 `log_trans_prior_ratio` set to its final value."""
261 prec_trees: (
262 Float32[Array, '*chains num_trees 2**d']
263 | Int32[Array, '*chains num_trees 2**d']
264 ) = field(chains=True)
265 """The likelihood precision scale in each potential or actual leaf node. If
266 there is no precision scale, this is the number of points in each leaf."""
268 move_precs: Precs | Counts
269 """The likelihood precision scale in each node modified by the moves. If
270 `bart.prec_scale` is not set, this is set to `move_counts`."""
272 prelkv: PreLkV
273 """Object with pre-computed terms of the likelihood ratios."""
275 prelk: PreLk | None
276 """Object with pre-computed terms of the likelihood ratios."""
278 prelf: PreLf
279 """Object with pre-computed terms of the leaf samples."""
282@named_call
283def accept_moves_parallel_stage(
284 key: Key[Array, ''], bart: State, moves: Moves
285) -> ParallelStageOut:
286 """
287 Pre-compute quantities used to accept moves, in parallel across trees.
289 Parameters
290 ----------
291 key
292 A jax random key.
293 bart
294 A BART mcmc state.
295 moves
296 The proposed moves, see `propose_moves`.
298 Returns
299 -------
300 An object with all that could be done in parallel.
301 """
302 # where the move is grow, modify the state like the move was accepted
303 bart = replace( 1ab
304 bart,
305 forest=replace(
306 bart.forest,
307 var_tree=moves.var_tree,
308 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
309 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
310 ),
311 )
313 # count number of datapoints per leaf
314 if ( 1aqorstbj
315 bart.forest.min_points_per_decision_node is not None
316 or bart.forest.min_points_per_leaf is not None
317 or bart.prec_scale is None
318 ):
319 count_trees, move_counts = compute_count_trees( 1aortbj
320 bart.forest.leaf_indices, moves, bart.config
321 )
323 # mark which leaves & potential leaves have enough points to be grown
324 if bart.forest.min_points_per_decision_node is not None: 1aqosbj
325 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1ab
326 moves = replace( 1ab
327 moves,
328 affluence_tree=moves.affluence_tree
329 & (count_half_trees >= bart.forest.min_points_per_decision_node),
330 )
332 # copy updated affluence_tree to state
333 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1aobj
335 # veto grove move if new leaves don't have enough datapoints
336 if bart.forest.min_points_per_leaf is not None: 1aufgb
337 moves = replace( 1afg
338 moves,
339 allowed=moves.allowed
340 & (move_counts.left >= bart.forest.min_points_per_leaf)
341 & (move_counts.right >= bart.forest.min_points_per_leaf),
342 )
344 # count number of datapoints per leaf, weighted by error precision scale
345 if bart.prec_scale is None: 1akub
346 prec_trees = count_trees 1ab
347 move_precs = move_counts 1ab
348 else:
349 prec_trees, move_precs = compute_prec_trees( 1k
350 bart.prec_scale, bart.forest.leaf_indices, moves, bart.config
351 )
352 assert move_precs is not None 1ab
354 # compute some missing information about moves
355 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1ab
356 save_ratios = bart.forest.log_likelihood is not None 1ab
357 bart = replace( 1ahfgb
358 bart,
359 forest=replace(
360 bart.forest,
361 grow_prop_count=jnp.sum(moves.grow),
362 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
363 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
364 ),
365 )
367 prelkv, prelk = precompute_likelihood_terms( 1ahfgb
368 bart.error_cov_inv, bart.forest.leaf_prior_cov_inv, move_precs
369 )
370 prelf = precompute_leaf_terms( 1ab
371 key, prec_trees, bart.error_cov_inv, bart.forest.leaf_prior_cov_inv
372 )
374 return ParallelStageOut( 1ab
375 bart=bart,
376 moves=moves,
377 prec_trees=prec_trees,
378 move_precs=move_precs,
379 prelkv=prelkv,
380 prelk=prelk,
381 prelf=prelf,
382 )
385@named_call
386@partial(vmap_nodoc, in_axes=(0, 0, None))
387def apply_grow_to_indices(
388 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
389) -> UInt[Array, 'num_trees n']:
390 """
391 Update the leaf indices to apply a grow move.
393 Parameters
394 ----------
395 moves
396 The proposed moves, see `propose_moves`.
397 leaf_indices
398 The index of the leaf each datapoint falls into.
399 X
400 The predictors matrix.
402 Returns
403 -------
404 The updated leaf indices.
405 """
406 left_child = moves.node.astype(leaf_indices.dtype) << 1 1ab
407 x: UInt[Array, ' n'] = X[moves.grow_var, :] 1ab
408 go_right = x >= moves.grow_split 1ab
409 tree_size = jnp.array(2 * moves.var_tree.size) 1ab
410 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1ab
411 return jnp.where( 1ab
412 leaf_indices == node_to_update, left_child + go_right, leaf_indices
413 )
416def _compute_count_or_prec_trees(
417 prec_scale: Float32[Array, ' n'] | None,
418 leaf_indices: UInt[Array, 'num_trees n'],
419 moves: Moves,
420 config: StepConfig,
421) -> (
422 tuple[UInt32[Array, 'num_trees 2**d'], Counts]
423 | tuple[Float32[Array, 'num_trees 2**d'], Precs]
424):
425 """Implement `compute_count_trees` and `compute_prec_trees`."""
426 if config.prec_count_num_trees is None: 1ahfgb
427 compute = vmap(_compute_count_or_prec_tree, in_axes=(None, 0, 0, None)) 1hb
428 return compute(prec_scale, leaf_indices, moves, config) 1hb
430 def compute( 1afg
431 args: tuple[UInt[Array, ' n'], Moves],
432 ) -> tuple[UInt32[Array, ' 2**d'], Counts] | tuple[Float32[Array, ' 2**d'], Precs]:
433 leaf_indices, moves = args 1afg
434 return _compute_count_or_prec_tree(prec_scale, leaf_indices, moves, config) 1afg
436 return lax.map( 1afg
437 compute, (leaf_indices, moves), batch_size=config.prec_count_num_trees
438 )
441def _compute_count_or_prec_tree(
442 prec_scale: Float32[Array, ' n'] | None,
443 leaf_indices: UInt[Array, ' n'],
444 moves: Moves,
445 config: StepConfig,
446) -> tuple[UInt32[Array, ' 2**d'], Counts] | tuple[Float32[Array, ' 2**d'], Precs]:
447 """Compute count or precision tree for a single tree."""
448 (tree_size,) = moves.var_tree.shape 1ab
449 tree_size *= 2 1ab
451 if prec_scale is None: 1akb
452 value = 1 1ab
453 cls = Counts 1ab
454 dtype = jnp.uint32 1ab
455 num_batches = config.count_num_batches 1ab
456 else:
457 value = prec_scale 1k
458 cls = Precs 1k
459 dtype = jnp.float32 1k
460 num_batches = config.prec_num_batches 1k
462 trees = _scatter_add( 1ab
463 value, leaf_indices, tree_size, dtype, num_batches, config.mesh
464 )
466 # count datapoints in nodes modified by move
467 left = trees[moves.left] 1ab
468 right = trees[moves.right] 1ab
469 counts = cls(left=left, right=right, total=left + right) 1ab
471 # write count into non-leaf node
472 trees = trees.at[moves.node].set(counts.total) 1ab
474 return trees, counts 1ab
477@named_call
478def compute_count_trees(
479 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, config: StepConfig
480) -> tuple[UInt32[Array, 'num_trees 2**d'], Counts]:
481 """
482 Count the number of datapoints in each leaf.
484 Parameters
485 ----------
486 leaf_indices
487 The index of the leaf each datapoint falls into, with the deeper version
488 of the tree (post-GROW, pre-PRUNE).
489 moves
490 The proposed moves, see `propose_moves`.
491 config
492 The MCMC configuration.
494 Returns
495 -------
496 count_trees : Int32[Array, 'num_trees 2**d']
497 The number of points in each potential or actual leaf node.
498 counts : Counts
499 The counts of the number of points in the leaves grown or pruned by the
500 moves.
501 """
502 return _compute_count_or_prec_trees(None, leaf_indices, moves, config) 1ab
505@named_call
506def compute_prec_trees(
507 prec_scale: Float32[Array, ' n'],
508 leaf_indices: UInt[Array, 'num_trees n'],
509 moves: Moves,
510 config: StepConfig,
511) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
512 """
513 Compute the likelihood precision scale in each leaf.
515 Parameters
516 ----------
517 prec_scale
518 The scale of the precision of the error on each datapoint.
519 leaf_indices
520 The index of the leaf each datapoint falls into, with the deeper version
521 of the tree (post-GROW, pre-PRUNE).
522 moves
523 The proposed moves, see `propose_moves`.
524 config
525 The MCMC configuration.
527 Returns
528 -------
529 prec_trees : Float32[Array, 'num_trees 2**d']
530 The likelihood precision scale in each potential or actual leaf node.
531 precs : Precs
532 The likelihood precision scale in the nodes involved in the moves.
533 """
534 return _compute_count_or_prec_trees(prec_scale, leaf_indices, moves, config) 1k
537@partial(vmap_nodoc, in_axes=(0, None))
538def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves:
539 """
540 Complete non-likelihood MH ratio calculation.
542 This function adds the probability of choosing a prune move over the grow
543 move in the inverse transition, and the a priori probability that the
544 children nodes are leaves.
546 Parameters
547 ----------
548 moves
549 The proposed moves. Must have already been updated to keep into account
550 the thresholds on the number of datapoints per node, this happens in
551 `accept_moves_parallel_stage`.
552 p_nonterminal
553 The a priori probability of each node being nonterminal conditional on
554 its ancestors, including at the maximum depth where it should be zero.
556 Returns
557 -------
558 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
559 """
560 # can the leaves be grown?
561 left_growable = moves.affluence_tree.at[moves.left].get( 1ab
562 mode='fill', fill_value=False
563 )
564 right_growable = moves.affluence_tree.at[moves.right].get( 1ab
565 mode='fill', fill_value=False
566 )
568 # p_prune if grow
569 other_growable_leaves = moves.num_growable >= 2 1ab
570 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1ab
571 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1.0) 1ab
573 # p_prune if prune
574 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1ab
576 # select p_prune
577 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1ab
579 # prior probability of both children being terminal
580 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1ab
581 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1ab
582 pt_children = pt_left * pt_right 1ab
584 assert moves.partial_ratio is not None 1ab
585 return replace( 1ab
586 moves,
587 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
588 partial_ratio=None,
589 )
592@named_call
593@vmap_nodoc
594def adapt_leaf_trees_to_grow_indices(
595 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
596) -> Float32[Array, 'num_trees 2**d']:
597 """
598 Modify leaves such that post-grow indices work on the original tree.
600 The value of the leaf to grow is copied to what would be its children if the
601 grow move was accepted.
603 Parameters
604 ----------
605 leaf_trees
606 The leaf values.
607 moves
608 The proposed moves, see `propose_moves`.
610 Returns
611 -------
612 The modified leaf values.
613 """
614 values_at_node = leaf_trees[..., moves.node] 1ab
615 return ( 1ab
616 leaf_trees.at[..., jnp.where(moves.grow, moves.left, leaf_trees.size)]
617 .set(values_at_node)
618 .at[..., jnp.where(moves.grow, moves.right, leaf_trees.size)]
619 .set(values_at_node)
620 )
623def _logdet_from_chol(L: Float32[Array, '... k k']) -> Float32[Array, '...']:
624 """Compute logdet of A = LL' via Cholesky (sum of log of diag^2)."""
625 diags: Float32[Array, '... k'] = jnp.diagonal(L, axis1=-2, axis2=-1) 1de
626 return 2.0 * jnp.sum(jnp.log(diags), axis=-1) 1de
629def _precompute_likelihood_terms_uv(
630 error_cov_inv: Float32[Array, ''],
631 leaf_prior_cov_inv: Float32[Array, ''],
632 move_precs: Precs | Counts,
633) -> tuple[PreLkV, PreLk]:
634 sigma2 = jnp.reciprocal(error_cov_inv) 1ab
635 sigma_mu2 = jnp.reciprocal(leaf_prior_cov_inv) 1ab
636 left = sigma2 + move_precs.left * sigma_mu2 1ab
637 right = sigma2 + move_precs.right * sigma_mu2 1ab
638 total = sigma2 + move_precs.total * sigma_mu2 1ab
639 prelkv = PreLkV( 1ab
640 left=left,
641 right=right,
642 total=total,
643 log_sqrt_term=jnp.log(sigma2 * total / (left * right)) / 2,
644 )
645 return prelkv, PreLk(exp_factor=error_cov_inv / leaf_prior_cov_inv / 2) 1ab
648def _precompute_likelihood_terms_mv(
649 error_cov_inv: Float32[Array, 'k k'],
650 leaf_prior_cov_inv: Float32[Array, 'k k'],
651 move_precs: Counts,
652) -> tuple[PreLkV, None]:
653 nL: UInt[Array, 'num_trees 1 1'] = move_precs.left[..., None, None] 1de
654 nR: UInt[Array, 'num_trees 1 1'] = move_precs.right[..., None, None] 1de
655 nT: UInt[Array, 'num_trees 1 1'] = move_precs.total[..., None, None] 1de
657 L_left: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1de
658 error_cov_inv * nL + leaf_prior_cov_inv
659 )
660 L_right: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1de
661 error_cov_inv * nR + leaf_prior_cov_inv
662 )
663 L_total: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1de
664 error_cov_inv * nT + leaf_prior_cov_inv
665 )
667 log_sqrt_term: Float32[Array, ' num_trees'] = 0.5 * ( 1de
668 _logdet_from_chol(chol_with_gersh(leaf_prior_cov_inv))
669 + _logdet_from_chol(L_total)
670 - _logdet_from_chol(L_left)
671 - _logdet_from_chol(L_right)
672 )
674 def _term_from_chol( 1de
675 L: Float32[Array, 'num_trees k k'],
676 ) -> Float32[Array, 'num_trees k k']:
677 rhs: Float32[Array, 'num_trees k k'] = jnp.broadcast_to(error_cov_inv, L.shape) 1de
678 Y: Float32[Array, 'num_trees k k'] = solve_triangular(L, rhs, lower=True) 1de
679 return Y.mT @ Y 1de
681 prelkv = PreLkV( 1de
682 left=_term_from_chol(L_left),
683 right=_term_from_chol(L_right),
684 total=_term_from_chol(L_total),
685 log_sqrt_term=log_sqrt_term,
686 )
688 return prelkv, None 1de
691@named_call
692def precompute_likelihood_terms(
693 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
694 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
695 move_precs: Precs | Counts,
696) -> tuple[PreLkV, PreLk | None]:
697 """
698 Pre-compute terms used in the likelihood ratio of the acceptance step.
700 Handles both univariate and multivariate cases based on the shape of the
701 input arrays. The multivariate implementation assumes a homoskedastic error
702 model (i.e., the residual covariance is the same for all observations).
704 Parameters
705 ----------
706 error_cov_inv
707 The inverse error variance (univariate) or the inverse of the error
708 covariance matrix (multivariate). For univariate case, this is the
709 inverse global error variance factor if `prec_scale` is set.
710 leaf_prior_cov_inv
711 The inverse prior variance of each leaf (univariate) or the inverse of
712 prior covariance matrix of each leaf (multivariate).
713 move_precs
714 The likelihood precision scale in the leaves grown or pruned by the
715 moves, under keys 'left', 'right', and 'total' (left + right).
717 Returns
718 -------
719 prelkv : PreLkV
720 Pre-computed terms of the likelihood ratio, one per tree.
721 prelk : PreLk | None
722 Pre-computed terms of the likelihood ratio, shared by all trees.
723 """
724 if error_cov_inv.ndim == 2: 1adeb
725 assert isinstance(move_precs, Counts) 1de
726 return _precompute_likelihood_terms_mv( 1de
727 error_cov_inv, leaf_prior_cov_inv, move_precs
728 )
729 else:
730 return _precompute_likelihood_terms_uv( 1ab
731 error_cov_inv, leaf_prior_cov_inv, move_precs
732 )
735def _precompute_leaf_terms_uv(
736 key: Key[Array, ''],
737 prec_trees: Float32[Array, 'num_trees 2**d'],
738 error_cov_inv: Float32[Array, ''],
739 leaf_prior_cov_inv: Float32[Array, ''],
740 z: Float32[Array, 'num_trees 2**d'] | None = None,
741) -> PreLf:
742 prec_lk = prec_trees * error_cov_inv 1ab
743 var_post = jnp.reciprocal(prec_lk + leaf_prior_cov_inv) 1ab
744 if z is None: 1abp
745 z = random.normal(key, prec_trees.shape, error_cov_inv.dtype) 1ab
746 return PreLf( 1abp
747 mean_factor=var_post * error_cov_inv,
748 # | mean = mean_lk * prec_lk * var_post
749 # | resid_tree = mean_lk * prec_tree -->
750 # | --> mean_lk = resid_tree / prec_tree (kind of)
751 # | mean_factor =
752 # | = mean / resid_tree =
753 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
754 # | = 1 / prec_tree * prec_tree / sigma2 * var_post =
755 # | = var_post / sigma2
756 centered_leaves=z * jnp.sqrt(var_post),
757 )
760def _precompute_leaf_terms_mv(
761 key: Key[Array, ''],
762 prec_trees: Float32[Array, 'num_trees 2**d'],
763 error_cov_inv: Float32[Array, 'k k'],
764 leaf_prior_cov_inv: Float32[Array, 'k k'],
765 z: Float32[Array, 'num_trees 2**d k'] | None = None,
766) -> PreLf:
767 num_trees, tree_size = prec_trees.shape 1de
768 k = error_cov_inv.shape[0] 1de
769 n_k: Float32[Array, 'num_trees tree_size 1 1'] = prec_trees[..., None, None] 1de
771 # Only broadcast the inverse of error covariance matrix to satisfy JAX's
772 # batching rules for `lax.linalg.solve_triangular`, which does not support
773 # implicit broadcasting.
774 error_cov_inv_batched = jnp.broadcast_to( 1de
775 error_cov_inv, (num_trees, tree_size, k, k)
776 )
778 posterior_precision: Float32[Array, 'num_trees tree_size k k'] = ( 1de
779 leaf_prior_cov_inv + n_k * error_cov_inv_batched
780 )
782 L_prec: Float32[Array, 'num_trees tree_size k k'] = chol_with_gersh( 1de
783 posterior_precision
784 )
785 Y: Float32[Array, 'num_trees tree_size k k'] = solve_triangular( 1de
786 L_prec, error_cov_inv_batched, lower=True
787 )
788 mean_factor: Float32[Array, 'num_trees tree_size k k'] = solve_triangular( 1de
789 L_prec, Y, trans='T', lower=True
790 )
791 mean_factor = mean_factor.mT 1de
792 mean_factor_out: Float32[Array, 'num_trees k k tree_size'] = jnp.moveaxis( 1de
793 mean_factor, 1, -1
794 )
796 if z is None: 1dep
797 z = random.normal(key, (num_trees, tree_size, k)) 1de
798 centered_leaves: Float32[Array, 'num_trees tree_size k'] = solve_triangular( 1dep
799 L_prec, z, trans='T'
800 )
801 centered_leaves_out: Float32[Array, 'num_trees k tree_size'] = jnp.swapaxes( 1de
802 centered_leaves, -1, -2
803 )
805 return PreLf(mean_factor=mean_factor_out, centered_leaves=centered_leaves_out) 1de
808@named_call
809def precompute_leaf_terms(
810 key: Key[Array, ''],
811 prec_trees: Float32[Array, 'num_trees 2**d'],
812 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
813 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
814 z: Float32[Array, 'num_trees 2**d']
815 | Float32[Array, 'num_trees 2**d k']
816 | None = None,
817) -> PreLf:
818 """
819 Pre-compute terms used to sample leaves from their posterior.
821 Handles both univariate and multivariate cases based on the shape of the
822 input arrays.
824 Parameters
825 ----------
826 key
827 A jax random key.
828 prec_trees
829 The likelihood precision scale in each potential or actual leaf node.
830 error_cov_inv
831 The inverse error variance (univariate) or the inverse of error
832 covariance matrix (multivariate). For univariate case, this is the
833 inverse global error variance factor if `prec_scale` is set.
834 leaf_prior_cov_inv
835 The inverse prior variance of each leaf (univariate) or the inverse of
836 prior covariance matrix of each leaf (multivariate).
837 z
838 Optional standard normal noise to use for sampling the centered leaves.
839 This is intended for testing purposes only.
841 Returns
842 -------
843 Pre-computed terms for leaf sampling.
844 """
845 if error_cov_inv.ndim == 2: 1adeb
846 return _precompute_leaf_terms_mv( 1de
847 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
848 )
849 else:
850 return _precompute_leaf_terms_uv( 1ab
851 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
852 )
855@named_call
856def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
857 """
858 Accept/reject the moves one tree at a time.
860 This is the most performance-sensitive function because it contains all and
861 only the parts of the algorithm that can not be parallelized across trees.
863 Parameters
864 ----------
865 pso
866 The output of `accept_moves_parallel_stage`.
868 Returns
869 -------
870 bart : State
871 A partially updated BART mcmc state.
872 moves : Moves
873 The accepted/rejected moves, with `acc` and `to_prune` set.
874 """
876 def loop( 1ab
877 resid: Float32[Array, ' n'] | Float32[Array, ' k n'], pt: SeqStageInPerTree
878 ) -> tuple[
879 Float32[Array, ' n'] | Float32[Array, ' k n'],
880 tuple[
881 Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'],
882 Bool[Array, ''],
883 Bool[Array, ''],
884 Float32[Array, ''] | None,
885 ],
886 ]:
887 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1ab
888 resid,
889 SeqStageInAllTrees(
890 pso.bart.X,
891 pso.bart.config.resid_num_batches,
892 pso.bart.config.mesh,
893 pso.bart.prec_scale,
894 pso.bart.forest.log_likelihood is not None,
895 pso.prelk,
896 ),
897 pt,
898 )
899 return resid, (leaf_tree, acc, to_prune, lkratio) 1ab
901 pts = SeqStageInPerTree( 1ab
902 pso.bart.forest.leaf_tree,
903 pso.prec_trees,
904 pso.moves,
905 pso.move_precs,
906 pso.bart.forest.leaf_indices,
907 pso.prelkv,
908 pso.prelf,
909 )
910 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1ab
912 bart = replace( 1ab
913 pso.bart,
914 resid=resid,
915 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
916 )
917 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1ab
919 return bart, moves 1ab
922class SeqStageInAllTrees(Module):
923 """The inputs to `accept_move_and_sample_leaves` that are shared by all trees."""
925 X: UInt[Array, 'p n']
926 """The predictors."""
928 resid_num_batches: int | None = field(static=True)
929 """The number of batches for computing the sum of residuals in each leaf."""
931 mesh: Mesh | None = field(static=True)
932 """The mesh of devices to use."""
934 prec_scale: Float32[Array, ' n'] | None
935 """The scale of the precision of the error on each datapoint. If None, it
936 is assumed to be 1."""
938 save_ratios: bool = field(static=True)
939 """Whether to save the acceptance ratios."""
941 prelk: PreLk | None
942 """The pre-computed terms of the likelihood ratio which are shared across
943 trees."""
946class SeqStageInPerTree(Module):
947 """The inputs to `accept_move_and_sample_leaves` that are separate for each tree."""
949 leaf_tree: Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d']
950 """The leaf values of the tree."""
952 prec_tree: Float32[Array, ' 2**d']
953 """The likelihood precision scale in each potential or actual leaf node."""
955 move: Moves
956 """The proposed move, see `propose_moves`."""
958 move_precs: Precs | Counts
959 """The likelihood precision scale in each node modified by the moves."""
961 leaf_indices: UInt[Array, ' n']
962 """The leaf indices for the largest version of the tree compatible with
963 the move."""
965 prelkv: PreLkV
966 """The pre-computed terms of the likelihood ratio which are specific to the tree."""
968 prelf: PreLf
969 """The pre-computed terms of the leaf sampling which are specific to the tree."""
972@named_call
973def accept_move_and_sample_leaves(
974 resid: Float32[Array, ' n'] | Float32[Array, ' k n'],
975 at: SeqStageInAllTrees,
976 pt: SeqStageInPerTree,
977) -> tuple[
978 Float32[Array, ' n'] | Float32[Array, ' k n'],
979 Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'],
980 Bool[Array, ''],
981 Bool[Array, ''],
982 Float32[Array, ''] | None,
983]:
984 """
985 Accept or reject a proposed move and sample the new leaf values.
987 Parameters
988 ----------
989 resid
990 The residuals (data minus forest value).
991 at
992 The inputs that are the same for all trees.
993 pt
994 The inputs that are separate for each tree.
996 Returns
997 -------
998 resid : Float32[Array, 'n'] | Float32[Array, ' k n']
999 The updated residuals (data minus forest value).
1000 leaf_tree : Float32[Array, '2**d'] | Float32[Array, ' k 2**d']
1001 The new leaf values of the tree.
1002 acc : Bool[Array, '']
1003 Whether the move was accepted.
1004 to_prune : Bool[Array, '']
1005 Whether, to reflect the acceptance status of the move, the state should
1006 be updated by pruning the leaves involved in the move.
1007 log_lk_ratio : Float32[Array, ''] | None
1008 The logarithm of the likelihood ratio for the move. `None` if not to be
1009 saved.
1010 """
1011 # sum residuals in each leaf, in tree proposed by grow move
1012 if at.prec_scale is None: 1akb
1013 scaled_resid = resid 1ab
1014 else:
1015 scaled_resid = resid * at.prec_scale 1k
1017 tree_size = pt.leaf_tree.shape[-1] # 2**d 1ab
1019 resid_tree = sum_resid( 1ab
1020 scaled_resid, pt.leaf_indices, tree_size, at.resid_num_batches, at.mesh
1021 )
1023 # subtract starting tree from function
1024 resid_tree += pt.prec_tree * pt.leaf_tree 1ab
1026 # sum residuals in parent node modified by move and compute likelihood
1027 resid_left = resid_tree[..., pt.move.left] 1ab
1028 resid_right = resid_tree[..., pt.move.right] 1ab
1029 resid_total = resid_left + resid_right 1ab
1030 assert pt.move.node.dtype == jnp.int32 1ab
1031 resid_tree = resid_tree.at[..., pt.move.node].set(resid_total) 1ab
1033 log_lk_ratio = compute_likelihood_ratio( 1ab
1034 resid_total, resid_left, resid_right, pt.prelkv, at.prelk
1035 )
1037 # calculate accept/reject ratio
1038 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1ab
1039 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1ab
1040 if not at.save_ratios: 1ahfgb
1041 log_lk_ratio = None 1hb
1043 # determine whether to accept the move
1044 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1afgb
1046 # compute leaves posterior and sample leaves
1047 if resid.ndim > 1: 1adeb
1048 mean_post = jnp.einsum('kil,kl->il', pt.prelf.mean_factor, resid_tree) 1de
1049 else:
1050 mean_post = resid_tree * pt.prelf.mean_factor 1ab
1051 leaf_tree = mean_post + pt.prelf.centered_leaves 1ab
1053 # copy leaves around such that the leaf indices point to the correct leaf
1054 to_prune = acc ^ pt.move.grow 1ab
1055 leaf_tree = ( 1ab
1056 leaf_tree.at[..., jnp.where(to_prune, pt.move.left, tree_size)]
1057 .set(leaf_tree[..., pt.move.node])
1058 .at[..., jnp.where(to_prune, pt.move.right, tree_size)]
1059 .set(leaf_tree[..., pt.move.node])
1060 )
1061 # replace old tree with new tree in function values
1062 resid += (pt.leaf_tree - leaf_tree)[..., pt.leaf_indices] 1ab
1064 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1ab
1067@named_call
1068@partial(jnp.vectorize, excluded=(1, 2, 3, 4), signature='(n)->(ts)')
1069def sum_resid(
1070 scaled_resid: Float32[Array, ' n'] | Float32[Array, 'k n'],
1071 leaf_indices: UInt[Array, ' n'],
1072 tree_size: int,
1073 resid_num_batches: int | None,
1074 mesh: Mesh | None,
1075) -> Float32[Array, ' {tree_size}'] | Float32[Array, 'k {tree_size}']:
1076 """
1077 Sum the residuals in each leaf.
1079 Handles both univariate and multivariate cases based on the shape of the
1080 input arrays.
1082 Parameters
1083 ----------
1084 scaled_resid
1085 The residuals (data minus forest value) multiplied by the error
1086 precision scale. For multivariate case, shape is ``(k, n)`` where ``k``
1087 is the number of outcome columns.
1088 leaf_indices
1089 The leaf indices of the tree (in which leaf each data point falls into).
1090 tree_size
1091 The size of the tree array (2 ** d).
1092 resid_num_batches
1093 The number of batches for computing the sum of residuals in each leaf.
1094 mesh
1095 The mesh of devices to use.
1097 Returns
1098 -------
1099 The sum of the residuals at data points in each leaf. For multivariate
1100 case, returns per-leaf sums of residual vectors.
1101 """
1102 return _scatter_add( 1ab
1103 scaled_resid, leaf_indices, tree_size, jnp.float32, resid_num_batches, mesh
1104 )
1107def _scatter_add(
1108 values: Float32[Array, ' n'] | int,
1109 indices: Integer[Array, ' n'],
1110 size: int,
1111 dtype: jnp.dtype,
1112 batch_size: int | None,
1113 mesh: Mesh | None,
1114) -> Shaped[Array, ' {size}']:
1115 """Indexed reduce with optional batching."""
1116 # check `values`
1117 values = jnp.asarray(values) 1ab
1118 assert values.ndim == 0 or values.shape == indices.shape 1ab
1120 # set configuration
1121 _scatter_add = partial( 1ab
1122 _scatter_add_impl, size=size, dtype=dtype, num_batches=batch_size
1123 )
1125 # single-device invocation
1126 if mesh is None or 'data' not in mesh.axis_names: 1vahibwn
1127 return _scatter_add(values, indices) 1vaibw
1129 # multi-device invocation
1130 if values.shape: 1hn
1131 in_specs = PartitionSpec('data'), PartitionSpec('data') 1hn
1132 else:
1133 in_specs = PartitionSpec(), PartitionSpec('data') 1hn
1134 _scatter_add = partial(_scatter_add, final_psum=True) 1hn
1135 _scatter_add = shard_map( 1hn
1136 _scatter_add,
1137 in_specs=in_specs,
1138 out_specs=PartitionSpec(),
1139 mesh=mesh,
1140 **_get_shard_map_patch_kwargs(),
1141 )
1142 return _scatter_add(values, indices) 1hn
1145def _get_shard_map_patch_kwargs() -> dict[str, bool]:
1146 # WORKAROUND(jax<=0.8.2): vmap(shard_map(psum)), jax#34249; the
1147 # jax_disable_vmap_shmap_error config did not work
1148 if jax.__version__ in ('0.8.1', '0.8.2'): 1148 ↛ 1149line 1148 didn't jump to line 1149 because the condition on line 1148 was never true1hn
1149 return {'check_vma': False}
1150 else:
1151 return {} 1hn
1154def _scatter_add_impl(
1155 values: Float32[Array, ' n'] | Int32[Array, ''],
1156 indices: Integer[Array, ' n'],
1157 /,
1158 *,
1159 size: int,
1160 dtype: jnp.dtype,
1161 num_batches: int | None,
1162 final_psum: bool = False,
1163) -> Shaped[Array, ' {size}']:
1164 if num_batches is None: 1akbj
1165 out = jnp.zeros(size, dtype).at[indices].add(values) 1aj
1167 else:
1168 # in the sharded case, n is the size of the local shard, not the full size
1169 (n,) = indices.shape 1kb
1170 batch_indices = jnp.arange(n) % num_batches 1kb
1171 out = ( 1kb
1172 jnp.zeros((size, num_batches), dtype)
1173 .at[indices, batch_indices]
1174 .add(values)
1175 .sum(axis=1)
1176 )
1178 if final_psum: 1ahbn
1179 out = lax.psum(out, 'data') 1hn
1180 return out 1ab
1183def _compute_likelihood_ratio_uv(
1184 total_resid: Float32[Array, ''],
1185 left_resid: Float32[Array, ''],
1186 right_resid: Float32[Array, ''],
1187 prelkv: PreLkV,
1188 prelk: PreLk,
1189) -> Float32[Array, '']:
1190 exp_term = prelk.exp_factor * ( 1ab
1191 left_resid * left_resid / prelkv.left
1192 + right_resid * right_resid / prelkv.right
1193 - total_resid * total_resid / prelkv.total
1194 )
1195 return prelkv.log_sqrt_term + exp_term 1ab
1198def _compute_likelihood_ratio_mv(
1199 total_resid: Float32[Array, ' k'],
1200 left_resid: Float32[Array, ' k'],
1201 right_resid: Float32[Array, ' k'],
1202 prelkv: PreLkV,
1203) -> Float32[Array, '']:
1204 def _quadratic_form( 1de
1205 r: Float32[Array, ' k'], mat: Float32[Array, 'k k']
1206 ) -> Float32[Array, '']:
1207 return r @ mat @ r 1de
1209 qf_left = _quadratic_form(left_resid, prelkv.left) 1de
1210 qf_right = _quadratic_form(right_resid, prelkv.right) 1de
1211 qf_total = _quadratic_form(total_resid, prelkv.total) 1de
1212 exp_term = 0.5 * (qf_left + qf_right - qf_total) 1de
1213 return prelkv.log_sqrt_term + exp_term 1de
1216@named_call
1217def compute_likelihood_ratio(
1218 total_resid: Float32[Array, ''] | Float32[Array, ' k'],
1219 left_resid: Float32[Array, ''] | Float32[Array, ' k'],
1220 right_resid: Float32[Array, ''] | Float32[Array, ' k'],
1221 prelkv: PreLkV,
1222 prelk: PreLk | None,
1223) -> Float32[Array, '']:
1224 """
1225 Compute the likelihood ratio of a grow move.
1227 Handles both univariate and multivariate cases based on the shape of the
1228 residual arrays.
1230 Parameters
1231 ----------
1232 total_resid
1233 left_resid
1234 right_resid
1235 The sum of the residuals (scaled by error precision scale) of the
1236 datapoints falling in the nodes involved in the moves.
1237 prelkv
1238 prelk
1239 The pre-computed terms of the likelihood ratio, see
1240 `precompute_likelihood_terms`.
1242 Returns
1243 -------
1244 The log-likelihood ratio log P(data | new tree) - log P(data | old tree).
1245 """
1246 if total_resid.ndim > 0: 1adeb
1247 return _compute_likelihood_ratio_mv( 1de
1248 total_resid, left_resid, right_resid, prelkv
1249 )
1250 else:
1251 assert prelk is not None 1ab
1252 return _compute_likelihood_ratio_uv( 1ab
1253 total_resid, left_resid, right_resid, prelkv, prelk
1254 )
1257@named_call
1258def accept_moves_final_stage(bart: State, moves: Moves) -> State:
1259 """
1260 Post-process the mcmc state after accepting/rejecting the moves.
1262 This function is separate from `accept_moves_sequential_stage` to signal it
1263 can work in parallel across trees.
1265 Parameters
1266 ----------
1267 bart
1268 A partially updated BART mcmc state.
1269 moves
1270 The proposed moves (see `propose_moves`) as updated by
1271 `accept_moves_sequential_stage`.
1273 Returns
1274 -------
1275 The fully updated BART mcmc state.
1276 """
1277 return replace( 1ab
1278 bart,
1279 forest=replace(
1280 bart.forest,
1281 grow_acc_count=jnp.sum(moves.acc & moves.grow),
1282 prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
1283 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
1284 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
1285 ),
1286 )
1289@named_call
1290@vmap_nodoc
1291def apply_moves_to_leaf_indices(
1292 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
1293) -> UInt[Array, 'num_trees n']:
1294 """
1295 Update the leaf indices to match the accepted move.
1297 Parameters
1298 ----------
1299 leaf_indices
1300 The index of the leaf each datapoint falls into, if the grow move was
1301 accepted.
1302 moves
1303 The proposed moves (see `propose_moves`), as updated by
1304 `accept_moves_sequential_stage`.
1306 Returns
1307 -------
1308 The updated leaf indices.
1309 """
1310 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1ab
1311 is_child = (leaf_indices & mask) == moves.left 1ab
1312 assert moves.to_prune is not None 1ab
1313 return jnp.where( 1ab
1314 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
1315 )
1318@named_call
1319@vmap_nodoc
1320def apply_moves_to_split_trees(
1321 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
1322) -> UInt[Array, 'num_trees 2**(d-1)']:
1323 """
1324 Update the split trees to match the accepted move.
1326 Parameters
1327 ----------
1328 split_tree
1329 The cutpoints of the decision nodes in the initial trees.
1330 moves
1331 The proposed moves (see `propose_moves`), as updated by
1332 `accept_moves_sequential_stage`.
1334 Returns
1335 -------
1336 The updated split trees.
1337 """
1338 assert moves.to_prune is not None 1ab
1339 return ( 1ab
1340 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
1341 .set(moves.grow_split.astype(split_tree.dtype))
1342 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
1343 .set(0)
1344 )
1347@jax.jit
1348def _sample_wishart_bartlett(
1349 key: Key[Array, ''], df: Float32[Array, ''], scale_inv: Float32[Array, 'k k']
1350) -> Float32[Array, 'k k']:
1351 """
1352 Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition.
1354 Parameters
1355 ----------
1356 key
1357 A JAX random key
1358 df
1359 Degrees of freedom
1360 scale_inv
1361 Scale matrix of the corresponding Inverse Wishart distribution
1363 Returns
1364 -------
1365 A sample from Wishart(df, scale)
1366 """
1367 keys = split(key) 1de
1369 # Diagonal elements: A_ii ~ sqrt(chi^2(df - i))
1370 # chi^2(k) = Gamma(k/2, scale=2)
1371 k, _ = scale_inv.shape 1de
1372 df_vector = df - jnp.arange(k) 1de
1373 chi2_samples = random.gamma(keys.pop(), df_vector / 2.0) * 2.0 1de
1374 diag_A = jnp.sqrt(chi2_samples) 1de
1376 off_diag_A = random.normal(keys.pop(), (k, k)) 1de
1377 A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A) 1de
1378 L = chol_with_gersh(scale_inv, absolute_eps=True) 1de
1379 T = solve_triangular(L, A, lower=True, trans='T') 1de
1381 return T @ T.T 1de
1384def _step_error_cov_inv_uv(key: Key[Array, ''], bart: State) -> State:
1385 assert bart.error_cov_df is not None 1ab
1386 assert bart.error_cov_scale is not None 1ab
1388 resid = bart.resid 1ab
1389 # inverse gamma prior: alpha = df / 2, beta = scale / 2
1390 alpha = bart.error_cov_df / 2 + resid.size / 2 1ab
1391 if bart.prec_scale is None: 1akb
1392 scaled_resid = resid 1ab
1393 else:
1394 scaled_resid = resid * bart.prec_scale 1k
1395 norm2 = resid @ scaled_resid 1ab
1396 beta = bart.error_cov_scale / 2 + norm2 / 2 1ab
1398 sample = random.gamma(key, alpha) 1ab
1399 # random.gamma seems to be slow at compiling, maybe cdf inversion would
1400 # be better, but it's not implemented in jax
1401 return replace(bart, error_cov_inv=sample / beta) 1ab
1404def _step_error_cov_inv_mv(key: Key[Array, ''], bart: State) -> State:
1405 assert bart.error_cov_df is not None 1de
1406 assert bart.error_cov_scale is not None 1de
1408 n = bart.resid.shape[-1] 1de
1409 df_post = bart.error_cov_df + n 1de
1410 scale_post = bart.error_cov_scale + bart.resid @ bart.resid.T 1de
1412 prec = _sample_wishart_bartlett(key, df_post, scale_post) 1de
1413 return replace(bart, error_cov_inv=prec) 1de
1416def _step_error_cov_inv_diag(key: Key[Array, ''], bart: State) -> State:
1417 """Update diagonal error_cov_inv for mixed binary-continuous.
1419 Each continuous component gets an independent inverse-gamma update
1420 (like `_step_error_cov_inv_uv` repeated per component). Binary
1421 components stay fixed at 1.
1422 """
1423 assert bart.binary_indices is not None 1ilm
1424 assert bart.error_cov_scale is not None 1ilm
1425 assert bart.error_cov_df is not None 1ilm
1427 # per-component sum of squared residuals, shape (k,)
1428 norm2 = jnp.einsum('kn,kn->k', bart.resid, bart.resid) 1ilm
1430 # inverse-gamma posterior parameters
1431 *_, k, n = bart.resid.shape 1ilm
1432 scale_diag = jnp.diag(bart.error_cov_scale) 1ilm
1433 alpha = bart.error_cov_df / 2 + n / 2 1ilm
1434 beta = scale_diag / 2 + norm2 / 2 1ilm
1436 # sample independent gamma variates for all k components
1437 samples = random.gamma(key, alpha, (k,)) 1ilm
1438 new_diag = samples / beta 1ilm
1440 # keep binary components at 1.0
1441 new_diag = new_diag.at[bart.binary_indices].set(1.0) 1ilm
1443 return replace(bart, error_cov_inv=jnp.diag(new_diag)) 1ilm
1446@named_call
1447def step_error_cov_inv(key: Key[Array, ''], bart: State) -> State:
1448 """
1449 MCMC-update the inverse error covariance.
1451 Handles univariate, multivariate, and mixed binary-continuous cases.
1453 Parameters
1454 ----------
1455 key
1456 A jax random key.
1457 bart
1458 A BART mcmc state.
1460 Returns
1461 -------
1462 The new BART mcmc state, with an updated `error_cov_inv`.
1463 """
1464 if bart.binary_indices is not None: 1ailbm
1465 return _step_error_cov_inv_diag(key, bart) 1ilm
1466 elif bart.error_cov_inv.ndim == 2: 1adeb
1467 return _step_error_cov_inv_mv(key, bart) 1de
1468 else:
1469 return _step_error_cov_inv_uv(key, bart) 1ab
1472@named_call
1473def step_z(key: Key[Array, ''], bart: State) -> State:
1474 """
1475 MCMC-update the latent variable for binary regression.
1477 Parameters
1478 ----------
1479 key
1480 A jax random key.
1481 bart
1482 A BART MCMC state.
1484 Returns
1485 -------
1486 The updated BART MCMC state.
1487 """
1488 assert bart.z is not None 1hj
1489 assert bart.binary_y is not None 1hj
1491 if bart.binary_indices is not None: 1hilmj
1492 resid = bart.resid[..., bart.binary_indices, :] 1ilm
1493 else:
1494 resid = bart.resid 1hj
1496 trees_plus_offset = bart.z - resid 1hj
1497 resid = truncated_normal_onesided(key, (), ~bart.binary_y, -trees_plus_offset) 1hj
1498 z = trees_plus_offset + resid 1hj
1500 if bart.binary_indices is not None: 1hilmj
1501 resid = bart.resid.at[..., bart.binary_indices, :].set(resid) 1ilm
1503 return replace(bart, z=z, resid=resid) 1hj
1506@named_call
1507def step_s(key: Key[Array, ''], bart: State) -> State:
1508 """
1509 Update `log_s` using Dirichlet sampling.
1511 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
1512 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
1513 varcount is the count of how many times each variable is used in the
1514 current forest.
1516 Parameters
1517 ----------
1518 key
1519 Random key for sampling.
1520 bart
1521 The current BART state.
1523 Returns
1524 -------
1525 Updated BART state with re-sampled `log_s`.
1527 Notes
1528 -----
1529 This full conditional is approximated, because it does not take into account
1530 that there are forbidden decision rules.
1531 """
1532 assert bart.forest.theta is not None 1afg
1534 # histogram current variable usage
1535 p = bart.forest.max_split.size 1afg
1536 varcount = var_histogram( 1afg
1537 p, bart.forest.var_tree, bart.forest.split_tree, sum_batch_axis=-1
1538 )
1540 # sample from Dirichlet posterior
1541 alpha = bart.forest.theta / p + varcount 1afg
1542 log_s = random.loggamma(key, alpha) 1afg
1544 # update forest with new s
1545 return replace(bart, forest=replace(bart.forest, log_s=log_s)) 1afg
1548@named_call
1549def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State:
1550 """
1551 Update `theta`.
1553 The prior is theta / (theta + rho) ~ Beta(a, b).
1555 Parameters
1556 ----------
1557 key
1558 Random key for sampling.
1559 bart
1560 The current BART state.
1561 num_grid
1562 The number of points in the evenly-spaced grid used to sample
1563 theta / (theta + rho).
1565 Returns
1566 -------
1567 Updated BART state with re-sampled `theta`.
1568 """
1569 assert bart.forest.log_s is not None 1afg
1570 assert bart.forest.rho is not None 1afg
1571 assert bart.forest.a is not None 1afg
1572 assert bart.forest.b is not None 1afg
1574 # the grid points are the midpoints of num_grid bins in (0, 1)
1575 padding = 1 / (2 * num_grid) 1afg
1576 lamda_grid = jnp.linspace(padding, 1 - padding, num_grid) 1afg
1578 # normalize s
1579 log_s = bart.forest.log_s - logsumexp(bart.forest.log_s) 1afg
1581 # sample lambda
1582 logp, theta_grid = _log_p_lamda( 1afg
1583 lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
1584 )
1585 i = random.categorical(key, logp) 1afg
1586 theta = theta_grid[i] 1afg
1588 return replace(bart, forest=replace(bart.forest, theta=theta)) 1afg
1591def _log_p_lamda(
1592 lamda: Float32[Array, ' num_grid'],
1593 log_s: Float32[Array, ' p'],
1594 rho: Float32[Array, ''],
1595 a: Float32[Array, ''],
1596 b: Float32[Array, ''],
1597) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
1598 # in the following I use lamda[::-1] == 1 - lamda
1599 theta = rho * lamda / lamda[::-1] 1afg
1600 p = log_s.size 1afg
1601 return ( 1afg
1602 (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
1603 + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
1604 + gammaln(theta)
1605 - p * gammaln(theta / p)
1606 + theta / p * jnp.sum(log_s)
1607 ), theta
1610@named_call
1611def step_sparse(key: Key[Array, ''], bart: State) -> State:
1612 """
1613 Update the sparsity parameters.
1615 This invokes `step_s`, and then `step_theta` only if the parameters of
1616 the theta prior are defined.
1618 Parameters
1619 ----------
1620 key
1621 Random key for sampling.
1622 bart
1623 The current BART state.
1625 Returns
1626 -------
1627 Updated BART state with re-sampled `log_s` and `theta`.
1628 """
1629 if bart.config.sparse_on_at is not None: 1xafgb
1630 bart = lax.cond( 1afg
1631 bart.config.steps_done < bart.config.sparse_on_at,
1632 lambda _key, bart: bart,
1633 _step_sparse,
1634 key,
1635 bart,
1636 )
1637 return bart 1xab
1640def _step_sparse(key: Key[Array, ''], bart: State) -> State:
1641 keys = split(key) 1afg
1642 bart = step_s(keys.pop(), bart) 1afg
1643 if bart.forest.rho is not None: 1akyfg
1644 bart = step_theta(keys.pop(), bart) 1afg
1645 return bart 1akyfg
1648@named_call
1649def step_config(bart: State) -> State:
1650 config = bart.config 1ab
1651 config = replace(config, steps_done=config.steps_done + 1) 1ab
1652 return replace(bart, config=config) 1ab