Coverage for src / bartz / mcmcstep / _step.py: 99%
420 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/mcmcstep/_step.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implement `step`, `step_trees`, and the accept-reject logic."""
27from dataclasses import replace
28from functools import partial
30try:
31 # available since jax v0.6.1
32 from jax import shard_map
33except ImportError:
34 # deprecated in jax v0.8.0
35 from jax.experimental.shard_map import shard_map
37import jax
38from equinox import Module, tree_at
39from jax import jit, lax, named_call, random, vmap
40from jax import numpy as jnp
41from jax.scipy.linalg import solve_triangular
42from jax.scipy.special import gammaln, logsumexp
43from jax.sharding import Mesh, PartitionSpec
44from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt, UInt32
46from bartz.grove import var_histogram
47from bartz.jaxext import split, truncated_normal_onesided, vmap_nodoc
48from bartz.mcmcstep._moves import Moves, propose_moves
49from bartz.mcmcstep._state import State, StepConfig, chol_with_gersh, field, vmap_chains
52@partial(jit, donate_argnums=(1,))
53@vmap_chains
54def step(key: Key[Array, ''], bart: State) -> State:
55 """
56 Do one MCMC step.
58 Parameters
59 ----------
60 key
61 A jax random key.
62 bart
63 A BART mcmc state, as created by `init`.
65 Returns
66 -------
67 The new BART mcmc state.
69 Notes
70 -----
71 The memory of the input state is re-used for the output state, so the input
72 state can not be used any more after calling `step`. All this applies
73 outside of `jax.jit`.
74 """
75 keys = split(key, 3) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST
77 if bart.y.dtype == bool: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST
78 bart = replace(bart, error_cov_inv=jnp.array(1.0)) 1*#+4$.,/-:
79 bart = step_trees(keys.pop(), bart) 1*#+4$.,/-:
80 bart = replace(bart, error_cov_inv=None) 1*#+4$.,/-:
81 bart = step_z(keys.pop(), bart) 1*#+4$.,/-:
83 else: # continuous or multivariate regression
84 bart = step_trees(keys.pop(), bart) 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST
85 bart = step_error_cov_inv(keys.pop(), bart) 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST
87 bart = step_sparse(keys.pop(), bart) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST
88 return step_config(bart) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST
91@named_call
92def step_trees(key: Key[Array, ''], bart: State) -> State:
93 """
94 Forest sampling step of BART MCMC.
96 Parameters
97 ----------
98 key
99 A jax random key.
100 bart
101 A BART mcmc state, as created by `init`.
103 Returns
104 -------
105 The new BART mcmc state.
107 Notes
108 -----
109 This function zeroes the proposal counters.
110 """
111 keys = split(key) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
112 moves = propose_moves(keys.pop(), bart.forest) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
113 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
116@named_call
117def accept_moves_and_sample_leaves(
118 key: Key[Array, ''], bart: State, moves: Moves
119) -> State:
120 """
121 Accept or reject the proposed moves and sample the new leaf values.
123 Parameters
124 ----------
125 key
126 A jax random key.
127 bart
128 A valid BART mcmc state.
129 moves
130 The proposed moves, see `propose_moves`.
132 Returns
133 -------
134 A new (valid) BART mcmc state.
135 """
136 pso = accept_moves_parallel_stage(key, bart, moves) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
137 bart, moves = accept_moves_sequential_stage(pso) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
138 return accept_moves_final_stage(bart, moves) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
141class Counts(Module):
142 """Number of datapoints in the nodes involved in proposed moves for each tree."""
144 left: UInt[Array, '*chains num_trees'] = field(chains=True)
145 """Number of datapoints in the left child."""
147 right: UInt[Array, '*chains num_trees'] = field(chains=True)
148 """Number of datapoints in the right child."""
150 total: UInt[Array, '*chains num_trees'] = field(chains=True)
151 """Number of datapoints in the parent (``= left + right``)."""
154class Precs(Module):
155 """Likelihood precision scale in the nodes involved in proposed moves for each tree.
157 The "likelihood precision scale" of a tree node is the sum of the inverse
158 squared error scales of the datapoints selected by the node.
159 """
161 left: Float32[Array, '*chains num_trees'] = field(chains=True)
162 """Likelihood precision scale in the left child."""
164 right: Float32[Array, '*chains num_trees'] = field(chains=True)
165 """Likelihood precision scale in the right child."""
167 total: Float32[Array, '*chains num_trees'] = field(chains=True)
168 """Likelihood precision scale in the parent (``= left + right``)."""
171class PreLkV(Module):
172 """Non-sequential terms of the likelihood ratio for each tree.
174 These terms can be computed in parallel across trees.
175 """
177 left: (
178 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
179 ) = field(chains=True)
180 """In the univariate case, this is the scalar term
182 ``1 / error_cov_inv + n_left / leaf_prior_cov_inv``.
184 In the multivariate case, this is the matrix term
186 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_left * error_cov_inv) @ error_cov_inv``.
188 ``n_left`` is the number of datapoints in the left child, or the
189 likelihood precision scale in the heteroskedastic case."""
191 right: (
192 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
193 ) = field(chains=True)
194 """In the univariate case, this is the scalar term
196 ``1 / error_cov_inv + n_right / leaf_prior_cov_inv``.
198 In the multivariate case, this is the matrix term
200 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_right * error_cov_inv) @ error_cov_inv``.
202 ``n_right`` is the number of datapoints in the right child, or the
203 likelihood precision scale in the heteroskedastic case."""
205 total: (
206 Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
207 ) = field(chains=True)
208 """In the univariate case, this is the scalar term
210 ``1 / error_cov_inv + n_total / leaf_prior_cov_inv``.
212 In the multivariate case, this is the matrix term
214 ``error_cov_inv @ inv(leaf_prior_cov_inv + n_total * error_cov_inv) @ error_cov_inv``.
216 ``n_total`` is the number of datapoints in the parent node, or the
217 likelihood precision scale in the heteroskedastic case."""
219 log_sqrt_term: Float32[Array, '*chains num_trees'] = field(chains=True)
220 """The logarithm of the square root term of the likelihood ratio."""
223class PreLk(Module):
224 """Non-sequential terms of the likelihood ratio shared by all trees."""
226 exp_factor: Float32[Array, '*chains'] = field(chains=True)
227 """The factor to multiply the likelihood ratio by, shared by all trees."""
230class PreLf(Module):
231 """Pre-computed terms used to sample leaves from their posterior.
233 These terms can be computed in parallel across trees.
235 For each tree and leaf, the terms are scalars in the univariate case, and
236 matrices/vectors in the multivariate case.
237 """
239 mean_factor: (
240 Float32[Array, '*chains num_trees 2**d']
241 | Float32[Array, '*chains num_trees k k 2**d']
242 ) = field(chains=True)
243 """The factor to be right-multiplied by the sum of the scaled residuals to
244 obtain the posterior mean."""
246 centered_leaves: (
247 Float32[Array, '*chains num_trees 2**d']
248 | Float32[Array, '*chains num_trees k 2**d']
249 ) = field(chains=True)
250 """The mean-zero normal values to be added to the posterior mean to
251 obtain the posterior leaf samples."""
254class ParallelStageOut(Module):
255 """The output of `accept_moves_parallel_stage`."""
257 bart: State
258 """A partially updated BART mcmc state."""
260 moves: Moves
261 """The proposed moves, with `partial_ratio` set to `None` and
262 `log_trans_prior_ratio` set to its final value."""
264 prec_trees: (
265 Float32[Array, '*chains num_trees 2**d']
266 | Int32[Array, '*chains num_trees 2**d']
267 ) = field(chains=True)
268 """The likelihood precision scale in each potential or actual leaf node. If
269 there is no precision scale, this is the number of points in each leaf."""
271 move_precs: Precs | Counts
272 """The likelihood precision scale in each node modified by the moves. If
273 `bart.prec_scale` is not set, this is set to `move_counts`."""
275 prelkv: PreLkV
276 """Object with pre-computed terms of the likelihood ratios."""
278 prelk: PreLk | None
279 """Object with pre-computed terms of the likelihood ratios."""
281 prelf: PreLf
282 """Object with pre-computed terms of the leaf samples."""
285@named_call
286def accept_moves_parallel_stage(
287 key: Key[Array, ''], bart: State, moves: Moves
288) -> ParallelStageOut:
289 """
290 Pre-compute quantities used to accept moves, in parallel across trees.
292 Parameters
293 ----------
294 key
295 A jax random key.
296 bart
297 A BART mcmc state.
298 moves
299 The proposed moves, see `propose_moves`.
301 Returns
302 -------
303 An object with all that could be done in parallel.
304 """
305 # where the move is grow, modify the state like the move was accepted
306 bart = replace( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
307 bart,
308 forest=replace(
309 bart.forest,
310 var_tree=moves.var_tree,
311 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
312 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
313 ),
314 )
316 # count number of datapoints per leaf
317 if ( 1akQzZ5s
318 bart.forest.min_points_per_decision_node is not None
319 or bart.forest.min_points_per_leaf is not None
320 or bart.prec_scale is None
321 ):
322 count_trees, move_counts = compute_count_trees( 1!0*3a#L+b4xc$MdNjPkQz.,C/%'(e-Ol:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
323 bart.forest.leaf_indices, moves, bart.config
324 )
326 # mark which leaves & potential leaves have enough points to be grown
327 if bart.forest.min_points_per_decision_node is not None: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
328 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1!03a#Lbxc$MdNjPeOYw6mno7DEp8FGH9qrI5
329 moves = replace( 1!03a#Lbxc$MdNjPeOYw6mno7DEp8FGH9qrI5
330 moves,
331 affluence_tree=moves.affluence_tree
332 & (count_half_trees >= bart.forest.min_points_per_decision_node),
333 )
335 # copy updated affluence_tree to state
336 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
338 # veto grove move if new leaves don't have enough datapoints
339 if bart.forest.min_points_per_leaf is not None: 1!0*3;a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
340 moves = replace( 1!03;a#Lbxc$MdNkQeOYw
341 moves,
342 allowed=moves.allowed
343 & (move_counts.left >= bart.forest.min_points_per_leaf)
344 & (move_counts.right >= bart.forest.min_points_per_leaf),
345 )
347 # count number of datapoints per leaf, weighted by error precision scale
348 if bart.prec_scale is None: 1!0*3;a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
349 prec_trees = count_trees 1!0*a#+b4c$djkz.,C/%'(e-l:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
350 move_precs = move_counts 1!0*a#+b4c$djkz.,C/%'(e-l:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
351 else:
352 prec_trees, move_precs = compute_prec_trees( 13;LxMNPQZ1O2
353 bart.prec_scale, bart.forest.leaf_indices, moves, bart.config
354 )
355 assert move_precs is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
357 # compute some missing information about moves
358 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
359 save_ratios = bart.forest.log_likelihood is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
360 bart = replace( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
361 bart,
362 forest=replace(
363 bart.forest,
364 grow_prop_count=jnp.sum(moves.grow),
365 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
366 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
367 ),
368 )
370 assert bart.error_cov_inv is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
371 prelkv, prelk = precompute_likelihood_terms( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
372 bart.error_cov_inv, bart.forest.leaf_prior_cov_inv, move_precs
373 )
374 prelf = precompute_leaf_terms( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
375 key, prec_trees, bart.error_cov_inv, bart.forest.leaf_prior_cov_inv
376 )
378 return ParallelStageOut( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
379 bart=bart,
380 moves=moves,
381 prec_trees=prec_trees,
382 move_precs=move_precs,
383 prelkv=prelkv,
384 prelk=prelk,
385 prelf=prelf,
386 )
389@named_call
390@partial(vmap_nodoc, in_axes=(0, 0, None))
391def apply_grow_to_indices(
392 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
393) -> UInt[Array, 'num_trees n']:
394 """
395 Update the leaf indices to apply a grow move.
397 Parameters
398 ----------
399 moves
400 The proposed moves, see `propose_moves`.
401 leaf_indices
402 The index of the leaf each datapoint falls into.
403 X
404 The predictors matrix.
406 Returns
407 -------
408 The updated leaf indices.
409 """
410 left_child = moves.node.astype(leaf_indices.dtype) << 1 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
411 x: UInt[Array, ' n'] = X[moves.grow_var, :] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
412 go_right = x >= moves.grow_split 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
413 tree_size = jnp.array(2 * moves.var_tree.size) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
414 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
415 return jnp.where( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
416 leaf_indices == node_to_update, left_child + go_right, leaf_indices
417 )
420def _compute_count_or_prec_trees(
421 prec_scale: Float32[Array, ' n'] | None,
422 leaf_indices: UInt[Array, 'num_trees n'],
423 moves: Moves,
424 config: StepConfig,
425) -> (
426 tuple[UInt32[Array, 'num_trees 2**d'], Counts]
427 | tuple[Float32[Array, 'num_trees 2**d'], Precs]
428):
429 """Implement `compute_count_trees` and `compute_prec_trees`."""
430 if config.prec_count_num_trees is None: 1!0*3;a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
431 compute = vmap(_compute_count_or_prec_tree, in_axes=(None, 0, 0, None)) 1!*#+4$.,/%'(-:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
432 return compute(prec_scale, leaf_indices, moves, config) 1!*#+4$.,/%'(-:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
434 def compute( 103;aLbxcMdNjPkQzZC1eOl2
435 args: tuple[UInt[Array, ' n'], Moves],
436 ) -> tuple[UInt32[Array, ' 2**d'], Counts] | tuple[Float32[Array, ' 2**d'], Precs]:
437 leaf_indices, moves = args 103;aLbxcMdNjPkQzZC1eOl2
438 return _compute_count_or_prec_tree(prec_scale, leaf_indices, moves, config) 103;aLbxcMdNjPkQzZC1eOl2
440 return lax.map( 103;aLbxcMdNjPkQzZC1eOl2
441 compute, (leaf_indices, moves), batch_size=config.prec_count_num_trees
442 )
445def _compute_count_or_prec_tree(
446 prec_scale: Float32[Array, ' n'] | None,
447 leaf_indices: UInt[Array, ' n'],
448 moves: Moves,
449 config: StepConfig,
450) -> tuple[UInt32[Array, ' 2**d'], Counts] | tuple[Float32[Array, ' 2**d'], Precs]:
451 """Compute count or precision tree for a single tree."""
452 (tree_size,) = moves.var_tree.shape 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
453 tree_size *= 2 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
455 if prec_scale is None: 1!0*3;a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
456 value = 1 1!0*3a#L+b4xc$MdNjPkQz.,C/%'(e-Ol:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
457 cls = Counts 1!0*3a#L+b4xc$MdNjPkQz.,C/%'(e-Ol:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
458 dtype = jnp.uint32 1!0*3a#L+b4xc$MdNjPkQz.,C/%'(e-Ol:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
459 num_batches = config.count_num_batches 1!0*3a#L+b4xc$MdNjPkQz.,C/%'(e-Ol:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
460 else:
461 value = prec_scale 13;LxMNPQZ1O2
462 cls = Precs 13;LxMNPQZ1O2
463 dtype = jnp.float32 13;LxMNPQZ1O2
464 num_batches = config.prec_num_batches 13;LxMNPQZ1O2
466 trees = _scatter_add( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
467 value, leaf_indices, tree_size, dtype, num_batches, config.mesh
468 )
470 # count datapoints in nodes modified by move
471 left = trees[moves.left] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
472 right = trees[moves.right] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
473 counts = cls(left=left, right=right, total=left + right) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
475 # write count into non-leaf node
476 trees = trees.at[moves.node].set(counts.total) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
478 return trees, counts 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
481@named_call
482def compute_count_trees(
483 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, config: StepConfig
484) -> tuple[UInt32[Array, 'num_trees 2**d'], Counts]:
485 """
486 Count the number of datapoints in each leaf.
488 Parameters
489 ----------
490 leaf_indices
491 The index of the leaf each datapoint falls into, with the deeper version
492 of the tree (post-GROW, pre-PRUNE).
493 moves
494 The proposed moves, see `propose_moves`.
495 config
496 The MCMC configuration.
498 Returns
499 -------
500 count_trees : Int32[Array, 'num_trees 2**d']
501 The number of points in each potential or actual leaf node.
502 counts : Counts
503 The counts of the number of points in the leaves grown or pruned by the
504 moves.
505 """
506 return _compute_count_or_prec_trees(None, leaf_indices, moves, config) 1!0*3a#L+b4xc$MdNjPkQz.,C/%'(e-Ol:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
509@named_call
510def compute_prec_trees(
511 prec_scale: Float32[Array, ' n'],
512 leaf_indices: UInt[Array, 'num_trees n'],
513 moves: Moves,
514 config: StepConfig,
515) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
516 """
517 Compute the likelihood precision scale in each leaf.
519 Parameters
520 ----------
521 prec_scale
522 The scale of the precision of the error on each datapoint.
523 leaf_indices
524 The index of the leaf each datapoint falls into, with the deeper version
525 of the tree (post-GROW, pre-PRUNE).
526 moves
527 The proposed moves, see `propose_moves`.
528 config
529 The MCMC configuration.
531 Returns
532 -------
533 prec_trees : Float32[Array, 'num_trees 2**d']
534 The likelihood precision scale in each potential or actual leaf node.
535 precs : Precs
536 The likelihood precision scale in the nodes involved in the moves.
537 """
538 return _compute_count_or_prec_trees(prec_scale, leaf_indices, moves, config) 13;LxMNPQZ1O2
541@partial(vmap_nodoc, in_axes=(0, None))
542def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves:
543 """
544 Complete non-likelihood MH ratio calculation.
546 This function adds the probability of choosing a prune move over the grow
547 move in the inverse transition, and the a priori probability that the
548 children nodes are leaves.
550 Parameters
551 ----------
552 moves
553 The proposed moves. Must have already been updated to keep into account
554 the thresholds on the number of datapoints per node, this happens in
555 `accept_moves_parallel_stage`.
556 p_nonterminal
557 The a priori probability of each node being nonterminal conditional on
558 its ancestors, including at the maximum depth where it should be zero.
560 Returns
561 -------
562 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
563 """
564 # can the leaves be grown?
565 left_growable = moves.affluence_tree.at[moves.left].get( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
566 mode='fill', fill_value=False
567 )
568 right_growable = moves.affluence_tree.at[moves.right].get( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
569 mode='fill', fill_value=False
570 )
572 # p_prune if grow
573 other_growable_leaves = moves.num_growable >= 2 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
574 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
575 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1.0) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
577 # p_prune if prune
578 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
580 # select p_prune
581 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
583 # prior probability of both children being terminal
584 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
585 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
586 pt_children = pt_left * pt_right 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
588 assert moves.partial_ratio is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
589 return replace( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
590 moves,
591 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
592 partial_ratio=None,
593 )
596@named_call
597@vmap_nodoc
598def adapt_leaf_trees_to_grow_indices(
599 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
600) -> Float32[Array, 'num_trees 2**d']:
601 """
602 Modify leaves such that post-grow indices work on the original tree.
604 The value of the leaf to grow is copied to what would be its children if the
605 grow move was accepted.
607 Parameters
608 ----------
609 leaf_trees
610 The leaf values.
611 moves
612 The proposed moves, see `propose_moves`.
614 Returns
615 -------
616 The modified leaf values.
617 """
618 values_at_node = leaf_trees[..., moves.node] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
619 return ( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
620 leaf_trees.at[..., jnp.where(moves.grow, moves.left, leaf_trees.size)]
621 .set(values_at_node)
622 .at[..., jnp.where(moves.grow, moves.right, leaf_trees.size)]
623 .set(values_at_node)
624 )
627def _logdet_from_chol(L: Float32[Array, '... k k']) -> Float32[Array, '...']:
628 """Compute logdet of A = LL' via Cholesky (sum of log of diag^2)."""
629 diags: Float32[Array, '... k'] = jnp.diagonal(L, axis1=-2, axis2=-1) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
630 return 2.0 * jnp.sum(jnp.log(diags), axis=-1) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
633def _precompute_likelihood_terms_uv(
634 error_cov_inv: Float32[Array, ''],
635 leaf_prior_cov_inv: Float32[Array, ''],
636 move_precs: Precs | Counts,
637) -> tuple[PreLkV, PreLk]:
638 sigma2 = jnp.reciprocal(error_cov_inv) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=
639 sigma_mu2 = jnp.reciprocal(leaf_prior_cov_inv) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=
640 left = sigma2 + move_precs.left * sigma_mu2 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=
641 right = sigma2 + move_precs.right * sigma_mu2 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=
642 total = sigma2 + move_precs.total * sigma_mu2 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=
643 prelkv = PreLkV( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=
644 left=left,
645 right=right,
646 total=total,
647 log_sqrt_term=jnp.log(sigma2 * total / (left * right)) / 2,
648 )
649 return prelkv, PreLk(exp_factor=error_cov_inv / leaf_prior_cov_inv / 2) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=
652def _precompute_likelihood_terms_mv(
653 error_cov_inv: Float32[Array, 'k k'],
654 leaf_prior_cov_inv: Float32[Array, 'k k'],
655 move_precs: Counts,
656) -> tuple[PreLkV, None]:
657 nL: UInt[Array, 'num_trees 1 1'] = move_precs.left[..., None, None] 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
658 nR: UInt[Array, 'num_trees 1 1'] = move_precs.right[..., None, None] 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
659 nT: UInt[Array, 'num_trees 1 1'] = move_precs.total[..., None, None] 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
661 L_left: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
662 error_cov_inv * nL + leaf_prior_cov_inv
663 )
664 L_right: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
665 error_cov_inv * nR + leaf_prior_cov_inv
666 )
667 L_total: Float32[Array, 'num_trees k k'] = chol_with_gersh( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
668 error_cov_inv * nT + leaf_prior_cov_inv
669 )
671 log_sqrt_term: Float32[Array, ' num_trees'] = 0.5 * ( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
672 _logdet_from_chol(chol_with_gersh(leaf_prior_cov_inv))
673 + _logdet_from_chol(L_total)
674 - _logdet_from_chol(L_left)
675 - _logdet_from_chol(L_right)
676 )
678 def _term_from_chol( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
679 L: Float32[Array, 'num_trees k k'],
680 ) -> Float32[Array, 'num_trees k k']:
681 rhs: Float32[Array, 'num_trees k k'] = jnp.broadcast_to(error_cov_inv, L.shape) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
682 Y: Float32[Array, 'num_trees k k'] = solve_triangular(L, rhs, lower=True) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
683 return Y.mT @ Y 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
685 prelkv = PreLkV( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
686 left=_term_from_chol(L_left),
687 right=_term_from_chol(L_right),
688 total=_term_from_chol(L_total),
689 log_sqrt_term=log_sqrt_term,
690 )
692 return prelkv, None 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
695@named_call
696def precompute_likelihood_terms(
697 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
698 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
699 move_precs: Precs | Counts,
700) -> tuple[PreLkV, PreLk | None]:
701 """
702 Pre-compute terms used in the likelihood ratio of the acceptance step.
704 Handles both univariate and multivariate cases based on the shape of the
705 input arrays. The multivariate implementation assumes a homoskedastic error
706 model (i.e., the residual covariance is the same for all observations).
708 Parameters
709 ----------
710 error_cov_inv
711 The inverse error variance (univariate) or the inverse of the error
712 covariance matrix (multivariate). For univariate case, this is the
713 inverse global error variance factor if `prec_scale` is set.
714 leaf_prior_cov_inv
715 The inverse prior variance of each leaf (univariate) or the inverse of
716 prior covariance matrix of each leaf (multivariate).
717 move_precs
718 The likelihood precision scale in the leaves grown or pruned by the
719 moves, under keys 'left', 'right', and 'total' (left + right).
721 Returns
722 -------
723 prelkv : PreLkV
724 Pre-computed terms of the likelihood ratio, one per tree.
725 prelk : PreLk | None
726 Pre-computed terms of the likelihood ratio, shared by all trees.
727 """
728 if error_cov_inv.ndim == 2: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
729 assert isinstance(move_precs, Counts) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX
730 return _precompute_likelihood_terms_mv( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX
731 error_cov_inv, leaf_prior_cov_inv, move_precs
732 )
733 else:
734 return _precompute_likelihood_terms_uv( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX
735 error_cov_inv, leaf_prior_cov_inv, move_precs
736 )
739def _precompute_leaf_terms_uv(
740 key: Key[Array, ''],
741 prec_trees: Float32[Array, 'num_trees 2**d'],
742 error_cov_inv: Float32[Array, ''],
743 leaf_prior_cov_inv: Float32[Array, ''],
744 z: Float32[Array, 'num_trees 2**d'] | None = None,
745) -> PreLf:
746 prec_lk = prec_trees * error_cov_inv 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX@
747 var_post = jnp.reciprocal(prec_lk + leaf_prior_cov_inv) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX@
748 if z is None: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX@
749 z = random.normal(key, prec_trees.shape, error_cov_inv.dtype) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX
750 return PreLf( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX@
751 mean_factor=var_post * error_cov_inv,
752 # | mean = mean_lk * prec_lk * var_post
753 # | resid_tree = mean_lk * prec_tree -->
754 # | --> mean_lk = resid_tree / prec_tree (kind of)
755 # | mean_factor =
756 # | = mean / resid_tree =
757 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
758 # | = 1 / prec_tree * prec_tree / sigma2 * var_post =
759 # | = var_post / sigma2
760 centered_leaves=z * jnp.sqrt(var_post),
761 )
764def _precompute_leaf_terms_mv(
765 key: Key[Array, ''],
766 prec_trees: Float32[Array, 'num_trees 2**d'],
767 error_cov_inv: Float32[Array, 'k k'],
768 leaf_prior_cov_inv: Float32[Array, 'k k'],
769 z: Float32[Array, 'num_trees 2**d k'] | None = None,
770) -> PreLf:
771 num_trees, tree_size = prec_trees.shape 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
772 k = error_cov_inv.shape[0] 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
773 n_k: Float32[Array, 'num_trees tree_size 1 1'] = prec_trees[..., None, None] 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
775 # Only broadcast the inverse of error covariance matrix to satisfy JAX's
776 # batching rules for `lax.linalg.solve_triangular`, which does not support
777 # implicit broadcasting.
778 error_cov_inv_batched = jnp.broadcast_to( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
779 error_cov_inv, (num_trees, tree_size, k, k)
780 )
782 posterior_precision: Float32[Array, 'num_trees tree_size k k'] = ( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
783 leaf_prior_cov_inv + n_k * error_cov_inv_batched
784 )
786 L_prec: Float32[Array, 'num_trees tree_size k k'] = chol_with_gersh( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
787 posterior_precision
788 )
789 Y: Float32[Array, 'num_trees tree_size k k'] = solve_triangular( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
790 L_prec, error_cov_inv_batched, lower=True
791 )
792 mean_factor: Float32[Array, 'num_trees tree_size k k'] = solve_triangular( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
793 L_prec, Y, trans='T', lower=True
794 )
795 mean_factor = mean_factor.mT 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
796 mean_factor_out: Float32[Array, 'num_trees k k tree_size'] = jnp.moveaxis( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
797 mean_factor, 1, -1
798 )
800 if z is None: 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
801 z = random.normal(key, (num_trees, tree_size, k)) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX[]^_
802 centered_leaves: Float32[Array, 'num_trees tree_size k'] = solve_triangular( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
803 L_prec, z, trans='T'
804 )
805 centered_leaves_out: Float32[Array, 'num_trees k tree_size'] = jnp.swapaxes( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
806 centered_leaves, -1, -2
807 )
809 return PreLf(mean_factor=mean_factor_out, centered_leaves=centered_leaves_out) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX@[]^_
812@named_call
813def precompute_leaf_terms(
814 key: Key[Array, ''],
815 prec_trees: Float32[Array, 'num_trees 2**d'],
816 error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
817 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
818 z: Float32[Array, 'num_trees 2**d']
819 | Float32[Array, 'num_trees 2**d k']
820 | None = None,
821) -> PreLf:
822 """
823 Pre-compute terms used to sample leaves from their posterior.
825 Handles both univariate and multivariate cases based on the shape of the
826 input arrays.
828 Parameters
829 ----------
830 key
831 A jax random key.
832 prec_trees
833 The likelihood precision scale in each potential or actual leaf node.
834 error_cov_inv
835 The inverse error variance (univariate) or the inverse of error
836 covariance matrix (multivariate). For univariate case, this is the
837 inverse global error variance factor if `prec_scale` is set.
838 leaf_prior_cov_inv
839 The inverse prior variance of each leaf (univariate) or the inverse of
840 prior covariance matrix of each leaf (multivariate).
841 z
842 Optional standard normal noise to use for sampling the centered leaves.
843 This is intended for testing purposes only.
845 Returns
846 -------
847 Pre-computed terms for leaf sampling.
848 """
849 if error_cov_inv.ndim == 2: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
850 return _precompute_leaf_terms_mv( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX
851 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
852 )
853 else:
854 return _precompute_leaf_terms_uv( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX
855 key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
856 )
859@named_call
860def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
861 """
862 Accept/reject the moves one tree at a time.
864 This is the most performance-sensitive function because it contains all and
865 only the parts of the algorithm that can not be parallelized across trees.
867 Parameters
868 ----------
869 pso
870 The output of `accept_moves_parallel_stage`.
872 Returns
873 -------
874 bart : State
875 A partially updated BART mcmc state.
876 moves : Moves
877 The accepted/rejected moves, with `acc` and `to_prune` set.
878 """
880 def loop( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
881 resid: Float32[Array, ' n'] | Float32[Array, ' k n'], pt: SeqStageInPerTree
882 ) -> tuple[
883 Float32[Array, ' n'] | Float32[Array, ' k n'],
884 tuple[
885 Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'],
886 Bool[Array, ''],
887 Bool[Array, ''],
888 Float32[Array, ''] | None,
889 ],
890 ]:
891 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
892 resid,
893 SeqStageInAllTrees(
894 pso.bart.X,
895 pso.bart.config.resid_num_batches,
896 pso.bart.config.mesh,
897 pso.bart.prec_scale,
898 pso.bart.forest.log_likelihood is not None,
899 pso.prelk,
900 ),
901 pt,
902 )
903 return resid, (leaf_tree, acc, to_prune, lkratio) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
905 pts = SeqStageInPerTree( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
906 pso.bart.forest.leaf_tree,
907 pso.prec_trees,
908 pso.moves,
909 pso.move_precs,
910 pso.bart.forest.leaf_indices,
911 pso.prelkv,
912 pso.prelf,
913 )
914 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
916 bart = replace( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
917 pso.bart,
918 resid=resid,
919 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
920 )
921 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
923 return bart, moves 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
926class SeqStageInAllTrees(Module):
927 """The inputs to `accept_move_and_sample_leaves` that are shared by all trees."""
929 X: UInt[Array, 'p n']
930 """The predictors."""
932 resid_num_batches: int | None = field(static=True)
933 """The number of batches for computing the sum of residuals in each leaf."""
935 mesh: Mesh | None = field(static=True)
936 """The mesh of devices to use."""
938 prec_scale: Float32[Array, ' n'] | None
939 """The scale of the precision of the error on each datapoint. If None, it
940 is assumed to be 1."""
942 save_ratios: bool = field(static=True)
943 """Whether to save the acceptance ratios."""
945 prelk: PreLk | None
946 """The pre-computed terms of the likelihood ratio which are shared across
947 trees."""
950class SeqStageInPerTree(Module):
951 """The inputs to `accept_move_and_sample_leaves` that are separate for each tree."""
953 leaf_tree: Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d']
954 """The leaf values of the tree."""
956 prec_tree: Float32[Array, ' 2**d']
957 """The likelihood precision scale in each potential or actual leaf node."""
959 move: Moves
960 """The proposed move, see `propose_moves`."""
962 move_precs: Precs | Counts
963 """The likelihood precision scale in each node modified by the moves."""
965 leaf_indices: UInt[Array, ' n']
966 """The leaf indices for the largest version of the tree compatible with
967 the move."""
969 prelkv: PreLkV
970 """The pre-computed terms of the likelihood ratio which are specific to the tree."""
972 prelf: PreLf
973 """The pre-computed terms of the leaf sampling which are specific to the tree."""
976@named_call
977def accept_move_and_sample_leaves(
978 resid: Float32[Array, ' n'] | Float32[Array, ' k n'],
979 at: SeqStageInAllTrees,
980 pt: SeqStageInPerTree,
981) -> tuple[
982 Float32[Array, ' n'] | Float32[Array, ' k n'],
983 Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'],
984 Bool[Array, ''],
985 Bool[Array, ''],
986 Float32[Array, ''] | None,
987]:
988 """
989 Accept or reject a proposed move and sample the new leaf values.
991 Parameters
992 ----------
993 resid
994 The residuals (data minus forest value).
995 at
996 The inputs that are the same for all trees.
997 pt
998 The inputs that are separate for each tree.
1000 Returns
1001 -------
1002 resid : Float32[Array, 'n'] | Float32[Array, ' k n']
1003 The updated residuals (data minus forest value).
1004 leaf_tree : Float32[Array, '2**d'] | Float32[Array, ' k 2**d']
1005 The new leaf values of the tree.
1006 acc : Bool[Array, '']
1007 Whether the move was accepted.
1008 to_prune : Bool[Array, '']
1009 Whether, to reflect the acceptance status of the move, the state should
1010 be updated by pruning the leaves involved in the move.
1011 log_lk_ratio : Float32[Array, ''] | None
1012 The logarithm of the likelihood ratio for the move. `None` if not to be
1013 saved.
1014 """
1015 # sum residuals in each leaf, in tree proposed by grow move
1016 if at.prec_scale is None: 1!0*3;a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1017 scaled_resid = resid 1!0*a#+b4c$djkz.,C/%'(e-l:Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1018 else:
1019 scaled_resid = resid * at.prec_scale 13;LxMNPQZ1O2
1021 tree_size = pt.leaf_tree.shape[-1] # 2**d 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1023 resid_tree = sum_resid( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1024 scaled_resid, pt.leaf_indices, tree_size, at.resid_num_batches, at.mesh
1025 )
1027 # subtract starting tree from function
1028 resid_tree += pt.prec_tree * pt.leaf_tree 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1030 # sum residuals in parent node modified by move and compute likelihood
1031 resid_left = resid_tree[..., pt.move.left] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1032 resid_right = resid_tree[..., pt.move.right] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1033 resid_total = resid_left + resid_right 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1034 assert pt.move.node.dtype == jnp.int32 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1035 resid_tree = resid_tree.at[..., pt.move.node].set(resid_total) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1037 log_lk_ratio = compute_likelihood_ratio( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1038 resid_total, resid_left, resid_right, pt.prelkv, at.prelk
1039 )
1041 # calculate accept/reject ratio
1042 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1043 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1044 if not at.save_ratios: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1045 log_lk_ratio = None 1!*#+4$,-Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1047 # determine whether to accept the move
1048 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1050 # compute leaves posterior and sample leaves
1051 if resid.ndim > 1: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1052 mean_post = jnp.einsum('kil,kl->il', pt.prelf.mean_factor, resid_tree) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX
1053 else:
1054 mean_post = resid_tree * pt.prelf.mean_factor 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX
1055 leaf_tree = mean_post + pt.prelf.centered_leaves 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1057 # copy leaves around such that the leaf indices point to the correct leaf
1058 to_prune = acc ^ pt.move.grow 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1059 leaf_tree = ( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1060 leaf_tree.at[..., jnp.where(to_prune, pt.move.left, tree_size)]
1061 .set(leaf_tree[..., pt.move.node])
1062 .at[..., jnp.where(to_prune, pt.move.right, tree_size)]
1063 .set(leaf_tree[..., pt.move.node])
1064 )
1065 # replace old tree with new tree in function values
1066 resid += (pt.leaf_tree - leaf_tree)[..., pt.leaf_indices] 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1068 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1071@named_call
1072@partial(jnp.vectorize, excluded=(1, 2, 3, 4), signature='(n)->(ts)')
1073def sum_resid(
1074 scaled_resid: Float32[Array, ' n'] | Float32[Array, 'k n'],
1075 leaf_indices: UInt[Array, ' n'],
1076 tree_size: int,
1077 resid_num_batches: int | None,
1078 mesh: Mesh | None,
1079) -> Float32[Array, ' {tree_size}'] | Float32[Array, 'k {tree_size}']:
1080 """
1081 Sum the residuals in each leaf.
1083 Handles both univariate and multivariate cases based on the shape of the
1084 input arrays.
1086 Parameters
1087 ----------
1088 scaled_resid
1089 The residuals (data minus forest value) multiplied by the error
1090 precision scale. For multivariate case, shape is ``(k, n)`` where ``k``
1091 is the number of outcome columns.
1092 leaf_indices
1093 The leaf indices of the tree (in which leaf each data point falls into).
1094 tree_size
1095 The size of the tree array (2 ** d).
1096 resid_num_batches
1097 The number of batches for computing the sum of residuals in each leaf.
1098 mesh
1099 The mesh of devices to use.
1101 Returns
1102 -------
1103 The sum of the residuals at data points in each leaf. For multivariate
1104 case, returns per-leaf sums of residual vectors.
1105 """
1106 return _scatter_add( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1107 scaled_resid, leaf_indices, tree_size, jnp.float32, resid_num_batches, mesh
1108 )
1111def _scatter_add(
1112 values: Float32[Array, ' n'] | int,
1113 indices: Integer[Array, ' n'],
1114 size: int,
1115 dtype: jnp.dtype,
1116 batch_size: int | None,
1117 mesh: Mesh | None,
1118) -> Shaped[Array, ' {size}']:
1119 """Indexed reduce with optional batching."""
1120 # check `values`
1121 values = jnp.asarray(values) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1122 assert values.ndim == 0 or values.shape == indices.shape 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1124 # set configuration
1125 _scatter_add = partial( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1126 _scatter_add_impl, size=size, dtype=dtype, num_batches=batch_size
1127 )
1129 # single-device invocation
1130 if mesh is None or 'data' not in mesh.axis_names: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1131 return _scatter_add(values, indices) 1!*3#L+b4x$MNPQz.Z,C/1%'(-O:2Yw6mno7DEp8FGH9qrI5tJKuvsyRSTUVWX
1133 # multi-device invocation
1134 if values.shape: 10ab4xcdjkelfABghi
1135 in_specs = PartitionSpec('data'), PartitionSpec('data') 10ab4xcdjkelfABghi
1136 else:
1137 in_specs = PartitionSpec(), PartitionSpec('data') 10ab4xcdjkelfABghi
1138 _scatter_add = partial(_scatter_add, final_psum=True) 10ab4xcdjkelfABghi
1139 _scatter_add = shard_map( 10ab4xcdjkelfABghi
1140 _scatter_add,
1141 in_specs=in_specs,
1142 out_specs=PartitionSpec(),
1143 mesh=mesh,
1144 **_get_shard_map_patch_kwargs(),
1145 )
1146 return _scatter_add(values, indices) 10ab4xcdjkelfABghi
1149def _get_shard_map_patch_kwargs() -> dict[str, bool]:
1150 # see jax/issues/#34249, problem with vmap(shard_map(psum))
1151 # we tried the config jax_disable_vmap_shmap_error but it didn't work
1152 if jax.__version__ in ('0.8.1', '0.8.2'): 1152 ↛ 1153line 1152 didn't jump to line 1153 because the condition on line 1152 was never true10ab4xcdjkelfABghi
1153 return {'check_vma': False}
1154 else:
1155 return {} 10ab4xcdjkelfABghi
1158def _scatter_add_impl(
1159 values: Float32[Array, ' n'] | Int32[Array, ''],
1160 indices: Integer[Array, ' n'],
1161 /,
1162 *,
1163 size: int,
1164 dtype: jnp.dtype,
1165 num_batches: int | None,
1166 final_psum: bool = False,
1167) -> Shaped[Array, ' {size}']:
1168 if num_batches is None: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1169 out = jnp.zeros(size, dtype).at[indices].add(values) 1!0*a#+b4c$djkz.,C/%'(e-l:fABghiyRSTUVWX
1171 else:
1172 # in the sharded case, n is the size of the local shard, not the full size
1173 (n,) = indices.shape 13LxMNPQZ1O2Yw6mno7DEp8FGH9qrI5tJKuvs
1174 batch_indices = jnp.arange(n) % num_batches 13LxMNPQZ1O2Yw6mno7DEp8FGH9qrI5tJKuvs
1175 out = ( 13LxMNPQZ1O2Yw6mno7DEp8FGH9qrI5tJKuvs
1176 jnp.zeros((size, num_batches), dtype)
1177 .at[indices, batch_indices]
1178 .add(values)
1179 .sum(axis=1)
1180 )
1182 if final_psum: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1183 out = lax.psum(out, 'data') 10ab4xcdjkelfABghi
1184 return out 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1187def _compute_likelihood_ratio_uv(
1188 total_resid: Float32[Array, ''],
1189 left_resid: Float32[Array, ''],
1190 right_resid: Float32[Array, ''],
1191 prelkv: PreLkV,
1192 prelk: PreLk,
1193) -> Float32[Array, '']:
1194 exp_term = prelk.exp_factor * ( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=
1195 left_resid * left_resid / prelkv.left
1196 + right_resid * right_resid / prelkv.right
1197 - total_resid * total_resid / prelkv.total
1198 )
1199 return prelkv.log_sqrt_term + exp_term 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX=
1202def _compute_likelihood_ratio_mv(
1203 total_resid: Float32[Array, ' k'],
1204 left_resid: Float32[Array, ' k'],
1205 right_resid: Float32[Array, ' k'],
1206 prelkv: PreLkV,
1207) -> Float32[Array, '']:
1208 def _quadratic_form( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
1209 r: Float32[Array, ' k'], mat: Float32[Array, 'k k']
1210 ) -> Float32[Array, '']:
1211 return r @ mat @ r 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
1213 qf_left = _quadratic_form(left_resid, prelkv.left) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
1214 qf_right = _quadratic_form(right_resid, prelkv.right) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
1215 qf_total = _quadratic_form(total_resid, prelkv.total) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
1216 exp_term = 0.5 * (qf_left + qf_right - qf_total) 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
1217 return prelkv.log_sqrt_term + exp_term 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX=
1220@named_call
1221def compute_likelihood_ratio(
1222 total_resid: Float32[Array, ''] | Float32[Array, ' k'],
1223 left_resid: Float32[Array, ''] | Float32[Array, ' k'],
1224 right_resid: Float32[Array, ''] | Float32[Array, ' k'],
1225 prelkv: PreLkV,
1226 prelk: PreLk | None,
1227) -> Float32[Array, '']:
1228 """
1229 Compute the likelihood ratio of a grow move.
1231 Handles both univariate and multivariate cases based on the shape of the
1232 residual arrays.
1234 Parameters
1235 ----------
1236 total_resid
1237 left_resid
1238 right_resid
1239 The sum of the residuals (scaled by error precision scale) of the
1240 datapoints falling in the nodes involved in the moves.
1241 prelkv
1242 prelk
1243 The pre-computed terms of the likelihood ratio, see
1244 `precompute_likelihood_terms`.
1246 Returns
1247 -------
1248 The log-likelihood ratio log P(data | new tree) - log P(data | old tree).
1249 """
1250 if total_resid.ndim > 0: 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1251 return _compute_likelihood_ratio_mv( 1mnoDEpFGHqrItJKuvsfABghiyRSTUVWX
1252 total_resid, left_resid, right_resid, prelkv
1253 )
1254 else:
1255 assert prelk is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX
1256 return _compute_likelihood_ratio_uv( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw67895UVWX
1257 total_resid, left_resid, right_resid, prelkv, prelk
1258 )
1261@named_call
1262def accept_moves_final_stage(bart: State, moves: Moves) -> State:
1263 """
1264 Post-process the mcmc state after accepting/rejecting the moves.
1266 This function is separate from `accept_moves_sequential_stage` to signal it
1267 can work in parallel across trees.
1269 Parameters
1270 ----------
1271 bart
1272 A partially updated BART mcmc state.
1273 moves
1274 The proposed moves (see `propose_moves`) as updated by
1275 `accept_moves_sequential_stage`.
1277 Returns
1278 -------
1279 The fully updated BART mcmc state.
1280 """
1281 return replace( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1282 bart,
1283 forest=replace(
1284 bart.forest,
1285 grow_acc_count=jnp.sum(moves.acc & moves.grow),
1286 prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
1287 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
1288 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
1289 ),
1290 )
1293@named_call
1294@vmap_nodoc
1295def apply_moves_to_leaf_indices(
1296 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
1297) -> UInt[Array, 'num_trees n']:
1298 """
1299 Update the leaf indices to match the accepted move.
1301 Parameters
1302 ----------
1303 leaf_indices
1304 The index of the leaf each datapoint falls into, if the grow move was
1305 accepted.
1306 moves
1307 The proposed moves (see `propose_moves`), as updated by
1308 `accept_moves_sequential_stage`.
1310 Returns
1311 -------
1312 The updated leaf indices.
1313 """
1314 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1315 is_child = (leaf_indices & mask) == moves.left 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1316 assert moves.to_prune is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1317 return jnp.where( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1318 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
1319 )
1322@named_call
1323@vmap_nodoc
1324def apply_moves_to_split_trees(
1325 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
1326) -> UInt[Array, 'num_trees 2**(d-1)']:
1327 """
1328 Update the split trees to match the accepted move.
1330 Parameters
1331 ----------
1332 split_tree
1333 The cutpoints of the decision nodes in the initial trees.
1334 moves
1335 The proposed moves (see `propose_moves`), as updated by
1336 `accept_moves_sequential_stage`.
1338 Returns
1339 -------
1340 The updated split trees.
1341 """
1342 assert moves.to_prune is not None 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1343 return ( 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRSTUVWX
1344 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
1345 .set(moves.grow_split.astype(split_tree.dtype))
1346 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
1347 .set(0)
1348 )
1351@jax.jit
1352def _sample_wishart_bartlett(
1353 key: Key[Array, ''], df: Float32[Array, ''], scale_inv: Float32[Array, 'k k']
1354) -> Float32[Array, 'k k']:
1355 """
1356 Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition.
1358 Parameters
1359 ----------
1360 key
1361 A JAX random key
1362 df
1363 Degrees of freedom
1364 scale_inv
1365 Scale matrix of the corresponding Inverse Wishart distribution
1367 Returns
1368 -------
1369 A sample from Wishart(df, scale)
1370 """
1371 keys = split(key) 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb
1373 # Diagonal elements: A_ii ~ sqrt(chi^2(df - i))
1374 # chi^2(k) = Gamma(k/2, scale=2)
1375 k, _ = scale_inv.shape 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb
1376 df_vector = df - jnp.arange(k) 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb
1377 chi2_samples = random.gamma(keys.pop(), df_vector / 2.0) * 2.0 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb
1378 diag_A = jnp.sqrt(chi2_samples) 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb
1380 off_diag_A = random.normal(keys.pop(), (k, k)) 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb
1381 A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A) 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb
1382 L = chol_with_gersh(scale_inv, absolute_eps=True) 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb
1383 T = solve_triangular(L, A, lower=True, trans='T') 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb
1385 return T @ T.T 2m n o p q r t u v s f g h i ? y } ~ abbbcbdb
1388def _step_error_cov_inv_uv(key: Key[Array, ''], bart: State) -> State:
1389 resid = bart.resid 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|
1390 # inverse gamma prior: alpha = df / 2, beta = scale / 2
1391 alpha = bart.error_cov_df / 2 + resid.size / 2 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|
1392 if bart.prec_scale is None: 1!03;aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|
1393 scaled_resid = resid 1!0abcdjkzC%'(elYw67895?`{|
1394 else:
1395 scaled_resid = resid * bart.prec_scale 13;LxMNPQZ1O2
1396 norm2 = resid @ scaled_resid 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|
1397 beta = bart.error_cov_scale / 2 + norm2 / 2 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|
1399 sample = random.gamma(key, alpha) 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|
1400 # random.gamma seems to be slow at compiling, maybe cdf inversion would
1401 # be better, but it's not implemented in jax
1402 return replace(bart, error_cov_inv=sample / beta) 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895?`{|
1405def _step_error_cov_inv_mv(key: Key[Array, ''], bart: State) -> State:
1406 n = bart.resid.shape[-1] 1mnoDEpFGHqrItJKuvsfABghi?`{|yRST
1407 df_post = bart.error_cov_df + n 1mnoDEpFGHqrItJKuvsfABghi?`{|yRST
1408 scale_post = bart.error_cov_scale + bart.resid @ bart.resid.T 1mnoDEpFGHqrItJKuvsfABghi?`{|yRST
1410 prec = _sample_wishart_bartlett(key, df_post, scale_post) 1mnoDEpFGHqrItJKuvsfABghi?`{|yRST
1411 return replace(bart, error_cov_inv=prec) 1mnoDEpFGHqrItJKuvsfABghi?`{|yRST
1414@named_call
1415def step_error_cov_inv(key: Key[Array, ''], bart: State) -> State:
1416 """
1417 MCMC-update the inverse error covariance.
1419 Handles both univariate and multivariate cases based on the BART state's
1420 `kind` attribute.
1422 Parameters
1423 ----------
1424 key
1425 A jax random key.
1426 bart
1427 A BART mcmc state.
1429 Returns
1430 -------
1431 The new BART mcmc state, with an updated `error_cov_inv`.
1432 """
1433 assert bart.error_cov_inv is not None 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST
1434 if bart.error_cov_inv.ndim == 2: 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST
1435 return _step_error_cov_inv_mv(key, bart) 1mnoDEpFGHqrItJKuvsfABghiyRST
1436 else:
1437 return _step_error_cov_inv_uv(key, bart) 1!03aLbxcMdNjPkQzZC1%'(eOl2Yw67895
1440@named_call
1441def step_z(key: Key[Array, ''], bart: State) -> State:
1442 """
1443 MCMC-update the latent variable for binary regression.
1445 Parameters
1446 ----------
1447 key
1448 A jax random key.
1449 bart
1450 A BART MCMC state.
1452 Returns
1453 -------
1454 The updated BART MCMC state.
1455 """
1456 trees_plus_offset = bart.z - bart.resid 1*#+4$.,/-:
1457 assert bart.y.dtype == bool 1*#+4$.,/-:
1458 resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset) 1*#+4$.,/-:
1459 z = trees_plus_offset + resid 1*#+4$.,/-:
1460 return replace(bart, z=z, resid=resid) 1*#+4$.,/-:
1463@named_call
1464def step_s(key: Key[Array, ''], bart: State) -> State:
1465 """
1466 Update `log_s` using Dirichlet sampling.
1468 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
1469 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
1470 varcount is the count of how many times each variable is used in the
1471 current forest.
1473 Parameters
1474 ----------
1475 key
1476 Random key for sampling.
1477 bart
1478 The current BART state.
1480 Returns
1481 -------
1482 Updated BART state with re-sampled `log_s`.
1484 Notes
1485 -----
1486 This full conditional is approximated, because it does not take into account
1487 that there are forbidden decision rules.
1488 """
1489 assert bart.forest.theta is not None 1;aLbxcMdNjPkQzZC1eOl2Yw
1491 # histogram current variable usage
1492 p = bart.forest.max_split.size 1;aLbxcMdNjPkQzZC1eOl2Yw
1493 varcount = var_histogram( 1;aLbxcMdNjPkQzZC1eOl2Yw
1494 p, bart.forest.var_tree, bart.forest.split_tree, sum_batch_axis=-1
1495 )
1497 # sample from Dirichlet posterior
1498 alpha = bart.forest.theta / p + varcount 1;aLbxcMdNjPkQzZC1eOl2Yw
1499 log_s = random.loggamma(key, alpha) 1;aLbxcMdNjPkQzZC1eOl2Yw
1501 # update forest with new s
1502 return replace(bart, forest=replace(bart.forest, log_s=log_s)) 1;aLbxcMdNjPkQzZC1eOl2Yw
1505@named_call
1506def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State:
1507 """
1508 Update `theta`.
1510 The prior is theta / (theta + rho) ~ Beta(a, b).
1512 Parameters
1513 ----------
1514 key
1515 Random key for sampling.
1516 bart
1517 The current BART state.
1518 num_grid
1519 The number of points in the evenly-spaced grid used to sample
1520 theta / (theta + rho).
1522 Returns
1523 -------
1524 Updated BART state with re-sampled `theta`.
1525 """
1526 assert bart.forest.log_s is not None 1abcdjkzCelw
1527 assert bart.forest.rho is not None 1abcdjkzCelw
1528 assert bart.forest.a is not None 1abcdjkzCelw
1529 assert bart.forest.b is not None 1abcdjkzCelw
1531 # the grid points are the midpoints of num_grid bins in (0, 1)
1532 padding = 1 / (2 * num_grid) 1abcdjkzCelw
1533 lamda_grid = jnp.linspace(padding, 1 - padding, num_grid) 1abcdjkzCelw
1535 # normalize s
1536 log_s = bart.forest.log_s - logsumexp(bart.forest.log_s) 1abcdjkzCelw
1538 # sample lambda
1539 logp, theta_grid = _log_p_lamda( 1abcdjkzCelw
1540 lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
1541 )
1542 i = random.categorical(key, logp) 1abcdjkzCelw
1543 theta = theta_grid[i] 1abcdjkzCelw
1545 return replace(bart, forest=replace(bart.forest, theta=theta)) 1abcdjkzCelw
1548def _log_p_lamda(
1549 lamda: Float32[Array, ' num_grid'],
1550 log_s: Float32[Array, ' p'],
1551 rho: Float32[Array, ''],
1552 a: Float32[Array, ''],
1553 b: Float32[Array, ''],
1554) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
1555 # in the following I use lamda[::-1] == 1 - lamda
1556 theta = rho * lamda / lamda[::-1] 1abcdjkzCelw
1557 p = log_s.size 1abcdjkzCelw
1558 return ( 1abcdjkzCelw
1559 (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
1560 + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
1561 + gammaln(theta)
1562 - p * gammaln(theta / p)
1563 + theta / p * jnp.sum(log_s)
1564 ), theta
1567@named_call
1568def step_sparse(key: Key[Array, ''], bart: State) -> State:
1569 """
1570 Update the sparsity parameters.
1572 This invokes `step_s`, and then `step_theta` only if the parameters of
1573 the theta prior are defined.
1575 Parameters
1576 ----------
1577 key
1578 Random key for sampling.
1579 bart
1580 The current BART state.
1582 Returns
1583 -------
1584 Updated BART state with re-sampled `log_s` and `theta`.
1585 """
1586 if bart.config.sparse_on_at is not None: 1!0*3;a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST
1587 bart = lax.cond( 1;aLbxcMdNjPkQzZC1eOl2Yw
1588 bart.config.steps_done < bart.config.sparse_on_at,
1589 lambda _key, bart: bart,
1590 _step_sparse,
1591 key,
1592 bart,
1593 )
1594 return bart 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST
1597def _step_sparse(key: Key[Array, ''], bart: State) -> State:
1598 keys = split(key) 1;aLbxcMdNjPkQzZC1eOl2Yw
1599 bart = step_s(keys.pop(), bart) 1;aLbxcMdNjPkQzZC1eOl2Yw
1600 if bart.forest.rho is not None: 1;aLbxcMdNjPkQzZC1eOl2Yw
1601 bart = step_theta(keys.pop(), bart) 1abcdjkzCelw
1602 return bart 1;aLbxcMdNjPkQzZC1eOl2Yw
1605@named_call
1606def step_config(bart: State) -> State:
1607 config = bart.config 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST
1608 config = replace(config, steps_done=config.steps_done + 1) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST
1609 return replace(bart, config=config) 1!0*3a#L+b4xc$MdNjPkQz.Z,C/1%'(e-Ol:2Yw6mno7DEp8FGH9qrI5tJKuvsfABghiyRST