Coverage for src / bartz / mcmcstep / _step.py: 99%
377 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
1# bartz/src/bartz/mcmcstep/_step.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement `step`, `step_trees`, and the accept-reject logic."""
27from dataclasses import replace
28from functools import partial
30try:
31 # available since jax v0.6.1
32 from jax import shard_map
33except ImportError:
34 # deprecated in jax v0.8.0
35 from jax.experimental.shard_map import shard_map
37import jax
38from equinox import Module, tree_at
39from jax import lax, random
40from jax import numpy as jnp
41from jax.lax import cond
42from jax.scipy.linalg import solve_triangular
43from jax.scipy.special import gammaln, logsumexp
44from jax.sharding import Mesh, PartitionSpec
45from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt, UInt32
47from bartz._profiler import (
48 get_profile_mode,
49 jit_and_block_if_profiling,
50 jit_if_not_profiling,
51 jit_if_profiling,
52 vmap_chains_if_not_profiling,
53 vmap_chains_if_profiling,
54)
55from bartz.grove import var_histogram
56from bartz.jaxext import split, truncated_normal_onesided, vmap_nodoc
57from bartz.mcmcstep._moves import Moves, propose_moves
58from bartz.mcmcstep._state import State, StepConfig, chol_with_gersh, field
61@partial(jit_if_not_profiling, donate_argnums=(1,))
62@partial(vmap_chains_if_not_profiling, auto_split_keys=True)
63def step(key: Key[Array, ''], bart: State) -> State:
64 """
65 Do one MCMC step.
67 Parameters
68 ----------
69 key
70 A jax random key.
71 bart
72 A BART mcmc state, as created by `init`.
74 Returns
75 -------
76 The new BART mcmc state.
78 Notes
79 -----
80 The memory of the input state is re-used for the output state, so the input
81 state can not be used any more after calling `step`. All this applies
82 outside of `jax.jit`.
83 """
84 # handle the interactions between chains and profile mode
85 num_chains = bart.forest.num_chains() 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
86 chain_shape = () if num_chains is None else (num_chains,) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
87 if get_profile_mode() and num_chains is not None and key.ndim == 0: 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
88 key = random.split(key, num_chains) 1#J:=L
89 assert key.shape == chain_shape 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
91 keys = split(key, 3) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
93 if bart.y.dtype == bool: 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
94 bart = replace(bart, error_cov_inv=jnp.ones(chain_shape)) 1#:9!)*+(,
95 bart = step_trees(keys.pop(), bart) 1#:9!)*+(,
96 bart = replace(bart, error_cov_inv=None) 1#:9!)*+(,
97 bart = step_z(keys.pop(), bart) 1#:9!)*+(,
99 else: # continuous or multivariate regression
100 bart = step_trees(keys.pop(), bart) 1fJ?=eMgKhNrVsWiYAZ$%'jOkPXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
101 bart = step_error_cov_inv(keys.pop(), bart) 1fJ?=eMgKhNrVsWiYAZ$%'jOkPXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
103 bart = step_sparse(keys.pop(), bart) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
104 return step_config(bart) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
107def step_trees(key: Key[Array, ''], bart: State) -> State:
108 """
109 Forest sampling step of BART MCMC.
111 Parameters
112 ----------
113 key
114 A jax random key.
115 bart
116 A BART mcmc state, as created by `init`.
118 Returns
119 -------
120 The new BART mcmc state.
122 Notes
123 -----
124 This function zeroes the proposal counters.
125 """
126 keys = split(key) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
127 moves = propose_moves(keys.pop(), bart.forest) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
128 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
131def accept_moves_and_sample_leaves(
132 key: Key[Array, ''], bart: State, moves: Moves
133) -> State:
134 """
135 Accept or reject the proposed moves and sample the new leaf values.
137 Parameters
138 ----------
139 key
140 A jax random key.
141 bart
142 A valid BART mcmc state.
143 moves
144 The proposed moves, see `propose_moves`.
146 Returns
147 -------
148 A new (valid) BART mcmc state.
149 """
150 pso = accept_moves_parallel_stage(key, bart, moves) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
151 bart, moves = accept_moves_sequential_stage(pso) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
152 return accept_moves_final_stage(bart, moves) 1f#J?:=e9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
155class Counts(Module):
156 """
157 Number of datapoints in the nodes involved in proposed moves for each tree.
159 Parameters
160 ----------
161 left
162 Number of datapoints in the left child.
163 right
164 Number of datapoints in the right child.
165 total
166 Number of datapoints in the parent (``= left + right``).
167 """
169 left: UInt[Array, '*chains num_trees'] = field(chains=True)
170 right: UInt[Array, '*chains num_trees'] = field(chains=True)
171 total: UInt[Array, '*chains num_trees'] = field(chains=True)
174class Precs(Module):
175 """
176 Likelihood precision scale in the nodes involved in proposed moves for each tree.
178 The "likelihood precision scale" of a tree node is the sum of the inverse
179 squared error scales of the datapoints selected by the node.
181 Parameters
182 ----------
183 left
184 Likelihood precision scale in the left child.
185 right
186 Likelihood precision scale in the right child.
187 total
188 Likelihood precision scale in the parent (``= left + right``).
189 """
191 left: Float32[Array, '*chains num_trees'] = field(chains=True)
192 right: Float32[Array, '*chains num_trees'] = field(chains=True)
193 total: Float32[Array, '*chains num_trees'] = field(chains=True)
196class PreLkV(Module):
197 """
198 Non-sequential terms of the likelihood ratio for each tree.
200 These terms can be computed in parallel across trees.
202 Parameters
203 ----------
204 left
205 right
206 total
207 In the univariate case, this is the scalar term
209 ``1 / error_cov_inv + n_* / leaf_prior_cov_inv``
211 In the multivariate case, this is the matrix term
213 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_* * error_cov_inv) @ error_cov_inv``
215 In both cases, ``n_*`` is n_left/right/total, the number of datapoints
216 respectively in the left child, right child, and parent node, or the
217 likelihood precision scale in the heteroskedastic case.
218 log_sqrt_term
219 The logarithm of the square root term of the likelihood ratio.
220 """
222 left: (
223 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
224 ) = field(chains=True)
225 right: (
226 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
227 ) = field(chains=True)
228 total: (
229 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
230 ) = field(chains=True)
231 log_sqrt_term: Float32[Array, '*chains num_trees'] = field(chains=True)
234class PreLk(Module):
235 """
236 Non-sequential terms of the likelihood ratio shared by all trees.
238 Parameters
239 ----------
240 exp_factor
241 The factor to multiply the likelihood ratio by, shared by all trees.
242 """
244 exp_factor: Float32[Array, '*chains'] = field(chains=True)
247class PreLf(Module):
248 """
249 Pre-computed terms used to sample leaves from their posterior.
251 These terms can be computed in parallel across trees.
253 For each tree and leaf, the terms are scalars in the univariate case, and
254 matrices/vectors in the multivariate case.
256 Parameters
257 ----------
258 mean_factor
259 The factor to be right-multiplied by the sum of the scaled residuals to
260 obtain the posterior mean.
261 centered_leaves
262 The mean-zero normal values to be added to the posterior mean to
263 obtain the posterior leaf samples.
264 """
266 mean_factor: (
267 Float32[Array, '*chains num_trees 2**d']
268 | Float32[Array, '*chains num_trees k k 2**d']
269 ) = field(chains=True)
270 centered_leaves: (
271 Float32[Array, '*chains num_trees 2**d']
272 | Float32[Array, '*chains num_trees k 2**d']
273 ) = field(chains=True)
276class ParallelStageOut(Module):
277 """
278 The output of `accept_moves_parallel_stage`.
280 Parameters
281 ----------
282 bart
283 A partially updated BART mcmc state.
284 moves
285 The proposed moves, with `partial_ratio` set to `None` and
286 `log_trans_prior_ratio` set to its final value.
287 prec_trees
288 The likelihood precision scale in each potential or actual leaf node. If
289 there is no precision scale, this is the number of points in each leaf.
290 move_counts
291 The counts of the number of points in the the nodes modified by the
292 moves. If `bart.min_points_per_leaf` is not set and
293 `bart.prec_scale` is set, they are not computed.
294 move_precs
295 The likelihood precision scale in each node modified by the moves. If
296 `bart.prec_scale` is not set, this is set to `move_counts`.
297 prelkv
298 prelk
299 prelf
300 Objects with pre-computed terms of the likelihood ratios and leaf
301 samples.
302 """
304 bart: State
305 moves: Moves
306 prec_trees: (
307 Float32[Array, '*chains num_trees 2**d']
308 | Int32[Array, '*chains num_trees 2**d']
309 ) = field(chains=True)
310 move_precs: Precs | Counts
311 prelkv: PreLkV
312 prelk: PreLk | None
313 prelf: PreLf
316@partial(jit_and_block_if_profiling, donate_argnums=(1, 2))
317@vmap_chains_if_profiling
318def accept_moves_parallel_stage(
319 key: Key[Array, ''], bart: State, moves: Moves
320) -> ParallelStageOut:
321 """
322 Pre-compute quantities used to accept moves, in parallel across trees.
324 Parameters
325 ----------
326 key
327 A jax random key.
328 bart
329 A BART mcmc state.
330 moves
331 The proposed moves, see `propose_moves`.
333 Returns
334 -------
335 An object with all that could be done in parallel.
336 """
337 # where the move is grow, modify the state like the move was accepted
338 bart = replace( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
339 bart,
340 forest=replace(
341 bart.forest,
342 var_tree=moves.var_tree,
343 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
344 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
345 ),
346 )
348 # count number of datapoints per leaf
349 if ( 349 ↛ 359line 349 didn't jump to line 359 because the condition on line 349 was always true1eKs(5t
350 bart.forest.min_points_per_decision_node is not None
351 or bart.forest.min_points_per_leaf is not None
352 or bart.prec_scale is None
353 ):
354 count_trees, move_counts = compute_count_trees( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
355 bart.forest.leaf_indices, moves, bart.config
356 )
358 # mark which leaves & potential leaves have enough points to be grown
359 if bart.forest.min_points_per_decision_node is not None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
360 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1fJe9Mg!KhNrViYAZjOkPXx4lmn6CDo7EFI8pqG5
361 moves = replace( 1fJe9Mg!KhNrViYAZjOkPXx4lmn6CDo7EFI8pqG5
362 moves,
363 affluence_tree=moves.affluence_tree
364 & (count_half_trees >= bart.forest.min_points_per_decision_node),
365 )
367 # copy updated affluence_tree to state
368 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
370 # veto grove move if new leaves don't have enough datapoints
371 if bart.forest.min_points_per_leaf is not None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
372 moves = replace( 1fJe9Mg!KhNsWiYAZjOkPXx4
373 moves,
374 allowed=moves.allowed
375 & (move_counts.left >= bart.forest.min_points_per_leaf)
376 & (move_counts.right >= bart.forest.min_points_per_leaf),
377 )
379 # count number of datapoints per leaf, weighted by error precision scale
380 if bart.prec_scale is None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
381 prec_trees = count_trees 1f#e9g!h)rsi*A+$%'j(k,Xx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
382 move_precs = move_counts 1f#e9g!h)rsi*A+$%'j(k,Xx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
383 else:
384 prec_trees, move_precs = compute_prec_trees( 1JMKNVWYZOP
385 bart.prec_scale, bart.forest.leaf_indices, moves, bart.config
386 )
387 assert move_precs is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
389 # compute some missing information about moves
390 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
391 save_ratios = bart.forest.log_likelihood is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
392 bart = replace( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
393 bart,
394 forest=replace(
395 bart.forest,
396 grow_prop_count=jnp.sum(moves.grow),
397 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
398 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
399 ),
400 )
402 assert bart.error_cov_inv is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
403 prelkv, prelk = precompute_likelihood_terms( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
404 bart.error_cov_inv, bart.forest.leaf_prior_cov_inv, move_precs
405 )
406 prelf = precompute_leaf_terms( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
407 key, prec_trees, bart.error_cov_inv, bart.forest.leaf_prior_cov_inv
408 )
410 return ParallelStageOut( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
411 bart=bart,
412 moves=moves,
413 prec_trees=prec_trees,
414 move_precs=move_precs,
415 prelkv=prelkv,
416 prelk=prelk,
417 prelf=prelf,
418 )
421@partial(vmap_nodoc, in_axes=(0, 0, None))
422def apply_grow_to_indices(
423 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
424) -> UInt[Array, 'num_trees n']:
425 """
426 Update the leaf indices to apply a grow move.
428 Parameters
429 ----------
430 moves
431 The proposed moves, see `propose_moves`.
432 leaf_indices
433 The index of the leaf each datapoint falls into.
434 X
435 The predictors matrix.
437 Returns
438 -------
439 The updated leaf indices.
440 """
441 left_child = moves.node.astype(leaf_indices.dtype) << 1 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
442 x: UInt[Array, ' n'] = X[moves.grow_var, :] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
443 go_right = x >= moves.grow_split 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
444 tree_size = jnp.array(2 * moves.var_tree.size) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
445 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
446 return jnp.where( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
447 leaf_indices == node_to_update, left_child + go_right, leaf_indices
448 )
451@partial(vmap_nodoc, in_axes=(None, 0, 0, None))
452def _compute_count_or_prec_trees(
453 prec_scale: Float32[Array, ' n'] | None,
454 leaf_indices: UInt[Array, 'num_trees n'],
455 moves: Moves,
456 config: StepConfig,
457) -> (
458 tuple[UInt32[Array, 'num_trees 2**d'], Counts]
459 | tuple[Float32[Array, 'num_trees 2**d'], Precs]
460):
461 (tree_size,) = moves.var_tree.shape 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
462 tree_size *= 2 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
464 if prec_scale is None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
465 value = 1 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
466 cls = Counts 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
467 dtype = jnp.uint32 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
468 else:
469 value = prec_scale 1JMKNVWYZOP
470 cls = Precs 1JMKNVWYZOP
471 dtype = jnp.float32 1JMKNVWYZOP
473 trees = _scatter_add( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
474 value, leaf_indices, tree_size, dtype, config.count_batch_size, config.mesh
475 )
477 # count datapoints in nodes modified by move
478 left = trees[moves.left] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
479 right = trees[moves.right] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
480 counts = cls(left=left, right=right, total=left + right) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
482 # write count into non-leaf node
483 trees = trees.at[moves.node].set(counts.total) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
485 return trees, counts 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
488def compute_count_trees(
489 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, config: StepConfig
490) -> tuple[UInt32[Array, 'num_trees 2**d'], Counts]:
491 """
492 Count the number of datapoints in each leaf.
494 Parameters
495 ----------
496 leaf_indices
497 The index of the leaf each datapoint falls into, with the deeper version
498 of the tree (post-GROW, pre-PRUNE).
499 moves
500 The proposed moves, see `propose_moves`.
501 config
502 The MCMC configuration.
504 Returns
505 -------
506 count_trees : Int32[Array, 'num_trees 2**d']
507 The number of points in each potential or actual leaf node.
508 counts : Counts
509 The counts of the number of points in the leaves grown or pruned by the
510 moves.
511 """
512 return _compute_count_or_prec_trees(None, leaf_indices, moves, config) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
515def compute_prec_trees(
516 prec_scale: Float32[Array, ' n'],
517 leaf_indices: UInt[Array, 'num_trees n'],
518 moves: Moves,
519 config: StepConfig,
520) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
521 """
522 Compute the likelihood precision scale in each leaf.
524 Parameters
525 ----------
526 prec_scale
527 The scale of the precision of the error on each datapoint.
528 leaf_indices
529 The index of the leaf each datapoint falls into, with the deeper version
530 of the tree (post-GROW, pre-PRUNE).
531 moves
532 The proposed moves, see `propose_moves`.
533 config
534 The MCMC configuration.
536 Returns
537 -------
538 prec_trees : Float32[Array, 'num_trees 2**d']
539 The likelihood precision scale in each potential or actual leaf node.
540 precs : Precs
541 The likelihood precision scale in the nodes involved in the moves.
542 """
543 return _compute_count_or_prec_trees(prec_scale, leaf_indices, moves, config) 1JMKNVWYZOP
546@partial(vmap_nodoc, in_axes=(0, None))
547def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves:
548 """
549 Complete non-likelihood MH ratio calculation.
551 This function adds the probability of choosing a prune move over the grow
552 move in the inverse transition, and the a priori probability that the
553 children nodes are leaves.
555 Parameters
556 ----------
557 moves
558 The proposed moves. Must have already been updated to keep into account
559 the thresholds on the number of datapoints per node, this happens in
560 `accept_moves_parallel_stage`.
561 p_nonterminal
562 The a priori probability of each node being nonterminal conditional on
563 its ancestors, including at the maximum depth where it should be zero.
565 Returns
566 -------
567 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
568 """
569 # can the leaves be grown?
570 left_growable = moves.affluence_tree.at[moves.left].get( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
571 mode='fill', fill_value=False
572 )
573 right_growable = moves.affluence_tree.at[moves.right].get( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
574 mode='fill', fill_value=False
575 )
577 # p_prune if grow
578 other_growable_leaves = moves.num_growable >= 2 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
579 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
580 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1.0) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
582 # p_prune if prune
583 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
585 # select p_prune
586 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
588 # prior probability of both children being terminal
589 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
590 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
591 pt_children = pt_left * pt_right 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
593 assert moves.partial_ratio is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
594 return replace( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
595 moves,
596 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
597 partial_ratio=None,
598 )
601@vmap_nodoc
602def adapt_leaf_trees_to_grow_indices(
603 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
604) -> Float32[Array, 'num_trees 2**d']:
605 """
606 Modify leaves such that post-grow indices work on the original tree.
608 The value of the leaf to grow is copied to what would be its children if the
609 grow move was accepted.
611 Parameters
612 ----------
613 leaf_trees
614 The leaf values.
615 moves
616 The proposed moves, see `propose_moves`.
618 Returns
619 -------
620 The modified leaf values.
621 """
622 values_at_node = leaf_trees[..., moves.node] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
623 return ( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
624 leaf_trees.at[..., jnp.where(moves.grow, moves.left, leaf_trees.size)]
625 .set(values_at_node)
626 .at[..., jnp.where(moves.grow, moves.right, leaf_trees.size)]
627 .set(values_at_node)
628 )
631def _logdet_from_chol(L: Float32[Array, '... k k']) -> Float32[Array, '...']:
632 """Compute logdet of A = LL' via Cholesky (sum of log of diag^2)."""
633 diags: Float32[Array, '... k'] = jnp.diagonal(L, axis1=-2, axis2=-1) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
634 return 2.0 * jnp.sum(jnp.log(diags), axis=-1) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
637def _precompute_likelihood_terms_uv(
638 error_cov_inv: Float32[Array, ''],
639 leaf_prior_cov_inv: Float32[Array, ''],
640 move_precs: Precs | Counts,
641) -> tuple[PreLkV, PreLk]:
642 sigma2 = lax.reciprocal(error_cov_inv) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.
643 sigma_mu2 = lax.reciprocal(leaf_prior_cov_inv) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.
644 left = sigma2 + move_precs.left * sigma_mu2 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.
645 right = sigma2 + move_precs.right * sigma_mu2 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.
646 total = sigma2 + move_precs.total * sigma_mu2 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.
647 prelkv = PreLkV( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.
648 left=left,
649 right=right,
650 total=total,
651 log_sqrt_term=jnp.log(sigma2 * total / (left * right)) / 2,
652 )
653 return prelkv, PreLk(exp_factor=error_cov_inv / leaf_prior_cov_inv / 2) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.
656def _precompute_likelihood_terms_mv(
657 error_cov_inv: Float32[Array, 'k k'],
658 leaf_prior_cov_inv: Float32[Array, 'k k'],
659 move_precs: Counts,
660) -> tuple[PreLkV, None]:
661 nL: UInt[Array, 'num_trees 1 1'] = move_precs.left[..., None, None] 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
662 nR: UInt[Array, 'num_trees 1 1'] = move_precs.right[..., None, None] 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
663 nT: UInt[Array, 'num_trees 1 1'] = move_precs.total[..., None, None] 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
665 L_left: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
666 error_cov_inv * nL + leaf_prior_cov_inv
667 )
668 L_right: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
669 error_cov_inv * nR + leaf_prior_cov_inv
670 )
671 L_total: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
672 error_cov_inv * nT + leaf_prior_cov_inv
673 )
675 log_sqrt_term: Float32[Array, ' num_trees'] = 0.5 * ( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
676 _logdet_from_chol(chol_with_gersh(leaf_prior_cov_inv))
677 + _logdet_from_chol(L_total)
678 - _logdet_from_chol(L_left)
679 - _logdet_from_chol(L_right)
680 )
682 def _term_from_chol( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
683 L: Float32[Array, 'num_trees k k'],
684 ) -> Float32[Array, 'num_trees k k']:
685 rhs: Float32[Array, 'num_trees k k'] = jnp.broadcast_to(error_cov_inv, L.shape) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
686 Y: Float32[Array, 'num_trees k k'] = solve_triangular(L, rhs, lower=True) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
687 return Y.mT @ Y 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
689 prelkv = PreLkV( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
690 left=_term_from_chol(L_left),
691 right=_term_from_chol(L_right),
692 total=_term_from_chol(L_total),
693 log_sqrt_term=log_sqrt_term,
694 )
696 return prelkv, None 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
699def precompute_likelihood_terms(
700 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
701 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
702 move_precs: Precs | Counts,
703) -> tuple[PreLkV, PreLk | None]:
704 """
705 Pre-compute terms used in the likelihood ratio of the acceptance step.
707 Handles both univariate and multivariate cases based on the shape of the
708 input arrays. The multivariate implementation assumes a homoskedastic error
709 model (i.e., the residual covariance is the same for all observations).
711 Parameters
712 ----------
713 error_cov_inv
714 The inverse error variance (univariate) or the inverse of the error
715 covariance matrix (multivariate). For univariate case, this is the
716 inverse global error variance factor if `prec_scale` is set.
717 leaf_prior_cov_inv
718 The inverse prior variance of each leaf (univariate) or the inverse of
719 prior covariance matrix of each leaf (multivariate).
720 move_precs
721 The likelihood precision scale in the leaves grown or pruned by the
722 moves, under keys 'left', 'right', and 'total' (left + right).
724 Returns
725 -------
726 prelkv : PreLkV
727 Pre-computed terms of the likelihood ratio, one per tree.
728 prelk : PreLk | None
729 Pre-computed terms of the likelihood ratio, shared by all trees.
730 """
731 if error_cov_inv.ndim == 2: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
732 assert isinstance(move_precs, Counts) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123
733 return _precompute_likelihood_terms_mv( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123
734 error_cov_inv, leaf_prior_cov_inv, move_precs
735 )
736 else:
737 return _precompute_likelihood_terms_uv( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123
738 error_cov_inv, leaf_prior_cov_inv, move_precs
739 )
742def _precompute_leaf_terms_uv(
743 key: Key[Array, ''],
744 prec_trees: Float32[Array, 'num_trees 2**d'],
745 error_cov_inv: Float32[Array, ''],
746 leaf_prior_cov_inv: Float32[Array, ''],
747 z: Float32[Array, 'num_trees 2**d'] | None = None,
748) -> PreLf:
749 prec_lk = prec_trees * error_cov_inv 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123;
750 var_post = lax.reciprocal(prec_lk + leaf_prior_cov_inv) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123;
751 if z is None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123;
752 z = random.normal(key, prec_trees.shape, error_cov_inv.dtype) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123
753 return PreLf( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123;
754 mean_factor=var_post * error_cov_inv,
755 # | mean = mean_lk * prec_lk * var_post
756 # | resid_tree = mean_lk * prec_tree -->
757 # | --> mean_lk = resid_tree / prec_tree (kind of)
758 # | mean_factor =
759 # | = mean / resid_tree =
760 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
761 # | = 1 / prec_tree * prec_tree / sigma2 * var_post =
762 # | = var_post / sigma2
763 centered_leaves=z * jnp.sqrt(var_post),
764 )
767def _precompute_leaf_terms_mv(
768 key: Key[Array, ''],
769 prec_trees: Float32[Array, 'num_trees 2**d'],
770 error_cov_inv: Float32[Array, 'k k'],
771 leaf_prior_cov_inv: Float32[Array, 'k k'],
772 z: Float32[Array, 'num_trees 2**d k'] | None = None,
773) -> PreLf:
774 num_trees, tree_size = prec_trees.shape 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
775 k = error_cov_inv.shape[0] 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
776 n_k: Float32[Array, 'num_trees tree_size 1 1'] = prec_trees[..., None, None] 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
778 # Only broadcast the inverse of error covariance matrix to satisfy JAX's
779 # batching rules for `lax.linalg.solve_triangular`, which does not support
780 # implicit broadcasting.
781 error_cov_inv_batched = jnp.broadcast_to( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
782 error_cov_inv, (num_trees, tree_size, k, k)
783 )
785 posterior_precision: Float32[Array, 'num_trees tree_size k k'] = ( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
786 leaf_prior_cov_inv + n_k * error_cov_inv_batched
787 )
789 L_prec: Float32[Array, 'num_trees tree_size k k'] = chol_with_gersh( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
790 posterior_precision
791 )
792 Y: Float32[Array, 'num_trees tree_size k k'] = solve_triangular( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
793 L_prec, error_cov_inv_batched, lower=True
794 )
795 mean_factor: Float32[Array, 'num_trees tree_size k k'] = solve_triangular( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
796 L_prec, Y, trans='T', lower=True
797 )
798 mean_factor = mean_factor.mT 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
799 mean_factor_out: Float32[Array, 'num_trees k k tree_size'] = jnp.moveaxis( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
800 mean_factor, 1, -1
801 )
803 if z is None: 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
804 z = random.normal(key, (num_trees, tree_size, k)) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123@[]^
805 centered_leaves: Float32[Array, 'num_trees tree_size k'] = solve_triangular( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
806 L_prec, z, trans='T'
807 )
808 centered_leaves_out: Float32[Array, 'num_trees k tree_size'] = jnp.swapaxes( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
809 centered_leaves, -1, -2
810 )
812 return PreLf(mean_factor=mean_factor_out, centered_leaves=centered_leaves_out) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123;@[]^
815def precompute_leaf_terms(
816 key: Key[Array, ''],
817 prec_trees: Float32[Array, 'num_trees 2**d'],
818 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
819 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
820 z: Float32[Array, 'num_trees 2**d']
821 | Float32[Array, 'num_trees 2**d k']
822 | None = None,
823) -> PreLf:
824 """
825 Pre-compute terms used to sample leaves from their posterior.
827 Handles both univariate and multivariate cases based on the shape of the
828 input arrays.
830 Parameters
831 ----------
832 key
833 A jax random key.
834 prec_trees
835 The likelihood precision scale in each potential or actual leaf node.
836 error_cov_inv
837 The inverse error variance (univariate) or the inverse of error
838 covariance matrix (multivariate). For univariate case, this is the
839 inverse global error variance factor if `prec_scale` is set.
840 leaf_prior_cov_inv
841 The inverse prior variance of each leaf (univariate) or the inverse of
842 prior covariance matrix of each leaf (multivariate).
843 z
844 Optional standard normal noise to use for sampling the centered leaves.
845 This is intended for testing purposes only.
847 Returns
848 -------
849 Pre-computed terms for leaf sampling.
850 """
851 if error_cov_inv.ndim == 2: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
852 return _precompute_leaf_terms_mv( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123
853 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
854 )
855 else:
856 return _precompute_leaf_terms_uv( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123
857 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
858 )
861@partial(jit_and_block_if_profiling, donate_argnums=(0,))
862@vmap_chains_if_profiling
863def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
864 """
865 Accept/reject the moves one tree at a time.
867 This is the most performance-sensitive function because it contains all and
868 only the parts of the algorithm that can not be parallelized across trees.
870 Parameters
871 ----------
872 pso
873 The output of `accept_moves_parallel_stage`.
875 Returns
876 -------
877 bart : State
878 A partially updated BART mcmc state.
879 moves : Moves
880 The accepted/rejected moves, with `acc` and `to_prune` set.
881 """
883 def loop(resid, pt): 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
884 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
885 resid,
886 SeqStageInAllTrees(
887 pso.bart.X,
888 pso.bart.config.resid_batch_size,
889 pso.bart.config.mesh,
890 pso.bart.prec_scale,
891 pso.bart.forest.log_likelihood is not None,
892 pso.prelk,
893 ),
894 pt,
895 )
896 return resid, (leaf_tree, acc, to_prune, lkratio) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
898 pts = SeqStageInPerTree( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
899 pso.bart.forest.leaf_tree,
900 pso.prec_trees,
901 pso.moves,
902 pso.move_precs,
903 pso.bart.forest.leaf_indices,
904 pso.prelkv,
905 pso.prelf,
906 )
907 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
909 bart = replace( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
910 pso.bart,
911 resid=resid,
912 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
913 )
914 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
916 return bart, moves 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
919class SeqStageInAllTrees(Module):
920 """
921 The inputs to `accept_move_and_sample_leaves` that are shared by all trees.
923 Parameters
924 ----------
925 X
926 The predictors.
927 resid_batch_size
928 The batch size for computing the sum of residuals in each leaf.
929 mesh
930 The mesh of devices to use.
931 prec_scale
932 The scale of the precision of the error on each datapoint. If None, it
933 is assumed to be 1.
934 save_ratios
935 Whether to save the acceptance ratios.
936 prelk
937 The pre-computed terms of the likelihood ratio which are shared across
938 trees.
939 """
941 X: UInt[Array, 'p n']
942 resid_batch_size: int | None = field(static=True)
943 mesh: Mesh | None = field(static=True)
944 prec_scale: Float32[Array, ' n'] | None
945 save_ratios: bool = field(static=True)
946 prelk: PreLk | None
949class SeqStageInPerTree(Module):
950 """
951 The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
953 Parameters
954 ----------
955 leaf_tree
956 The leaf values of the tree.
957 prec_tree
958 The likelihood precision scale in each potential or actual leaf node.
959 move
960 The proposed move, see `propose_moves`.
961 move_precs
962 The likelihood precision scale in each node modified by the moves.
963 leaf_indices
964 The leaf indices for the largest version of the tree compatible with
965 the move.
966 prelkv
967 prelf
968 The pre-computed terms of the likelihood ratio and leaf sampling which
969 are specific to the tree.
970 """
972 leaf_tree: Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d']
973 prec_tree: Float32[Array, ' 2**d']
974 move: Moves
975 move_precs: Precs | Counts
976 leaf_indices: UInt[Array, ' n']
977 prelkv: PreLkV
978 prelf: PreLf
981def accept_move_and_sample_leaves(
982 resid: Float32[Array, ' n'] | Float32[Array, ' k n'],
983 at: SeqStageInAllTrees,
984 pt: SeqStageInPerTree,
985) -> tuple[
986 Float32[Array, ' n'] | Float32[Array, ' k n'],
987 Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'],
988 Bool[Array, ''],
989 Bool[Array, ''],
990 Float32[Array, ''] | None,
991]:
992 """
993 Accept or reject a proposed move and sample the new leaf values.
995 Parameters
996 ----------
997 resid
998 The residuals (data minus forest value).
999 at
1000 The inputs that are the same for all trees.
1001 pt
1002 The inputs that are separate for each tree.
1004 Returns
1005 -------
1006 resid : Float32[Array, 'n'] | Float32[Array, ' k n']
1007 The updated residuals (data minus forest value).
1008 leaf_tree : Float32[Array, '2**d'] | Float32[Array, ' k 2**d']
1009 The new leaf values of the tree.
1010 acc : Bool[Array, '']
1011 Whether the move was accepted.
1012 to_prune : Bool[Array, '']
1013 Whether, to reflect the acceptance status of the move, the state should
1014 be updated by pruning the leaves involved in the move.
1015 log_lk_ratio : Float32[Array, ''] | None
1016 The logarithm of the likelihood ratio for the move. `None` if not to be
1017 saved.
1018 """
1019 # sum residuals in each leaf, in tree proposed by grow move
1020 if at.prec_scale is None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1021 scaled_resid = resid 1f#e9g!h)rsi*A+$%'j(k,Xx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1022 else:
1023 scaled_resid = resid * at.prec_scale 1JMKNVWYZOP
1025 tree_size = pt.leaf_tree.shape[-1] # 2**d 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1027 resid_tree = sum_resid( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1028 scaled_resid, pt.leaf_indices, tree_size, at.resid_batch_size, at.mesh
1029 )
1031 # subtract starting tree from function
1032 resid_tree += pt.prec_tree * pt.leaf_tree 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1034 # sum residuals in parent node modified by move and compute likelihood
1035 resid_left = resid_tree[..., pt.move.left] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1036 resid_right = resid_tree[..., pt.move.right] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1037 resid_total = resid_left + resid_right 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1038 assert pt.move.node.dtype == jnp.int32 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1039 resid_tree = resid_tree.at[..., pt.move.node].set(resid_total) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1041 log_lk_ratio = compute_likelihood_ratio( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1042 resid_total, resid_left, resid_right, pt.prelkv, at.prelk
1043 )
1045 # calculate accept/reject ratio
1046 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1047 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1048 if not at.save_ratios: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1049 log_lk_ratio = None 1#9!)*+$%'(,Xx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1051 # determine whether to accept the move
1052 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1054 # compute leaves posterior and sample leaves
1055 if resid.ndim > 1: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1056 mean_post = jnp.einsum('kil,kl->il', pt.prelf.mean_factor, resid_tree) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123
1057 else:
1058 mean_post = resid_tree * pt.prelf.mean_factor 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123
1059 leaf_tree = mean_post + pt.prelf.centered_leaves 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1061 # copy leaves around such that the leaf indices point to the correct leaf
1062 to_prune = acc ^ pt.move.grow 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1063 leaf_tree = ( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1064 leaf_tree.at[..., jnp.where(to_prune, pt.move.left, tree_size)]
1065 .set(leaf_tree[..., pt.move.node])
1066 .at[..., jnp.where(to_prune, pt.move.right, tree_size)]
1067 .set(leaf_tree[..., pt.move.node])
1068 )
1069 # replace old tree with new tree in function values
1070 resid += (pt.leaf_tree - leaf_tree)[..., pt.leaf_indices] 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1072 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1075@partial(jnp.vectorize, excluded=(1, 2, 3, 4), signature='(n)->(ts)')
1076def sum_resid(
1077 scaled_resid: Float32[Array, ' n'] | Float32[Array, 'k n'],
1078 leaf_indices: UInt[Array, ' n'],
1079 tree_size: int,
1080 resid_batch_size: int | None,
1081 mesh: Mesh | None,
1082) -> Float32[Array, ' {tree_size}'] | Float32[Array, 'k {tree_size}']:
1083 """
1084 Sum the residuals in each leaf.
1086 Handles both univariate and multivariate cases based on the shape of the
1087 input arrays.
1089 Parameters
1090 ----------
1091 scaled_resid
1092 The residuals (data minus forest value) multiplied by the error
1093 precision scale. For multivariate case, shape is ``(k, n)`` where ``k``
1094 is the number of outcome columns.
1095 leaf_indices
1096 The leaf indices of the tree (in which leaf each data point falls into).
1097 tree_size
1098 The size of the tree array (2 ** d).
1099 resid_batch_size
1100 The batch size for computing the sum of residuals in each leaf.
1101 mesh
1102 The mesh of devices to use.
1104 Returns
1105 -------
1106 The sum of the residuals at data points in each leaf. For multivariate
1107 case, returns per-leaf sums of residual vectors.
1108 """
1109 return _scatter_add( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1110 scaled_resid, leaf_indices, tree_size, jnp.float32, resid_batch_size, mesh
1111 )
1114def _scatter_add(
1115 values: Float32[Array, ' n'] | int,
1116 indices: Integer[Array, ' n'],
1117 size: int,
1118 dtype: jnp.dtype,
1119 batch_size: int | None,
1120 mesh: Mesh | None,
1121) -> Shaped[Array, ' {size}']:
1122 """Indexed reduce with optional batching."""
1123 # check `values`
1124 values = jnp.asarray(values) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1125 assert values.ndim == 0 or values.shape == indices.shape 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1127 # set configuration
1128 _scatter_add = partial( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1129 _scatter_add_impl, size=size, dtype=dtype, batch_size=batch_size
1130 )
1132 # single-device invocation
1133 if mesh is None or 'data' not in mesh.axis_names: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1134 return _scatter_add(values, indices) 1#J9M!K)NVW*YA+Z$%'(O,PXx4lmn6CDo7EFI8pqG5uHQvwtRLzSTU0123
1136 # multi-device invocation
1137 if values.shape: 1feghrsijkbyBcda
1138 in_specs = PartitionSpec('data'), PartitionSpec('data') 1feghrsijkbyBcda
1139 else:
1140 in_specs = PartitionSpec(), PartitionSpec('data') 1feghrsijkbyBcda
1141 _scatter_add = partial(_scatter_add, final_psum=True) 1feghrsijkbyBcda
1142 _scatter_add = shard_map( 1feghrsijkbyBcda
1143 _scatter_add,
1144 in_specs=in_specs,
1145 out_specs=PartitionSpec(),
1146 mesh=mesh,
1147 **_get_shard_map_patch_kwargs(),
1148 )
1149 return _scatter_add(values, indices) 1feghrsijkbyBcda
1152def _get_shard_map_patch_kwargs():
1153 # see jax/issues/#34249, problem with vmap(shard_map(psum))
1154 if jax.__version__ in ('0.8.1', '0.8.2'): 1feghrsijkbyBcda
1155 return {'check_vma': False} 1ea
1156 else:
1157 return {} 1feghrsijkbyBcda
1160def _scatter_add_impl(
1161 values: Float32[Array, ' n'] | Int32[Array, ''],
1162 indices: Integer[Array, ' n'],
1163 /,
1164 *,
1165 size: int,
1166 dtype: jnp.dtype,
1167 batch_size: int | None,
1168 final_psum: bool = False,
1169) -> Shaped[Array, ' {size}']:
1170 if batch_size is None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1171 out = jnp.zeros(size, dtype).at[indices].add(values) 1fJeMgKhNrVsWiYAZ$%'jOkPXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1173 else:
1174 # in the sharded case, n is the size of the local shard, not the full
1175 # size
1176 (n,) = indices.shape 1#J9M!K)NVW*+(O,PXx4lmn6CDo7EF8pqG5uHvwtbycda
1177 nbatches = n // batch_size + bool(n % batch_size) 1#J9M!K)NVW*+(O,PXx4lmn6CDo7EF8pqG5uHvwtbycda
1178 batch_indices = jnp.arange(n) % nbatches 1#J9M!K)NVW*+(O,PXx4lmn6CDo7EF8pqG5uHvwtbycda
1179 out = ( 1#J9M!K)NVW*+(O,PXx4lmn6CDo7EF8pqG5uHvwtbycda
1180 jnp.zeros((size, nbatches), dtype)
1181 .at[indices, batch_indices]
1182 .add(values)
1183 .sum(axis=1)
1184 )
1186 if final_psum: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1187 out = lax.psum(out, 'data') 1feghrsijkbyBcda
1188 return out 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1191def _compute_likelihood_ratio_uv(
1192 total_resid: Float32[Array, ''],
1193 left_resid: Float32[Array, ''],
1194 right_resid: Float32[Array, ''],
1195 prelkv: PreLkV,
1196 prelk: PreLk,
1197) -> Float32[Array, '']:
1198 exp_term = prelk.exp_factor * ( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.
1199 left_resid * left_resid / prelkv.left
1200 + right_resid * right_resid / prelkv.right
1201 - total_resid * total_resid / prelkv.total
1202 )
1203 return prelkv.log_sqrt_term + exp_term 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123.
1206def _compute_likelihood_ratio_mv(
1207 total_resid: Float32[Array, ' k'],
1208 left_resid: Float32[Array, ' k'],
1209 right_resid: Float32[Array, ' k'],
1210 prelkv: PreLkV,
1211) -> Float32[Array, '']:
1212 def _quadratic_form(r, mat): 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
1213 return r @ mat @ r 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
1215 qf_left = _quadratic_form(left_resid, prelkv.left) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
1216 qf_right = _quadratic_form(right_resid, prelkv.right) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
1217 qf_total = _quadratic_form(total_resid, prelkv.total) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
1218 exp_term = 0.5 * (qf_left + qf_right - qf_total) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
1219 return prelkv.log_sqrt_term + exp_term 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123.
1222def compute_likelihood_ratio(
1223 total_resid: Float32[Array, ''] | Float32[Array, ' k'],
1224 left_resid: Float32[Array, ''] | Float32[Array, ' k'],
1225 right_resid: Float32[Array, ''] | Float32[Array, ' k'],
1226 prelkv: PreLkV,
1227 prelk: PreLk | None,
1228) -> Float32[Array, '']:
1229 """
1230 Compute the likelihood ratio of a grow move.
1232 Handles both univariate and multivariate cases based on the shape of the
1233 residual arrays.
1235 Parameters
1236 ----------
1237 total_resid
1238 left_resid
1239 right_resid
1240 The sum of the residuals (scaled by error precision scale) of the
1241 datapoints falling in the nodes involved in the moves.
1242 prelkv
1243 prelk
1244 The pre-computed terms of the likelihood ratio, see
1245 `precompute_likelihood_terms`.
1247 Returns
1248 -------
1249 The log-likelihood ratio log P(data | new tree) - log P(data | old tree).
1250 """
1251 if total_resid.ndim > 0: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1252 return _compute_likelihood_ratio_mv( 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU0123
1253 total_resid, left_resid, right_resid, prelkv
1254 )
1255 else:
1256 assert prelk is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123
1257 return _compute_likelihood_ratio_uv( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx467850123
1258 total_resid, left_resid, right_resid, prelkv, prelk
1259 )
1262@partial(jit_and_block_if_profiling, donate_argnums=(0, 1))
1263@vmap_chains_if_profiling
1264def accept_moves_final_stage(bart: State, moves: Moves) -> State:
1265 """
1266 Post-process the mcmc state after accepting/rejecting the moves.
1268 This function is separate from `accept_moves_sequential_stage` to signal it
1269 can work in parallel across trees.
1271 Parameters
1272 ----------
1273 bart
1274 A partially updated BART mcmc state.
1275 moves
1276 The proposed moves (see `propose_moves`) as updated by
1277 `accept_moves_sequential_stage`.
1279 Returns
1280 -------
1281 The fully updated BART mcmc state.
1282 """
1283 return replace( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1284 bart,
1285 forest=replace(
1286 bart.forest,
1287 grow_acc_count=jnp.sum(moves.acc & moves.grow),
1288 prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
1289 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
1290 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
1291 ),
1292 )
1295@vmap_nodoc
1296def apply_moves_to_leaf_indices(
1297 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
1298) -> UInt[Array, 'num_trees n']:
1299 """
1300 Update the leaf indices to match the accepted move.
1302 Parameters
1303 ----------
1304 leaf_indices
1305 The index of the leaf each datapoint falls into, if the grow move was
1306 accepted.
1307 moves
1308 The proposed moves (see `propose_moves`), as updated by
1309 `accept_moves_sequential_stage`.
1311 Returns
1312 -------
1313 The updated leaf indices.
1314 """
1315 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1316 is_child = (leaf_indices & mask) == moves.left 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1317 assert moves.to_prune is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1318 return jnp.where( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1319 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
1320 )
1323@vmap_nodoc
1324def apply_moves_to_split_trees(
1325 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
1326) -> UInt[Array, 'num_trees 2**(d-1)']:
1327 """
1328 Update the split trees to match the accepted move.
1330 Parameters
1331 ----------
1332 split_tree
1333 The cutpoints of the decision nodes in the initial trees.
1334 moves
1335 The proposed moves (see `propose_moves`), as updated by
1336 `accept_moves_sequential_stage`.
1338 Returns
1339 -------
1340 The updated split trees.
1341 """
1342 assert moves.to_prune is not None 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1343 return ( 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU0123
1344 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
1345 .set(moves.grow_split.astype(split_tree.dtype))
1346 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
1347 .set(0)
1348 )
1351@jax.jit
1352def _sample_wishart_bartlett(
1353 key: Key[Array, ''], df: Integer[Array, ''], scale_inv: Float32[Array, 'k k']
1354) -> Float32[Array, 'k k']:
1355 """
1356 Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition.
1358 Parameters
1359 ----------
1360 key
1361 A JAX random key
1362 df
1363 Degrees of freedom
1364 scale_inv
1365 Scale matrix of the corresponding Inverse Wishart distribution
1367 Returns
1368 -------
1369 A sample from Wishart(df, scale)
1370 """
1371 keys = split(key) 1lmnopquvwtbcda/z|}
1373 # Diagonal elements: A_ii ~ sqrt(chi^2(df - i))
1374 # chi^2(k) = Gamma(k/2, scale=2)
1375 k, _ = scale_inv.shape 1lmnopquvwtbcda/z|}
1376 df_vector = df - jnp.arange(k) 1lmnopquvwtbcda/z|}
1377 chi2_samples = random.gamma(keys.pop(), df_vector / 2.0) * 2.0 1lmnopquvwtbcda/z|}
1378 diag_A = jnp.sqrt(chi2_samples) 1lmnopquvwtbcda/z|}
1380 off_diag_A = random.normal(keys.pop(), (k, k)) 1lmnopquvwtbcda/z|}
1381 A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A) 1lmnopquvwtbcda/z|}
1382 L = chol_with_gersh(scale_inv, absolute_eps=True) 1lmnopquvwtbcda/z|}
1383 T = solve_triangular(L, A, lower=True, trans='T') 1lmnopquvwtbcda/z|}
1385 return T @ T.T 1lmnopquvwtbcda/z|}
1388def _step_error_cov_inv_uv(key: Key[Array, ''], bart: State) -> State:
1389 resid = bart.resid 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{
1390 # inverse gamma prior: alpha = df / 2, beta = scale / 2
1391 alpha = bart.error_cov_df / 2 + resid.size / 2 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{
1392 if bart.prec_scale is None: 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{
1393 scaled_resid = resid 1feghrsiA$%'jkXx46785/_`{
1394 else:
1395 scaled_resid = resid * bart.prec_scale 1JMKNVWYZOP
1396 norm2 = resid @ scaled_resid 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{
1397 beta = bart.error_cov_scale / 2 + norm2 / 2 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{
1399 sample = random.gamma(key, alpha) 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{
1400 # random.gamma seems to be slow at compiling, maybe cdf inversion would
1401 # be better, but it's not implemented in jax
1402 return replace(bart, error_cov_inv=sample / beta) 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785/_`{
1405def _step_error_cov_inv_mv(key: Key[Array, ''], bart: State) -> State:
1406 n = bart.resid.shape[-1] 1lmnCDoEFIpqGuHQvwtbyBcdaRL/_`{zSTU
1407 df_post = bart.error_cov_df + n 1lmnCDoEFIpqGuHQvwtbyBcdaRL/_`{zSTU
1408 scale_post = bart.error_cov_scale + bart.resid @ bart.resid.T 1lmnCDoEFIpqGuHQvwtbyBcdaRL/_`{zSTU
1410 prec = _sample_wishart_bartlett(key, df_post, scale_post) 1lmnCDoEFIpqGuHQvwtbyBcdaRL/_`{zSTU
1411 return replace(bart, error_cov_inv=prec) 1lmnCDoEFIpqGuHQvwtbyBcdaRL/_`{zSTU
1414@partial(jit_and_block_if_profiling, donate_argnums=(1,))
1415@vmap_chains_if_profiling
1416def step_error_cov_inv(key: Key[Array, ''], bart: State) -> State:
1417 """
1418 MCMC-update the inverse error covariance.
1420 Handles both univariate and multivariate cases based on the BART state's
1421 `kind` attribute.
1423 Parameters
1424 ----------
1425 key
1426 A jax random key.
1427 bart
1428 A BART mcmc state.
1430 Returns
1431 -------
1432 The new BART mcmc state, with an updated `error_cov_inv`.
1433 """
1434 assert bart.error_cov_inv is not None 1fJeMgKhNrVsWiYAZ$%'jOkPXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
1435 if bart.error_cov_inv.ndim == 2: 1fJeMgKhNrVsWiYAZ$%'jOkPXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
1436 return _step_error_cov_inv_mv(key, bart) 1lmnCDoEFIpqGuHQvwtbyBcdaRLzSTU
1437 else:
1438 return _step_error_cov_inv_uv(key, bart) 1fJeMgKhNrVsWiYAZ$%'jOkPXx46785
1441@partial(jit_and_block_if_profiling, donate_argnums=(1,))
1442@vmap_chains_if_profiling
1443def step_z(key: Key[Array, ''], bart: State) -> State:
1444 """
1445 MCMC-update the latent variable for binary regression.
1447 Parameters
1448 ----------
1449 key
1450 A jax random key.
1451 bart
1452 A BART MCMC state.
1454 Returns
1455 -------
1456 The updated BART MCMC state.
1457 """
1458 trees_plus_offset = bart.z - bart.resid 1#9!)*+(,
1459 assert bart.y.dtype == bool 1#9!)*+(,
1460 resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset) 1#9!)*+(,
1461 z = trees_plus_offset + resid 1#9!)*+(,
1462 return replace(bart, z=z, resid=resid) 1#9!)*+(,
1465def step_s(key: Key[Array, ''], bart: State) -> State:
1466 """
1467 Update `log_s` using Dirichlet sampling.
1469 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
1470 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
1471 varcount is the count of how many times each variable is used in the
1472 current forest.
1474 Parameters
1475 ----------
1476 key
1477 Random key for sampling.
1478 bart
1479 The current BART state.
1481 Returns
1482 -------
1483 Updated BART state with re-sampled `log_s`.
1485 Notes
1486 -----
1487 This full conditional is approximated, because it does not take into account
1488 that there are forbidden decision rules.
1489 """
1490 assert bart.forest.theta is not None 1fJeMgKhNrVsWiYAZjOkPXx
1492 # histogram current variable usage
1493 p = bart.forest.max_split.size 1fJeMgKhNrVsWiYAZjOkPXx
1494 varcount = var_histogram( 1fJeMgKhNrVsWiYAZjOkPXx
1495 p, bart.forest.var_tree, bart.forest.split_tree, sum_batch_axis=-1
1496 )
1498 # sample from Dirichlet posterior
1499 alpha = bart.forest.theta / p + varcount 1fJeMgKhNrVsWiYAZjOkPXx
1500 log_s = random.loggamma(key, alpha) 1fJeMgKhNrVsWiYAZjOkPXx
1502 # update forest with new s
1503 return replace(bart, forest=replace(bart.forest, log_s=log_s)) 1fJeMgKhNrVsWiYAZjOkPXx
1506def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State:
1507 """
1508 Update `theta`.
1510 The prior is theta / (theta + rho) ~ Beta(a, b).
1512 Parameters
1513 ----------
1514 key
1515 Random key for sampling.
1516 bart
1517 The current BART state.
1518 num_grid
1519 The number of points in the evenly-spaced grid used to sample
1520 theta / (theta + rho).
1522 Returns
1523 -------
1524 Updated BART state with re-sampled `theta`.
1525 """
1526 assert bart.forest.log_s is not None 1feghrsiAjkx
1527 assert bart.forest.rho is not None 1feghrsiAjkx
1528 assert bart.forest.a is not None 1feghrsiAjkx
1529 assert bart.forest.b is not None 1feghrsiAjkx
1531 # the grid points are the midpoints of num_grid bins in (0, 1)
1532 padding = 1 / (2 * num_grid) 1feghrsiAjkx
1533 lamda_grid = jnp.linspace(padding, 1 - padding, num_grid) 1feghrsiAjkx
1535 # normalize s
1536 log_s = bart.forest.log_s - logsumexp(bart.forest.log_s) 1feghrsiAjkx
1538 # sample lambda
1539 logp, theta_grid = _log_p_lamda( 1feghrsiAjkx
1540 lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
1541 )
1542 i = random.categorical(key, logp) 1feghrsiAjkx
1543 theta = theta_grid[i] 1feghrsiAjkx
1545 return replace(bart, forest=replace(bart.forest, theta=theta)) 1feghrsiAjkx
1548def _log_p_lamda(
1549 lamda: Float32[Array, ' num_grid'],
1550 log_s: Float32[Array, ' p'],
1551 rho: Float32[Array, ''],
1552 a: Float32[Array, ''],
1553 b: Float32[Array, ''],
1554) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
1555 # in the following I use lamda[::-1] == 1 - lamda
1556 theta = rho * lamda / lamda[::-1] 1feghrsiAjkx
1557 p = log_s.size 1feghrsiAjkx
1558 return ( 1feghrsiAjkx
1559 (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
1560 + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
1561 + gammaln(theta)
1562 - p * gammaln(theta / p)
1563 + theta / p * jnp.sum(log_s)
1564 ), theta
1567@partial(jit_and_block_if_profiling, donate_argnums=(1,))
1568@vmap_chains_if_profiling
1569def step_sparse(key: Key[Array, ''], bart: State) -> State:
1570 """
1571 Update the sparsity parameters.
1573 This invokes `step_s`, and then `step_theta` only if the parameters of
1574 the theta prior are defined.
1576 Parameters
1577 ----------
1578 key
1579 Random key for sampling.
1580 bart
1581 The current BART state.
1583 Returns
1584 -------
1585 Updated BART state with re-sampled `log_s` and `theta`.
1586 """
1587 if bart.config.sparse_on_at is not None: 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
1588 bart = cond( 1fJeMgKhNrVsWiYAZjOkPXx
1589 bart.config.steps_done < bart.config.sparse_on_at,
1590 lambda _key, bart: bart,
1591 _step_sparse,
1592 key,
1593 bart,
1594 )
1595 return bart 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
1598def _step_sparse(key, bart):
1599 keys = split(key) 1fJeMgKhNrVsWiYAZjOkPXx
1600 bart = step_s(keys.pop(), bart) 1fJeMgKhNrVsWiYAZjOkPXx
1601 if bart.forest.rho is not None: 1fJeMgKhNrVsWiYAZjOkPXx
1602 bart = step_theta(keys.pop(), bart) 1feghrsiAjkx
1603 return bart 1fJeMgKhNrVsWiYAZjOkPXx
1606@jit_if_profiling
1607# jit to avoid the overhead of replace(_: Module)
1608def step_config(bart):
1609 config = bart.config 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
1610 config = replace(config, steps_done=config.steps_done + 1) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU
1611 return replace(bart, config=config) 1f#Je9Mg!Kh)NrVsWi*YA+Z$%'j(Ok,PXx4lmn6CDo7EFI8pqG5uHQvwtbyBcdaRLzSTU