Coverage for src / bartz / mcmcstep / _state.py: 93%
301 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
1# bartz/src/bartz/mcmcstep/_state.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Module defining the BART MCMC state and initialization."""
27from collections.abc import Callable, Hashable
28from dataclasses import Field, fields
29from functools import partial, wraps
30from math import ceil, log2
31from typing import Any, Literal, TypeVar
33from equinox import Module, error_if
34from equinox import field as eqx_field
35from jax import NamedSharding, device_put, eval_shape, make_mesh, random, tree, vmap
36from jax import numpy as jnp
37from jax.scipy.linalg import solve_triangular
38from jax.sharding import AxisType, Mesh, PartitionSpec
39from jax.tree import flatten
40from jaxtyping import Array, Bool, Float32, Int32, Integer, PyTree, Shaped, UInt
42from bartz.grove import make_tree, tree_depths
43from bartz.jaxext import get_default_device, is_key, minimal_unsigned_dtype
46def field(*, chains: bool = False, data: bool = False, **kwargs) -> Field:
47 """Extend `equinox.field` with two new parameters.
49 Parameters
50 ----------
51 chains
52 Whether the arrays in the field have an optional first axis that
53 represents independent Markov chains.
54 data
55 Whether the last axis of the arrays in the field represent units of
56 the data.
57 **kwargs
58 Other parameters passed to `equinox.field`.
60 Returns
61 -------
62 A dataclass field descriptor with the special attributes in the metadata, unset if False.
63 """
64 metadata = dict(kwargs.pop('metadata', {}))
65 assert 'chains' not in metadata
66 assert 'data' not in metadata
67 if chains:
68 metadata['chains'] = True
69 if data:
70 metadata['data'] = True
71 return eqx_field(metadata=metadata, **kwargs)
74def chain_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']:
75 """Determine vmapping axes for chains.
77 This function determines the argument to the `in_axes` or `out_axes`
78 parameter of `jax.vmap` to vmap over all and only the chain axes found in the
79 pytree `x`.
81 Parameters
82 ----------
83 x
84 A pytree. Subpytrees that are Module attributes marked with
85 ``field(..., chains=True)`` are considered to have a leading chain axis.
87 Returns
88 -------
89 A pytree with the same structure as `x` with 0 or None in the leaves.
90 """
91 return _find_metadata(x, 'chains', 0, None) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
94def data_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']:
95 """Determine vmapping axes for data.
97 This is analogous to `chain_vmap_axes` but returns -1 for all fields
98 marked with ``field(..., data=True)``.
99 """
100 return _find_metadata(x, 'data', -1, None) 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%-WXpqr'YsZvwxab0efgcdy(
103T = TypeVar('T')
106def _find_metadata(
107 x: PyTree[Any, ' S'], key: Hashable, if_true: T, if_false: T
108) -> PyTree[T, ' S']:
109 """Replace all subtrees of x marked with a metadata key."""
110 if isinstance(x, Module): 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
111 args = [] 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
112 for f in fields(x): 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
113 v = getattr(x, f.name) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
114 if f.metadata.get('static', False): 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
115 args.append(v) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 O 2 C t I k J l P D m E n 8 Q 9 R ! S # T $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbv w x a b 0 e f g c d y 1 z (
116 elif f.metadata.get(key, False): 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
117 subtree = tree.map(lambda _: if_true, v) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
118 args.append(subtree) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
119 else:
120 args.append(_find_metadata(v, key, if_true, if_false)) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
121 return x.__class__(*args) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
123 def is_leaf(x) -> bool: 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
124 return isinstance(x, Module) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
126 def get_axes(x: Module | Any) -> PyTree[T]: 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
127 if isinstance(x, Module): 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
128 return _find_metadata(x, key, if_true, if_false) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
129 else:
130 return tree.map(lambda _: if_false, x) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
132 return tree.map(get_axes, x, is_leaf=is_leaf) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (
135class Forest(Module):
136 """
137 Represents the MCMC state of a sum of trees.
139 Parameters
140 ----------
141 leaf_tree
142 The leaf values.
143 var_tree
144 The decision axes.
145 split_tree
146 The decision boundaries.
147 affluence_tree
148 Marks leaves that can be grown.
149 max_split
150 The maximum split index for each predictor.
151 blocked_vars
152 Indices of variables that are not used. This shall include at least
153 the `i` such that ``max_split[i] == 0``, otherwise behavior is
154 undefined.
155 p_nonterminal
156 The prior probability of each node being nonterminal, conditional on
157 its ancestors. Includes the nodes at maximum depth which should be set
158 to 0.
159 p_propose_grow
160 The unnormalized probability of picking a leaf for a grow proposal.
161 leaf_indices
162 The index of the leaf each datapoints falls into, for each tree.
163 min_points_per_decision_node
164 The minimum number of data points in a decision node.
165 min_points_per_leaf
166 The minimum number of data points in a leaf node.
167 log_trans_prior
168 The log transition and prior Metropolis-Hastings ratio for the
169 proposed move on each tree.
170 log_likelihood
171 The log likelihood ratio.
172 grow_prop_count
173 prune_prop_count
174 The number of grow/prune proposals made during one full MCMC cycle.
175 grow_acc_count
176 prune_acc_count
177 The number of grow/prune moves accepted during one full MCMC cycle.
178 leaf_prior_cov_inv
179 The prior precision matrix of a leaf, conditional on the tree structure.
180 For the univariate case (k=1), this is a scalar (the inverse variance).
181 The prior covariance of the sum of trees is
182 ``num_trees * leaf_prior_cov_inv^-1``.
183 log_s
184 The logarithm of the prior probability for choosing a variable to split
185 along in a decision rule, conditional on the ancestors. Not normalized.
186 If `None`, use a uniform distribution.
187 theta
188 The concentration parameter for the Dirichlet prior on the variable
189 distribution `s`. Required only to update `s`.
190 a
191 b
192 rho
193 Parameters of the prior on `theta`. Required only to sample `theta`.
194 See `step_theta`.
195 """
197 leaf_tree: (
198 Float32[Array, '*chains num_trees 2**d']
199 | Float32[Array, '*chains num_trees k 2**d']
200 ) = field(chains=True)
201 var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
202 split_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
203 affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
204 max_split: UInt[Array, ' p']
205 blocked_vars: UInt[Array, ' q'] | None
206 p_nonterminal: Float32[Array, ' 2**d']
207 p_propose_grow: Float32[Array, ' 2**(d-1)']
208 leaf_indices: UInt[Array, '*chains num_trees n'] = field(chains=True, data=True)
209 min_points_per_decision_node: Int32[Array, ''] | None
210 min_points_per_leaf: Int32[Array, ''] | None
211 log_trans_prior: Float32[Array, '*chains num_trees'] | None = field(chains=True)
212 log_likelihood: Float32[Array, '*chains num_trees'] | None = field(chains=True)
213 grow_prop_count: Int32[Array, '*chains'] = field(chains=True)
214 prune_prop_count: Int32[Array, '*chains'] = field(chains=True)
215 grow_acc_count: Int32[Array, '*chains'] = field(chains=True)
216 prune_acc_count: Int32[Array, '*chains'] = field(chains=True)
217 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None
218 log_s: Float32[Array, '*chains p'] | None = field(chains=True)
219 theta: Float32[Array, '*chains'] | None = field(chains=True)
220 a: Float32[Array, ''] | None
221 b: Float32[Array, ''] | None
222 rho: Float32[Array, ''] | None
224 def num_chains(self) -> int | None:
225 """Return the number of chains, or `None` if not multichain."""
226 # maybe this should be replaced by chain_shape() -> () | (int,)
227 if self.var_tree.ndim == 2: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =
228 return None 2K h 4 u A i 5 6 H B j 7 2 C t I k J l P D m , E n 8 9 ! # ) * + fb$ V F o % X G p q r ' s zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbOb1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =
229 else:
230 return self.var_tree.shape[0] 23 h mbL u A i nbM obN B j . O C t pbk qbl D m E n rbQ sbR tbS ubT U F o - W G p q r vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbPbQbRbSbTbUbVbWbXbYbZb0bv w x a b e f g c d z
233class StepConfig(Module):
234 """Options for the MCMC step.
236 Parameters
237 ----------
238 steps_done
239 The number of MCMC steps completed so far.
240 sparse_on_at
241 After how many steps to turn on variable selection.
242 resid_batch_size
243 count_batch_size
244 The data batch sizes for computing the sufficient statistics. If `None`,
245 they are computed with no batching.
246 mesh
247 The mesh used to shard data and computation across multiple devices.
248 """
250 steps_done: Int32[Array, '']
251 sparse_on_at: Int32[Array, ''] | None
252 resid_batch_size: int | None = field(static=True)
253 count_batch_size: int | None = field(static=True)
254 mesh: Mesh | None = field(static=True)
257class State(Module):
258 """
259 Represents the MCMC state of BART.
261 Parameters
262 ----------
263 X
264 The predictors.
265 y
266 The response. If the data type is `bool`, the model is binary regression.
267 resid
268 The residuals (`y` or `z` minus sum of trees).
269 z
270 The latent variable for binary regression. `None` in continuous
271 regression.
272 offset
273 Constant shift added to the sum of trees.
274 error_cov_inv
275 The inverse error covariance (scalar for univariate, matrix for multivariate).
276 `None` in binary regression.
277 prec_scale
278 The scale on the error precision, i.e., ``1 / error_scale ** 2``.
279 `None` in binary regression.
280 error_cov_df
281 error_cov_scale
282 The df and scale parameters of the inverse Wishart prior on the noise
283 covariance. For the univariate case, the relationship to the inverse
284 gamma prior parameters is ``alpha = df / 2``, ``beta = scale / 2``.
285 `None` in binary regression.
286 forest
287 The sum of trees model.
288 config
289 Metadata and configurations for the MCMC step.
290 """
292 X: UInt[Array, 'p n'] = field(data=True)
293 y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'] = field(
294 data=True
295 )
296 z: None | Float32[Array, '*chains n'] = field(chains=True, data=True)
297 offset: Float32[Array, ''] | Float32[Array, ' k']
298 resid: Float32[Array, '*chains n'] | Float32[Array, '*chains k n'] = field(
299 chains=True, data=True
300 )
301 error_cov_inv: Float32[Array, '*chains'] | Float32[Array, '*chains k k'] | None = (
302 field(chains=True)
303 )
304 prec_scale: Float32[Array, ' n'] | None = field(data=True)
305 error_cov_df: Float32[Array, ''] | None
306 error_cov_scale: Float32[Array, ''] | Float32[Array, 'k k'] | None
307 forest: Forest
308 config: StepConfig
311def _init_shape_shifting_parameters(
312 y: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'],
313 offset: Float32[Array, ''] | Float32[Array, ' k'],
314 error_scale: Float32[Any, ' n'] | None,
315 error_cov_df: float | Float32[Any, ''] | None,
316 error_cov_scale: float | Float32[Any, ''] | Float32[Any, 'k k'] | None,
317 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
318) -> tuple[
319 bool,
320 tuple[()] | tuple[int],
321 None | Float32[Array, ''],
322 None | Float32[Array, ''],
323 None | Float32[Array, ''],
324]:
325 """
326 Check and initialize parameters that change array type/shape based on outcome kind.
328 Parameters
329 ----------
330 y
331 The response variable; the outcome type is deduced from `y` and then
332 all other parameters are checked against it.
333 offset
334 The offset to add to the predictions.
335 error_scale
336 Per-observation error scale (univariate only).
337 error_cov_df
338 The error covariance degrees of freedom.
339 error_cov_scale
340 The error covariance scale.
341 leaf_prior_cov_inv
342 The inverse of the leaf prior covariance.
344 Returns
345 -------
346 is_binary
347 Whether the outcome is binary.
348 kshape
349 The outcome shape, empty for univariate, (k,) for multivariate.
350 error_cov_inv
351 The initialized error covariance inverse.
352 error_cov_df
353 The error covariance degrees of freedom (as array).
354 error_cov_scale
355 The error covariance scale (as array).
357 Raises
358 ------
359 ValueError
360 If `y` is binary and multivariate.
361 """
362 # determine outcome kind, binary/continuous x univariate/multivariate
363 is_binary = y.dtype == bool 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
364 kshape = y.shape[:-1] 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
366 # Binary vs continuous
367 if is_binary: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
368 if kshape: 368 ↛ 369line 368 didn't jump to line 369 because the condition on line 368 was never true23 mbA nbobB . C pbqbD E rbsbtbubF - G vb
369 msg = 'Binary multivariate regression not supported, open an issue at https://github.com/bartz-org/bartz/issues if you need it.'
370 raise ValueError(msg)
371 assert error_scale is None 23 mbA nbobB . C pbqbD E rbsbtbubF - G vb
372 assert error_cov_df is None 23 mbA nbobB . C pbqbD E rbsbtbubF - G vb
373 assert error_cov_scale is None 23 mbA nbobB . C pbqbD E rbsbtbubF - G vb
374 error_cov_inv = None 23 mbA nbobB . C pbqbD E rbsbtbubF - G vb
375 else:
376 error_cov_df = jnp.asarray(error_cov_df) 2K h 4 L u i 5 M 6 N H j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T ) * + fb$ U V o % W X p q r ' Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
377 error_cov_scale = jnp.asarray(error_cov_scale) 2K h 4 L u i 5 M 6 N H j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T ) * + fb$ U V o % W X p q r ' Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
378 assert error_cov_scale.shape == 2 * kshape 2K h 4 L u i 5 M 6 N H j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T ) * + fb$ U V o % W X p q r ' Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
380 # Multivariate vs univariate
381 if kshape: 2K h 4 L u i 5 M 6 N H j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T ) * + fb$ U V o % W X p q r ' Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
382 error_cov_inv = error_cov_df * _inv_via_chol_with_gersh(error_cov_scale) 2{ | } ~ abbbcbdbeb^ _ ` v w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
383 else:
384 # inverse gamma prior: alpha = df / 2, beta = scale / 2
385 error_cov_inv = error_cov_df / error_cov_scale 2K h 4 L u i 5 M 6 N H j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T ) * + fb$ U V o % W X p q r ' Y s Z wbxbyblbgbhbibjb/ : ; =
387 assert leaf_prior_cov_inv.shape == 2 * kshape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
388 assert offset.shape == kshape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
390 return is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
393def _parse_p_nonterminal(
394 p_nonterminal: Float32[Any, ' d_minus_1'],
395) -> Float32[Array, ' d_minus_1+1']:
396 """Check it's in (0, 1) and pad with a 0 at the end."""
397 p_nonterminal = jnp.asarray(p_nonterminal) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
398 ok = (p_nonterminal > 0) & (p_nonterminal < 1) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
399 p_nonterminal = error_if(p_nonterminal, ~ok, 'p_nonterminal must be in (0, 1)') 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
400 return jnp.pad(p_nonterminal, (0, 1)) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
403def init(
404 *,
405 X: UInt[Any, 'p n'],
406 y: Float32[Any, ' n'] | Float32[Any, ' k n'] | Bool[Any, ' n'],
407 offset: float | Float32[Any, ''] | Float32[Any, ' k'],
408 max_split: UInt[Any, ' p'],
409 num_trees: int,
410 p_nonterminal: Float32[Any, ' d_minus_1'],
411 leaf_prior_cov_inv: float | Float32[Any, ''] | Float32[Array, 'k k'],
412 error_cov_df: float | Float32[Any, ''] | None = None,
413 error_cov_scale: float | Float32[Any, ''] | Float32[Array, 'k k'] | None = None,
414 error_scale: Float32[Any, ' n'] | None = None,
415 min_points_per_decision_node: int | Integer[Any, ''] | None = None,
416 resid_batch_size: int | None | Literal['auto'] = 'auto',
417 count_batch_size: int | None | Literal['auto'] = 'auto',
418 save_ratios: bool = False,
419 filter_splitless_vars: bool = True,
420 min_points_per_leaf: int | Integer[Any, ''] | None = None,
421 log_s: Float32[Any, ' p'] | None = None,
422 theta: float | Float32[Any, ''] | None = None,
423 a: float | Float32[Any, ''] | None = None,
424 b: float | Float32[Any, ''] | None = None,
425 rho: float | Float32[Any, ''] | None = None,
426 sparse_on_at: int | Integer[Any, ''] | None = None,
427 num_chains: int | None = None,
428 mesh: Mesh | dict[str, int] | None = None,
429 target_platform: Literal['cpu', 'gpu'] | None = None,
430) -> State:
431 """
432 Make a BART posterior sampling MCMC initial state.
434 Parameters
435 ----------
436 X
437 The predictors. Note this is trasposed compared to the usual convention.
438 y
439 The response. If the data type is `bool`, the regression model is binary
440 regression with probit. If two-dimensional, the outcome is multivariate
441 with the first axis indicating the component.
442 offset
443 Constant shift added to the sum of trees. 0 if not specified.
444 max_split
445 The maximum split index for each variable. All split ranges start at 1.
446 num_trees
447 The number of trees in the forest.
448 p_nonterminal
449 The probability of a nonterminal node at each depth. The maximum depth
450 of trees is fixed by the length of this array.
451 leaf_prior_cov_inv
452 The prior precision matrix of a leaf, conditional on the tree structure.
453 For the univariate case (k=1), this is a scalar (the inverse variance).
454 The prior covariance of the sum of trees is
455 ``num_trees * leaf_prior_cov_inv^-1``. The prior mean of leaves is
456 always zero.
457 error_cov_df
458 error_cov_scale
459 The df and scale parameters of the inverse Wishart prior on the error
460 covariance. For the univariate case, the relationship to the inverse
461 gamma prior parameters is ``alpha = df / 2``, ``beta = scale / 2``.
462 Leave unspecified for binary regression.
463 error_scale
464 Each error is scaled by the corresponding factor in `error_scale`, so
465 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
466 Not supported for binary regression. If not specified, defaults to 1 for
467 all points, but potentially skipping calculations.
468 min_points_per_decision_node
469 The minimum number of data points in a decision node. 0 if not
470 specified.
471 resid_batch_size
472 count_batch_size
473 The batch sizes, along datapoints, for summing the residuals and
474 counting the number of datapoints in each leaf. `None` for no batching.
475 If 'auto', it's chosen automatically based on the target platform; see
476 the description of `target_platform` below for how it is determined.
477 save_ratios
478 Whether to save the Metropolis-Hastings ratios.
479 filter_splitless_vars
480 Whether to check `max_split` for variables without available cutpoints.
481 If any are found, they are put into a list of variables to exclude from
482 the MCMC. If `False`, no check is performed, but the results may be
483 wrong if any variable is blocked. The function is jax-traceable only
484 if this is set to `False`.
485 min_points_per_leaf
486 The minimum number of datapoints in a leaf node. 0 if not specified.
487 Unlike `min_points_per_decision_node`, this constraint is not taken into
488 account in the Metropolis-Hastings ratio because it would be expensive
489 to compute. Grow moves that would violate this constraint are vetoed.
490 This parameter is independent of `min_points_per_decision_node` and
491 there is no check that they are coherent. It makes sense to set
492 ``min_points_per_decision_node >= 2 * min_points_per_leaf``.
493 log_s
494 The logarithm of the prior probability for choosing a variable to split
495 along in a decision rule, conditional on the ancestors. Not normalized.
496 If not specified, use a uniform distribution. If not specified and
497 `theta` or `rho`, `a`, `b` are, it's initialized automatically.
498 theta
499 The concentration parameter for the Dirichlet prior on `s`. Required
500 only to update `log_s`. If not specified, and `rho`, `a`, `b` are
501 specified, it's initialized automatically.
502 a
503 b
504 rho
505 Parameters of the prior on `theta`. Required only to sample `theta`.
506 sparse_on_at
507 After how many MCMC steps to turn on variable selection.
508 num_chains
509 The number of independent MCMC chains to represent in the state. Single
510 chain with scalar values if not specified.
511 mesh
512 A jax mesh used to shard data and computation across multiple devices.
513 If it has a 'chains' axis, that axis is used to shard the chains. If it
514 has a 'data' axis, that axis is used to shard the datapoints.
516 As a shorthand, if a dictionary mapping axis names to axis size is
517 passed, the corresponding mesh is created, e.g., ``dict(chains=4,
518 data=2)`` will let jax pick 8 devices to split chains (which must be a
519 multiple of 4) across 4 pairs of devices, where in each pair the data is
520 split in two.
522 Note: if a mesh is passed, the arrays are always sharded according to
523 it. In particular even if the mesh has no 'chains' or 'data' axis, the
524 arrays will be replicated on all devices in the mesh.
525 target_platform
526 Platform ('cpu' or 'gpu') used to determine the batch sizes
527 automatically. If `mesh` is specified, the platform is inferred from the
528 devices in the mesh. Otherwise, if `y` is a concrete array (i.e., `init`
529 is not invoked in a `jax.jit` context), the platform is set to the
530 platform of `y`. Otherwise, use `target_platform`.
532 To avoid confusion, in all cases where the `target_platform` argument
533 would be ignored, `init` raises an exception if `target_platform` is
534 set.
536 Returns
537 -------
538 An initialized BART MCMC state.
540 Raises
541 ------
542 ValueError
543 If `y` is boolean and arguments unused in binary regression are set.
545 Notes
546 -----
547 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
548 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
549 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
550 integers in the range ``[0, 1, ..., max_split[i]]``.
551 """
552 # convert to array all array-like arguments that are used in other
553 # configurations but don't need further processing themselves
554 X = jnp.asarray(X) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
555 y = jnp.asarray(y) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
556 offset = jnp.asarray(offset) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
557 leaf_prior_cov_inv = jnp.asarray(leaf_prior_cov_inv) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
558 max_split = jnp.asarray(max_split) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
560 # check p_nonterminal and pad it with a 0 at the end (still not final shape)
561 p_nonterminal = _parse_p_nonterminal(p_nonterminal) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
563 # process arguments that change depending on outcome type
564 is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale = ( 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
565 _init_shape_shifting_parameters(
566 y, offset, error_scale, error_cov_df, error_cov_scale, leaf_prior_cov_inv
567 )
568 )
570 # extract array sizes from arguments
571 (max_depth,) = p_nonterminal.shape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
572 p, n = X.shape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
574 # check and initialize sparsity parameters
575 if not _all_none_or_not_none(rho, a, b): 575 ↛ 576line 575 didn't jump to line 576 because the condition on line 575 was never true2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
576 msg = 'rho, a, b are not either all `None` or all set'
577 raise ValueError(msg)
578 if theta is None and rho is not None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
579 theta = rho 1K4u56H72IJP,89!#$V%Xr'
580 if log_s is None and theta is not None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
581 log_s = jnp.zeros(max_split.size) 1Kh4Lui5M6NHj7O2tIkJlPm,n8Q9R!S#T$UVo%WXpqr'Y
582 if not _all_none_or_not_none(theta, sparse_on_at): 582 ↛ 583line 582 didn't jump to line 583 because the condition on line 582 was never true2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
583 msg = 'sparsity params (either theta or rho,a,b) and sparse_on_at must be either all None or all set'
584 raise ValueError(msg)
586 # process multichain settings
587 chain_shape = () if num_chains is None else (num_chains,) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
588 resid_shape = chain_shape + y.shape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
589 tree_shape = (*chain_shape, num_trees) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
590 add_chains = partial(_add_chains, chain_shape=chain_shape) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
592 # determine batch sizes for reductions
593 mesh = _parse_mesh(num_chains, mesh) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
594 target_platform = _parse_target_platform( 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
595 y, mesh, target_platform, resid_batch_size, count_batch_size
596 )
597 resid_batch_size, count_batch_size = _choose_suffstat_batch_size( 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
598 resid_batch_size,
599 count_batch_size,
600 y,
601 max_depth,
602 num_trees,
603 num_chains,
604 mesh,
605 target_platform,
606 )
608 # initialize all remaining stuff and put it in an unsharded state
609 state = State( 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
610 X=X,
611 y=y,
612 z=jnp.full(resid_shape, offset) if is_binary else None,
613 offset=offset,
614 resid=jnp.zeros(resid_shape)
615 if is_binary
616 else jnp.broadcast_to(y - offset[..., None], resid_shape),
617 error_cov_inv=add_chains(error_cov_inv),
618 prec_scale=(
619 None if error_scale is None else jnp.reciprocal(jnp.square(error_scale))
620 ),
621 error_cov_df=error_cov_df,
622 error_cov_scale=error_cov_scale,
623 forest=Forest(
624 leaf_tree=make_tree(max_depth, jnp.float32, tree_shape + kshape),
625 var_tree=make_tree(
626 max_depth - 1, minimal_unsigned_dtype(p - 1), tree_shape
627 ),
628 split_tree=make_tree(max_depth - 1, max_split.dtype, tree_shape),
629 affluence_tree=(
630 make_tree(max_depth - 1, bool, tree_shape)
631 .at[..., 1]
632 .set(
633 True
634 if min_points_per_decision_node is None
635 else n >= min_points_per_decision_node
636 )
637 ),
638 blocked_vars=_get_blocked_vars(filter_splitless_vars, max_split),
639 max_split=max_split,
640 grow_prop_count=jnp.zeros(chain_shape, int),
641 grow_acc_count=jnp.zeros(chain_shape, int),
642 prune_prop_count=jnp.zeros(chain_shape, int),
643 prune_acc_count=jnp.zeros(chain_shape, int),
644 p_nonterminal=p_nonterminal[tree_depths(2**max_depth)],
645 p_propose_grow=p_nonterminal[tree_depths(2 ** (max_depth - 1))],
646 leaf_indices=jnp.ones(
647 (*tree_shape, n), minimal_unsigned_dtype(2**max_depth - 1)
648 ),
649 min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
650 min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
651 log_trans_prior=jnp.zeros((*chain_shape, num_trees))
652 if save_ratios
653 else None,
654 log_likelihood=jnp.zeros((*chain_shape, num_trees))
655 if save_ratios
656 else None,
657 leaf_prior_cov_inv=leaf_prior_cov_inv,
658 log_s=add_chains(_asarray_or_none(log_s)),
659 theta=add_chains(_asarray_or_none(theta)),
660 rho=_asarray_or_none(rho),
661 a=_asarray_or_none(a),
662 b=_asarray_or_none(b),
663 ),
664 config=StepConfig(
665 steps_done=jnp.int32(0),
666 sparse_on_at=_asarray_or_none(sparse_on_at),
667 resid_batch_size=resid_batch_size,
668 count_batch_size=count_batch_size,
669 mesh=mesh,
670 ),
671 )
673 # move all arrays to the appropriate device
674 return _shard_state(state) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
677def _get_blocked_vars(
678 filter_splitless_vars: bool, max_split: UInt[Array, ' p']
679) -> None | UInt[Array, ' q']:
680 """Initialize the `blocked_vars` field."""
681 if filter_splitless_vars: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
682 (p,) = max_split.shape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z v w x a b 0 e f g c d y 1 z (
683 (blocked_vars,) = jnp.nonzero(max_split == 0) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z v w x a b 0 e f g c d y 1 z (
684 return blocked_vars.astype(minimal_unsigned_dtype(p)) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z v w x a b 0 e f g c d y 1 z (
685 # see `fully_used_variables` for the type cast
686 else:
687 return None 22 C t { | } wb~ abbbxbcbdbebyb^ _ ` lbgbhbibjb? @ [ ] / : ; =
690def _add_chains(
691 x: Shaped[Array, '*shape'] | None, chain_shape: tuple[int, ...]
692) -> Shaped[Array, '*shape'] | Shaped[Array, ' num_chains *shape'] | None:
693 """Broadcast `x` to all chains."""
694 if x is None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
695 return None 23 mbA nbobB . C pbqbD E rbsbtbub) * + fbF - G vbs Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
696 else:
697 return jnp.broadcast_to(x, chain_shape + x.shape) 2K h 4 L u i 5 M 6 N H j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T ) * + fb$ U V o % W X p q r ' Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
700def _parse_mesh(
701 num_chains: int | None, mesh: Mesh | dict[str, int] | None
702) -> Mesh | None:
703 """Parse the `mesh` argument."""
704 if mesh is None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
705 return None 23 mbA nbobB . C pbqbD , E rbsbtbub) * + fbF - G vb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x 0 1 z ( gbhbibjb? @ [ ] / : ; =
707 # convert dict format to actual mesh
708 if isinstance(mesh, dict): 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
709 assert set(mesh).issubset({'chains', 'data'}) 1abefgcdy
710 mesh = make_mesh( 1abefgcdy
711 tuple(mesh.values()), tuple(mesh), axis_types=(AxisType.Auto,) * len(mesh)
712 )
714 # check there's no chain mesh axis if there are no chains
715 if num_chains is None: 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
716 assert 'chains' not in mesh.axis_names 1K456H72IJP89!#$V%X'y
718 # check the axes we use are in auto mode
719 assert 'chains' not in mesh.axis_names or 'chains' in _auto_axes(mesh) 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
720 assert 'data' not in mesh.axis_names or 'data' in _auto_axes(mesh) 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
722 return mesh 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
725def _parse_target_platform(
726 y: Array,
727 mesh: Mesh | None,
728 target_platform: Literal['cpu', 'gpu'] | None,
729 resid_batch_size: int | None | Literal['auto'],
730 count_batch_size: int | None | Literal['auto'],
731) -> Literal['cpu', 'gpu'] | None:
732 if mesh is not None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
733 assert target_platform is None, 'mesh provided, do not set target_platform' 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
734 return mesh.devices.flat[0].platform 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
735 elif hasattr(y, 'platform'): 23 mbA nbobB . C pbqbD , E rbsbtbub) * + fbF - G vb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x 0 1 z ( gbhbibjb? @ [ ] / : ; =
736 assert target_platform is None, 'device inferred from y, unset target_platform' 23 mbA nbobB . C pbqbD , E rbsbtbub) * + fbF - G vbv w x 0 1 z ( gbhbibjb? @ [ ] / : ; =
737 return y.platform() 23 mbA nbobB . C pbqbD , E rbsbtbub) * + fbF - G vbv w x 0 1 z ( gbhbibjb? @ [ ] / : ; =
738 elif resid_batch_size == 'auto' or count_batch_size == 'auto': 2C { | } wb~ abbbxbcbdbebyb^ _ ` lb
739 assert target_platform in ('cpu', 'gpu') 2{ | } wb~ abbbxbcbdbebyb^ _ ` lb
740 return target_platform 2{ | } wb~ abbbxbcbdbebyb^ _ ` lb
741 else:
742 assert target_platform is None, 'target_platform not used, unset it' 1C
743 return target_platform 1C
746def _auto_axes(mesh: Mesh) -> list[str]:
747 """Re-implement `Mesh.auto_axes` because that's missing in jax v0.5."""
748 # Mesh.auto_axes added in jax v0.6.0
749 return [ 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
750 n
751 for n, t in zip(mesh.axis_names, mesh.axis_types, strict=True)
752 if t == AxisType.Auto
753 ]
756def _shard_state(state: State) -> State:
757 """Place all fields in the state on the appropriate devices."""
758 mesh = state.config.mesh 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
759 if mesh is None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
760 return state 23 mbA nbobB . C pbqbD , E rbsbtbub) * + fbF - G vb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x 0 1 z ( gbhbibjb? @ [ ] / : ; =
762 def shard_leaf( 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
763 x: Array | None, chain_axis: int | None, data_axis: int | None
764 ) -> Array | None:
765 if x is None: 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
766 return None 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
768 spec = [None] * x.ndim 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
769 if chain_axis is not None and 'chains' in mesh.axis_names: 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
770 spec[chain_axis] = 'chains' 1hLuiMNjOtklmnQRSTUoWpqrYsZabcd
771 if data_axis is not None and 'data' in mesh.axis_names: 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
772 spec[data_axis] = 'data' 1K4u56H72IJP89!#$V%X'efgcdy
774 spec = PartitionSpec(*spec) 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
775 return device_put(x, NamedSharding(mesh, spec), donate=True) 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
777 return tree.map( 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy
778 shard_leaf,
779 state,
780 chain_vmap_axes(state),
781 data_vmap_axes(state),
782 is_leaf=lambda x: x is None,
783 )
786def _all_none_or_not_none(*args):
787 is_none = [x is None for x in args] 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
788 return all(is_none) or not any(is_none) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
791def _asarray_or_none(x):
792 if x is None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
793 return None 23 h mbL A i nbM obN H B j . O C t I pbk J qbl D m E n rbQ sbR tbS ubT ) * + fbU F o - W G p q vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
794 return jnp.asarray(x) 2K h 4 L u A i 5 M 6 N H B j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T $ U V o % W X p q r ' Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lb
797def _get_platform(mesh: Mesh | None) -> str:
798 if mesh is None:
799 return get_default_device().platform
800 else:
801 return mesh.devices.flat[0].platform
804def _choose_suffstat_batch_size(
805 resid_batch_size: int | None | Literal['auto'],
806 count_batch_size: int | None | Literal['auto'],
807 y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'],
808 max_depth: int,
809 num_trees: int,
810 num_chains: int | None,
811 mesh: Mesh | None,
812 target_platform: Literal['cpu', 'gpu'] | None,
813) -> tuple[int | None, int | None]:
814 """Determine batch sizes for reductions."""
815 # get number of outcomes and of datapoints
816 k, n = _get_k_n(y) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
818 # get per-device values
819 if num_chains is None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
820 num_chains = 1 2K 4 5 6 H 7 2 I J P , 8 9 ! # ) * + fb$ V % X ' ^ _ ` lb0 y 1 z ( gbhbibjb? @ [ ] / : ; =
821 num_chains //= get_axis_size(mesh, 'chains') 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
822 n //= get_axis_size(mesh, 'data') 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
824 # compute auxiliary sizes
825 batch_size = k * num_chains 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
826 unbatched_accum_bytes_times_batch_size = num_trees * 2**max_depth * 4 * n 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
828 def final_round(s: float) -> int | None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
829 # multiply by batch_size because if the calculation is already
830 # parallelizable over batching dims there is correspondingly less need
831 # to parallelize across datapoints
832 s *= batch_size 2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (
834 # at least 1, i.e., each datapoint is its own batch
835 s = max(1, s) 2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (
837 # round to the nearest power of 2 because I guess XLA and the hardware
838 # will like that
839 s = 2 ** round(log2(s)) 2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (
841 # disable batching if the batch is as large as the whole dataset
842 return s if s < n else None 2h L i M N :bj O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( ;b
844 if resid_batch_size != 'auto': 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
845 rbs = resid_batch_size 2K 3 4 mbu A 5 nb6 obH B 7 . 2 C I pbJ qbP D , E 8 rb9 sb! tb# ub$ V F % - X G ' vb1 z gbhbibjb? @ [ ] / : ; =
846 elif target_platform == 'cpu': 846 ↛ 850line 846 didn't jump to line 850 because the condition on line 846 was always true2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (
847 rbs = final_round(n / 6) 2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (
848 # instead of 6 I guess I should have in general the number of "good"
849 # physical cores
850 elif target_platform == 'gpu':
851 rbs = final_round((2 * n) ** (1 / 3))
853 if count_batch_size != 'auto': 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
854 cbs = count_batch_size 2K 3 4 mbu A 5 nb6 obH B 7 . 2 C I pbJ qbP D , E 8 rb9 sb! tb# ub$ V F % - X G ' vb1 z gbhbibjb? @ [ ] / : ; =
855 elif target_platform == 'cpu': 855 ↛ 857line 855 didn't jump to line 857 because the condition on line 855 was always true2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (
856 cbs = None 2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (
857 elif target_platform == 'gpu':
858 cbs = (n / 16) ** 0.5
860 # ensure we don't exceed ~512MiB of memory usage per device
861 max_memory = 2**29
862 min_batch_size = ceil(unbatched_accum_bytes_times_batch_size / max_memory)
863 cbs = max(cbs, min_batch_size)
865 cbs = final_round(cbs)
867 return rbs, cbs 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
870def get_axis_size(mesh: Mesh | None, axis_name: str) -> int:
871 if mesh is None or axis_name not in mesh.axis_names: 2K 3 h 4 mbL ,b-b.b+b/b*bu A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
872 return 1 2K 3 h 4 mbL ,b-b.b/b*bA i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g y 1 z ( gbhbibjb? @ [ ] / : ; =
873 else:
874 i = mesh.axis_names.index(axis_name) 2K h 4 L +b*bu i 5 M 6 N H j 7 O 2 t I k J l P m n 8 Q 9 R ! S # T $ U V o % W X p q r ' Y s Z a b e f g c d y
875 return mesh.axis_sizes[i] 2K h 4 L +b*bu i 5 M 6 N H j 7 O 2 t I k J l P m n 8 Q 9 R ! S # T $ U V o % W X p q r ' Y s Z a b e f g c d y
878def _get_k_n(y: Array) -> tuple[int, int]:
879 if y.ndim == 2: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
880 return y.shape 2{ | } ~ abbbcbdbeb^ _ ` v w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
881 else:
882 (n,) = y.shape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z wbxbyblbgbhbibjb/ : ; =
883 return 1, n 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z wbxbyblbgbhbibjb/ : ; =
886def chol_with_gersh(
887 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool = False
888) -> Float32[Array, '*batch_shape k k']:
889 """Cholesky with Gershgorin stabilization, supports batching."""
890 return _chol_with_gersh_impl(mat, absolute_eps) 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b
893@partial(jnp.vectorize, signature='(k,k)->(k,k)', excluded=(1,))
894def _chol_with_gersh_impl(
895 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool
896) -> Float32[Array, '*batch_shape k k']:
897 rho = jnp.max(jnp.sum(jnp.abs(mat), axis=1), initial=0.0) 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b
898 eps = jnp.finfo(mat.dtype).eps 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b
899 u = mat.shape[0] * rho * eps 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b
900 if absolute_eps: 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b
901 u += eps 2zbAbBbCbLbMbv a b 0 e c d y 8b? 9b!b
902 mat = mat.at[jnp.diag_indices_from(mat)].add(u) 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b
903 return jnp.linalg.cholesky(mat) 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b
906def _inv_via_chol_with_gersh(mat: Float32[Array, 'k k']) -> Float32[Array, 'k k']:
907 """Compute matrix inverse via Cholesky with Gershgorin stabilization.
909 DO NOT USE THIS FUNCTION UNLESS YOU REALLY NEED TO.
910 """
911 L = chol_with_gersh(mat) 2{ | } ~ abbbcbdbeb^ _ ` v w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
912 I = jnp.eye(mat.shape[0], dtype=mat.dtype) 2{ | } ~ abbbcbdbeb^ _ ` v w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
913 L_inv = solve_triangular(L, I, lower=True) 2{ | } ~ abbbcbdbeb^ _ ` v w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
914 return L_inv.T @ L_inv 2{ | } ~ abbbcbdbeb^ _ ` v w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =
917def get_num_chains(x: PyTree) -> int | None:
918 """Get the number of chains of a pytree.
920 Find all nodes in the structure that define 'num_chains()', stopping
921 traversal at nodes that define it. Check all values obtained invoking
922 `num_chains` are equal, then return it.
923 """
924 leaves, _ = flatten(x, is_leaf=lambda x: hasattr(x, 'num_chains')) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =
925 num_chains = [x.num_chains() for x in leaves if hasattr(x, 'num_chains')] 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =
926 ref = num_chains[0] 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =
927 assert all(c == ref for c in num_chains) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =
928 return ref 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =
931def _chain_axes_with_keys(x: PyTree) -> PyTree[int | None]:
932 """Return `chain_vmap_axes(x)` but also set to 0 for random keys."""
933 axes = chain_vmap_axes(x) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
935 def axis_if_key(x, axis): 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
936 if is_key(x): 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
937 return 0 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
938 else:
939 return axis 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
941 return tree.map(axis_if_key, x, axes) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
944def _get_mc_out_axes(
945 fun: Callable[[tuple, dict], PyTree], args: PyTree, in_axes: PyTree[int | None]
946) -> PyTree[int | None]:
947 """Decide chain vmap axes for outputs."""
948 vmapped_fun = vmap(fun, in_axes=in_axes) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
949 out = eval_shape(vmapped_fun, *args) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
950 return chain_vmap_axes(out) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
953def _find_mesh(x: PyTree) -> Mesh | None:
954 """Find the mesh used for chains."""
956 class MeshFound(Exception): 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
957 pass 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
959 def find_mesh(x: State | Any): 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
960 if isinstance(x, State): 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
961 raise MeshFound(x.config.mesh) 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
963 try: 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
964 tree.map(find_mesh, x, is_leaf=lambda x: isinstance(x, State)) 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
965 except MeshFound as e: 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
966 return e.args[0] 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
967 else:
968 raise ValueError
971def _split_all_keys(x: PyTree, num_chains: int) -> PyTree:
972 """Split all random keys in `num_chains` keys."""
973 mesh = _find_mesh(x) 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
975 def split_key(x): 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
976 if is_key(x): 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
977 x = random.split(x, num_chains) 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
978 if mesh is not None and 'chains' in mesh.axis_names: 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
979 x = device_put(x, NamedSharding(mesh, PartitionSpec('chains'))) 1huijtklmnopqrsabcd
980 return x 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
982 return tree.map(split_key, x) 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
985def vmap_chains(
986 fun: Callable[..., T], *, auto_split_keys: bool = False
987) -> Callable[..., T]:
988 """Apply vmap on chain axes automatically if the inputs are multichain."""
990 @wraps(fun)
991 def auto_vmapped_fun(*args, **kwargs) -> T:
992 all_args = args, kwargs 2K 3 h u A i H B j 2 C t I k J l P D m , E n ) * + V F o X G p q r s zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =
993 num_chains = get_num_chains(all_args) 2K 3 h u A i H B j 2 C t I k J l P D m , E n ) * + V F o X G p q r s zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =
994 if num_chains is not None: 2K 3 h u A i H B j 2 C t I k J l P D m , E n ) * + V F o X G p q r s zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =
995 if auto_split_keys: 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
996 all_args = _split_all_keys(all_args, num_chains) 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d
998 def wrapped_fun(args, kwargs): 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
999 return fun(*args, **kwargs) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
1001 mc_in_axes = _chain_axes_with_keys(all_args) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
1002 mc_out_axes = _get_mc_out_axes(wrapped_fun, all_args, mc_in_axes) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
1003 vmapped_fun = vmap(wrapped_fun, in_axes=mc_in_axes, out_axes=mc_out_axes) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
1004 return vmapped_fun(*all_args) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z
1006 else:
1007 return fun(*args, **kwargs) 2K h u A i H B j 2 C t I k J l P D m , E n ) * + V F o X G p q r s zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =
1009 return auto_vmapped_fun