Coverage for src / bartz / mcmcstep / _state.py: 94%
409 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/mcmcstep/_state.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Module defining the BART MCMC state and initialization."""
27from collections.abc import Callable, Hashable
28from dataclasses import fields, replace
29from functools import partial, wraps
30from math import log2
31from typing import Any, Literal, TypedDict, TypeVar
33import numpy
34from equinox import Module, error_if, filter_jit
35from equinox import field as eqx_field
36from jax import (
37 NamedSharding,
38 device_put,
39 eval_shape,
40 jit,
41 lax,
42 make_mesh,
43 random,
44 tree,
45 vmap,
46)
47from jax import numpy as jnp
48from jax.scipy.linalg import solve_triangular
49from jax.sharding import AxisType, Mesh, PartitionSpec
50from jaxtyping import Array, Bool, Float32, Int32, Integer, PyTree, Shaped, UInt
52from bartz.grove import tree_depths
53from bartz.jaxext import get_default_device, is_key, minimal_unsigned_dtype
56def field(*, chains: bool = False, data: bool = False, **kwargs: Any): # noqa: ANN202
57 """Extend `equinox.field` with two new parameters.
59 Parameters
60 ----------
61 chains
62 Whether the arrays in the field have an optional first axis that
63 represents independent Markov chains.
64 data
65 Whether the last axis of the arrays in the field represent units of
66 the data.
67 **kwargs
68 Other parameters passed to `equinox.field`.
70 Returns
71 -------
72 A dataclass field descriptor with the special attributes in the metadata, unset if False.
73 """
74 metadata = dict(kwargs.pop('metadata', {}))
75 assert 'chains' not in metadata
76 assert 'data' not in metadata
77 if chains:
78 metadata['chains'] = True
79 if data:
80 metadata['data'] = True
81 return eqx_field(metadata=metadata, **kwargs)
84def chain_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']:
85 """Determine vmapping axes for chains.
87 This function determines the argument to the `in_axes` or `out_axes`
88 parameter of `jax.vmap` to vmap over all and only the chain axes found in the
89 pytree `x`.
91 Parameters
92 ----------
93 x
94 A pytree. Subpytrees that are Module attributes marked with
95 ``field(..., chains=True)`` are considered to have a leading chain axis.
97 Returns
98 -------
99 A pytree with the same structure as `x` with 0 or None in the leaves.
100 """
101 return _find_metadata(x, 'chains', 0, None) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
104def data_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']:
105 """Determine vmapping axes for data.
107 This is analogous to `chain_vmap_axes` but returns -1 for all fields
108 marked with ``field(..., data=True)``.
109 """
110 return _find_metadata(x, 'data', -1, None) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
113T = TypeVar('T')
116def _find_metadata(
117 x: PyTree[Any, ' S'], key: Hashable, if_true: T, if_false: T
118) -> PyTree[T, ' S']:
119 """Replace all subtrees of x marked with a metadata key."""
121 def is_lazy_array(x: object) -> bool: 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
122 return isinstance(x, _LazyArray) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
124 def is_module(x: object) -> bool: 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
125 return isinstance(x, Module) and not is_lazy_array(x) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
127 if is_module(x): 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
128 args = [] 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
129 for f in fields(x): 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
130 v = getattr(x, f.name) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
131 if f.metadata.get('static', False): 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
132 args.append(v) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L FbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
133 elif f.metadata.get(key, False): 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
134 subtree = tree.map(lambda _: if_true, v, is_leaf=is_lazy_array) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
135 args.append(subtree) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
136 else:
137 args.append(_find_metadata(v, key, if_true, if_false)) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
138 return x.__class__(*args) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
140 def get_axes(x: object) -> PyTree[T]: 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
141 if is_module(x): 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
142 return _find_metadata(x, key, if_true, if_false) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
143 else:
144 return tree.map(lambda _: if_false, x, is_leaf=is_lazy_array) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
146 def is_leaf(x: object) -> bool: 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
147 return isinstance(x, Module) # this catches _LazyArray as well 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
149 return tree.map(get_axes, x, is_leaf=is_leaf) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
152class Forest(Module):
153 """Represents the MCMC state of a sum of trees."""
155 leaf_tree: (
156 Float32[Array, '*chains num_trees 2**d']
157 | Float32[Array, '*chains num_trees k 2**d']
158 ) = field(chains=True)
159 """The leaf values."""
161 var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
162 """The decision axes."""
164 split_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
165 """The decision boundaries."""
167 affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
168 """Marks leaves that can be grown."""
170 max_split: UInt[Array, ' p']
171 """The maximum split index for each predictor."""
173 blocked_vars: UInt[Array, ' q'] | None
174 """Indices of variables that are not used. This shall include at least
175 the `i` such that ``max_split[i] == 0``, otherwise behavior is
176 undefined."""
178 p_nonterminal: Float32[Array, ' 2**d']
179 """The prior probability of each node being nonterminal, conditional on
180 its ancestors. Includes the nodes at maximum depth which should be set
181 to 0."""
183 p_propose_grow: Float32[Array, ' 2**(d-1)']
184 """The unnormalized probability of picking a leaf for a grow proposal."""
186 leaf_indices: UInt[Array, '*chains num_trees n'] = field(chains=True, data=True)
187 """The index of the leaf each datapoints falls into, for each tree."""
189 min_points_per_decision_node: Int32[Array, ''] | None
190 """The minimum number of data points in a decision node."""
192 min_points_per_leaf: Int32[Array, ''] | None
193 """The minimum number of data points in a leaf node."""
195 log_trans_prior: Float32[Array, '*chains num_trees'] | None = field(chains=True)
196 """The log transition and prior Metropolis-Hastings ratio for the
197 proposed move on each tree."""
199 log_likelihood: Float32[Array, '*chains num_trees'] | None = field(chains=True)
200 """The log likelihood ratio."""
202 grow_prop_count: Int32[Array, '*chains'] = field(chains=True)
203 """The number of grow proposals made during one full MCMC cycle."""
205 prune_prop_count: Int32[Array, '*chains'] = field(chains=True)
206 """The number of prune proposals made during one full MCMC cycle."""
208 grow_acc_count: Int32[Array, '*chains'] = field(chains=True)
209 """The number of grow moves accepted during one full MCMC cycle."""
211 prune_acc_count: Int32[Array, '*chains'] = field(chains=True)
212 """The number of prune moves accepted during one full MCMC cycle."""
214 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None
215 """The prior precision matrix of a leaf, conditional on the tree structure.
216 For the univariate case (k=1), this is a scalar (the inverse variance).
217 The prior covariance of the sum of trees is
218 ``num_trees * leaf_prior_cov_inv^-1``."""
220 log_s: Float32[Array, '*chains p'] | None = field(chains=True)
221 """The logarithm of the prior probability for choosing a variable to split
222 along in a decision rule, conditional on the ancestors. Not normalized.
223 If `None`, use a uniform distribution."""
225 theta: Float32[Array, '*chains'] | None = field(chains=True)
226 """The concentration parameter for the Dirichlet prior on the variable
227 distribution `s`. Required only to update `log_s`."""
229 a: Float32[Array, ''] | None
230 """Parameter of the prior on `theta`. Required only to sample `theta`.
231 See `step_theta`."""
233 b: Float32[Array, ''] | None
234 """Parameter of the prior on `theta`. Required only to sample `theta`.
235 See `step_theta`."""
237 rho: Float32[Array, ''] | None
238 """Parameter of the prior on `theta`. Required only to sample `theta`.
239 See `step_theta`."""
241 def num_chains(self) -> int | None:
242 """Return the number of chains, or `None` if not multichain."""
243 # maybe this should be replaced by chain_shape() -> () | (int,)
244 if self.var_tree.ndim == 2: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbgbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J M O P Q
245 return None 2{ , | } ~ R abU - V [ % ] bbcbdb^ Z . / qbebS fbT UbRbSbTbVbgb8b9b!b#bN J M O P Q
246 else:
247 return self.var_tree.shape[0] 2ubibi K z u B q D jbvbkbwblbm h l C r xbmbtbs ybE zby F v G ' H w AbnbBbobCbpbDb_ rbI t EbsbA x j k L FbGbHbObJbKbIbPbLbMbNbQbWbXbYbZb0b1b2b3b4b5b6b7bn o p c d e f g a b
250class StepConfig(Module):
251 """Options for the MCMC step."""
253 steps_done: Int32[Array, '']
254 """The number of MCMC steps completed so far."""
256 sparse_on_at: Int32[Array, ''] | None
257 """After how many steps to turn on variable selection."""
259 resid_num_batches: int | None = field(static=True)
260 """The number of batches for computing the sum of residuals. If
261 `None`, they are computed with no batching."""
263 count_num_batches: int | None = field(static=True)
264 """The number of batches for computing counts. If
265 `None`, they are computed with no batching."""
267 prec_num_batches: int | None = field(static=True)
268 """The number of batches for computing precision scales. If
269 `None`, they are computed with no batching."""
271 prec_count_num_trees: int | None = field(static=True)
272 """Batch size for processing trees to compute count and prec trees."""
274 mesh: Mesh | None = field(static=True)
275 """The mesh used to shard data and computation across multiple devices."""
278class State(Module):
279 """Represents the MCMC state of BART."""
281 X: UInt[Array, 'p n'] = field(data=True)
282 """The predictors."""
284 y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'] = field(
285 data=True
286 )
287 """The response. If the data type is `bool`, the model is binary regression."""
289 z: None | Float32[Array, '*chains n'] = field(chains=True, data=True)
290 """The latent variable for binary regression. `None` in continuous
291 regression."""
293 offset: Float32[Array, ''] | Float32[Array, ' k']
294 """Constant shift added to the sum of trees."""
296 resid: Float32[Array, '*chains n'] | Float32[Array, '*chains k n'] = field(
297 chains=True, data=True
298 )
299 """The residuals (`y` or `z` minus sum of trees)."""
301 error_cov_inv: Float32[Array, '*chains'] | Float32[Array, '*chains k k'] | None = (
302 field(chains=True)
303 )
304 """The inverse error covariance (scalar for univariate, matrix for multivariate).
305 `None` in binary regression."""
307 prec_scale: Float32[Array, ' n'] | None = field(data=True)
308 """The scale on the error precision, i.e., ``1 / error_scale ** 2``.
309 `None` in binary regression."""
311 error_cov_df: Float32[Array, ''] | None
312 """The df parameter of the inverse Wishart prior on the noise
313 covariance. For the univariate case, the relationship to the inverse
314 gamma prior parameters is ``alpha = df / 2``.
315 `None` in binary regression."""
317 error_cov_scale: Float32[Array, ''] | Float32[Array, 'k k'] | None
318 """The scale parameter of the inverse Wishart prior on the noise
319 covariance. For the univariate case, the relationship to the inverse
320 gamma prior parameters is ``beta = scale / 2``.
321 `None` in binary regression."""
323 forest: Forest
324 """The sum of trees model."""
326 config: StepConfig
327 """Metadata and configurations for the MCMC step."""
330def _init_shape_shifting_parameters(
331 y: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'],
332 offset: Float32[Array, ''] | Float32[Array, ' k'],
333 error_scale: Float32[Any, ' n'] | None,
334 error_cov_df: float | Float32[Any, ''] | None,
335 error_cov_scale: float | Float32[Any, ''] | Float32[Any, 'k k'] | None,
336 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
337) -> tuple[
338 bool,
339 tuple[()] | tuple[int],
340 None | Float32[Array, ''],
341 None | Float32[Array, ''],
342 None | Float32[Array, ''],
343]:
344 """
345 Check and initialize parameters that change array type/shape based on outcome kind.
347 Parameters
348 ----------
349 y
350 The response variable; the outcome type is deduced from `y` and then
351 all other parameters are checked against it.
352 offset
353 The offset to add to the predictions.
354 error_scale
355 Per-observation error scale (univariate only).
356 error_cov_df
357 The error covariance degrees of freedom.
358 error_cov_scale
359 The error covariance scale.
360 leaf_prior_cov_inv
361 The inverse of the leaf prior covariance.
363 Returns
364 -------
365 is_binary
366 Whether the outcome is binary.
367 kshape
368 The outcome shape, empty for univariate, (k,) for multivariate.
369 error_cov_inv
370 The initialized error covariance inverse.
371 error_cov_df
372 The error covariance degrees of freedom (as array).
373 error_cov_scale
374 The error covariance scale (as array).
376 Raises
377 ------
378 ValueError
379 If `y` is binary and multivariate.
380 """
381 # determine outcome kind, binary/continuous x univariate/multivariate
382 is_binary = y.dtype == bool 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
383 kshape = y.shape[:-1] 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
385 # Binary vs continuous
386 if is_binary: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
387 if kshape: 387 ↛ 388line 387 didn't jump to line 388 because the condition on line 387 was never true2ubK B D vbwbh C xbtbybzbF G H AbBbCbDbI EbA
388 msg = 'Binary multivariate regression not supported, open an issue at https://github.com/bartz-org/bartz/issues if you need it.'
389 raise ValueError(msg)
390 assert error_scale is None 2ubK B D vbwbh C xbtbybzbF G H AbBbCbDbI EbA
391 assert error_cov_df is None 2ubK B D vbwbh C xbtbybzbF G H AbBbCbDbI EbA
392 assert error_cov_scale is None 2ubK B D vbwbh C xbtbybzbF G H AbBbCbDbI EbA
393 error_cov_inv = None 2ubK B D vbwbh C xbtbybzbF G H AbBbCbDbI EbA
394 else:
395 error_cov_df = jnp.asarray(error_cov_df) 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
396 error_cov_scale = jnp.asarray(error_cov_scale) 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
397 assert error_cov_scale.shape == 2 * kshape 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
399 # Multivariate vs univariate
400 if kshape: 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
401 error_cov_inv = error_cov_df * _inv_via_chol_with_gersh(error_cov_scale) 2W 0 1 2 X 3 4 5 Y 6 7 8 n o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
402 else:
403 # inverse gamma prior: alpha = df / 2, beta = scale / 2
404 error_cov_inv = error_cov_df / error_cov_scale 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L = ? @ ; gb9 ! # $ ( ) * +
406 assert leaf_prior_cov_inv.shape == 2 * kshape 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
407 assert offset.shape == kshape 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
409 return is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
412def _parse_p_nonterminal(
413 p_nonterminal: Float32[Any, ' d_minus_1'],
414) -> Float32[Array, ' d_minus_1+1']:
415 """Check it's in (0, 1) and pad with a 0 at the end."""
416 p_nonterminal = jnp.asarray(p_nonterminal) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
417 ok = (p_nonterminal > 0) & (p_nonterminal < 1) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
418 p_nonterminal = error_if(p_nonterminal, ~ok, 'p_nonterminal must be in (0, 1)') 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
419 return jnp.pad(p_nonterminal, (0, 1)) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
422def make_p_nonterminal(
423 d: int,
424 alpha: float | Float32[Array, ''] = 0.95,
425 beta: float | Float32[Array, ''] = 2.0,
426) -> Float32[Array, ' {d}-1']:
427 """Prepare the `p_nonterminal` argument to `init`.
429 It is calculated according to the formula:
431 P_nt(depth) = alpha / (1 + depth)^beta, with depth 0-based
433 Parameters
434 ----------
435 d
436 The maximum depth of the trees (d=1 means tree with only root node)
437 alpha
438 The a priori probability of the root node having children, conditional
439 on it being possible
440 beta
441 The exponent of the power decay of the probability of having children
442 with depth.
444 Returns
445 -------
446 An array of probabilities, one per tree level but the last.
447 """
448 assert d >= 1 2{ ubibi , K z ^b_b`bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gb
449 depth = jnp.arange(d - 1) 2{ ubibi , K z ^b_b`bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gb
450 return alpha / (1 + depth).astype(float) ** beta 2{ ubibi , K z ^b_b`bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gb
453class _LazyArray(Module):
454 """Like `functools.partial` but specialized to array-creating functions like `jax.numpy.zeros`."""
456 array_creator: Callable
457 shape: tuple[int, ...]
458 args: tuple
460 def __init__(
461 self, array_creator: Callable, shape: tuple[int, ...], *args: Any
462 ) -> None:
463 self.array_creator = array_creator 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
464 self.shape = shape 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
465 self.args = args 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
467 def __call__(self, **kwargs: Any) -> T:
468 return self.array_creator(self.shape, *self.args, **kwargs) 1iuBqDmhlRCrUsVy[Fv%G']HwZSItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+
470 @property
471 def ndim(self) -> int:
472 return len(self.shape) 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
475def init(
476 *,
477 X: UInt[Any, 'p n'],
478 y: Float32[Any, ' n'] | Float32[Any, ' k n'] | Bool[Any, ' n'],
479 offset: float | Float32[Any, ''] | Float32[Any, ' k'],
480 max_split: UInt[Any, ' p'],
481 num_trees: int,
482 p_nonterminal: Float32[Any, ' d_minus_1'],
483 leaf_prior_cov_inv: float | Float32[Any, ''] | Float32[Array, 'k k'],
484 error_cov_df: float | Float32[Any, ''] | None = None,
485 error_cov_scale: float | Float32[Any, ''] | Float32[Array, 'k k'] | None = None,
486 error_scale: Float32[Any, ' n'] | None = None,
487 min_points_per_decision_node: int | Integer[Any, ''] | None = None,
488 resid_num_batches: int | None | Literal['auto'] = 'auto',
489 count_num_batches: int | None | Literal['auto'] = 'auto',
490 prec_num_batches: int | None | Literal['auto'] = 'auto',
491 prec_count_num_trees: int | None | Literal['auto'] = 'auto',
492 save_ratios: bool = False,
493 filter_splitless_vars: int = 0,
494 min_points_per_leaf: int | Integer[Any, ''] | None = None,
495 log_s: Float32[Any, ' p'] | None = None,
496 theta: float | Float32[Any, ''] | None = None,
497 a: float | Float32[Any, ''] | None = None,
498 b: float | Float32[Any, ''] | None = None,
499 rho: float | Float32[Any, ''] | None = None,
500 sparse_on_at: int | Integer[Any, ''] | None = None,
501 num_chains: int | None = None,
502 mesh: Mesh | dict[str, int] | None = None,
503 target_platform: Literal['cpu', 'gpu'] | None = None,
504) -> State:
505 """
506 Make a BART posterior sampling MCMC initial state.
508 Parameters
509 ----------
510 X
511 The predictors. Note this is trasposed compared to the usual convention.
512 y
513 The response. If the data type is `bool`, the regression model is binary
514 regression with probit. If two-dimensional, the outcome is multivariate
515 with the first axis indicating the component.
516 offset
517 Constant shift added to the sum of trees. 0 if not specified.
518 max_split
519 The maximum split index for each variable. All split ranges start at 1.
520 num_trees
521 The number of trees in the forest.
522 p_nonterminal
523 The probability of a nonterminal node at each depth. The maximum depth
524 of trees is fixed by the length of this array. Use `make_p_nonterminal`
525 to set it with the conventional formula.
526 leaf_prior_cov_inv
527 The prior precision matrix of a leaf, conditional on the tree structure.
528 For the univariate case (k=1), this is a scalar (the inverse variance).
529 The prior covariance of the sum of trees is
530 ``num_trees * leaf_prior_cov_inv^-1``. The prior mean of leaves is
531 always zero.
532 error_cov_df
533 error_cov_scale
534 The df and scale parameters of the inverse Wishart prior on the error
535 covariance. For the univariate case, the relationship to the inverse
536 gamma prior parameters is ``alpha = df / 2``, ``beta = scale / 2``.
537 Leave unspecified for binary regression.
538 error_scale
539 Each error is scaled by the corresponding factor in `error_scale`, so
540 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
541 Not supported for binary regression. If not specified, defaults to 1 for
542 all points, but potentially skipping calculations.
543 min_points_per_decision_node
544 The minimum number of data points in a decision node. 0 if not
545 specified.
546 resid_num_batches
547 count_num_batches
548 prec_num_batches
549 The number of batches, along datapoints, for summing the residuals,
550 counting the number of datapoints in each leaf, and computing the
551 likelihood precision in each leaf, respectively. `None` for no batching.
552 If 'auto', it's chosen automatically based on the target platform; see
553 the description of `target_platform` below for how it is determined.
554 prec_count_num_trees
555 The number of trees to process at a time when counting datapoints or
556 computing the likelihood precision. If `None`, do all trees at once,
557 which may use too much memory. If 'auto' (default), it's chosen
558 automatically.
559 save_ratios
560 Whether to save the Metropolis-Hastings ratios.
561 filter_splitless_vars
562 The maximum number of variables without splits that can be ignored. If
563 there are more, `init` raises an exception.
564 min_points_per_leaf
565 The minimum number of datapoints in a leaf node. 0 if not specified.
566 Unlike `min_points_per_decision_node`, this constraint is not taken into
567 account in the Metropolis-Hastings ratio because it would be expensive
568 to compute. Grow moves that would violate this constraint are vetoed.
569 This parameter is independent of `min_points_per_decision_node` and
570 there is no check that they are coherent. It makes sense to set
571 ``min_points_per_decision_node >= 2 * min_points_per_leaf``.
572 log_s
573 The logarithm of the prior probability for choosing a variable to split
574 along in a decision rule, conditional on the ancestors. Not normalized.
575 If not specified, use a uniform distribution. If not specified and
576 `theta` or `rho`, `a`, `b` are, it's initialized automatically.
577 theta
578 The concentration parameter for the Dirichlet prior on `s`. Required
579 only to update `log_s`. If not specified, and `rho`, `a`, `b` are
580 specified, it's initialized automatically.
581 a
582 b
583 rho
584 Parameters of the prior on `theta`. Required only to sample `theta`.
585 sparse_on_at
586 After how many MCMC steps to turn on variable selection.
587 num_chains
588 The number of independent MCMC chains to represent in the state. Single
589 chain with scalar values if not specified.
590 mesh
591 A jax mesh used to shard data and computation across multiple devices.
592 If it has a 'chains' axis, that axis is used to shard the chains. If it
593 has a 'data' axis, that axis is used to shard the datapoints.
595 As a shorthand, if a dictionary mapping axis names to axis size is
596 passed, the corresponding mesh is created, e.g., ``dict(chains=4,
597 data=2)`` will let jax pick 8 devices to split chains (which must be a
598 multiple of 4) across 4 pairs of devices, where in each pair the data is
599 split in two.
601 Note: if a mesh is passed, the arrays are always sharded according to
602 it. In particular even if the mesh has no 'chains' or 'data' axis, the
603 arrays will be replicated on all devices in the mesh.
604 target_platform
605 Platform ('cpu' or 'gpu') used to determine the number of batches
606 automatically. If `mesh` is specified, the platform is inferred from the
607 devices in the mesh. Otherwise, if `y` is a concrete array (i.e., `init`
608 is not invoked in a `jax.jit` context), the platform is set to the
609 platform of `y`. Otherwise, use `target_platform`.
611 To avoid confusion, in all cases where the `target_platform` argument
612 would be ignored, `init` raises an exception if `target_platform` is
613 set.
615 Returns
616 -------
617 An initialized BART MCMC state.
619 Raises
620 ------
621 ValueError
622 If `y` is boolean and arguments unused in binary regression are set.
624 Notes
625 -----
626 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
627 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
628 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
629 integers in the range ``[0, 1, ..., max_split[i]]``.
630 """
631 # convert to array all array-like arguments that are used in other
632 # configurations but don't need further processing themselves
633 X = jnp.asarray(X) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
634 y = jnp.asarray(y) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
635 offset = jnp.asarray(offset) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
636 leaf_prior_cov_inv = jnp.asarray(leaf_prior_cov_inv) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
637 max_split = jnp.asarray(max_split) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
639 # check p_nonterminal and pad it with a 0 at the end (still not final shape)
640 p_nonterminal = _parse_p_nonterminal(p_nonterminal) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
642 # process arguments that change depending on outcome type
643 is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale = ( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
644 _init_shape_shifting_parameters(
645 y, offset, error_scale, error_cov_df, error_cov_scale, leaf_prior_cov_inv
646 )
647 )
649 # extract array sizes from arguments
650 (max_depth,) = p_nonterminal.shape 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
651 p, n = X.shape 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
653 # check and initialize sparsity parameters
654 if not _all_none_or_not_none(rho, a, b): 654 ↛ 655line 654 didn't jump to line 655 because the condition on line 654 was never true2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
655 msg = 'rho, a, b are not either all `None` or all set'
656 raise ValueError(msg)
657 if theta is None and rho is not None: 2{ ubibi , K z (bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
658 theta = rho 2{ u | } ~ m R abU - V [ % ] bbcbdb^ ebS fbT k
659 if log_s is None and theta is not None: 2{ ubibi , K z (bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
660 log_s = jnp.zeros(max_split.size) 2{ u | } ~ m R abU - V [ % ] bbcbdb^ ebS fbT j k
661 if not _all_none_or_not_none(theta, sparse_on_at): 661 ↛ 662line 661 didn't jump to line 662 because the condition on line 661 was never true2{ ubibi , K z (bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
662 msg = 'sparsity params (either theta or rho,a,b) and sparse_on_at must be either all None or all set'
663 raise ValueError(msg)
665 # process multichain settings
666 chain_shape = () if num_chains is None else (num_chains,) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L FbW 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
667 resid_shape = chain_shape + y.shape 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L FbW 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
668 add_chains = partial(_add_chains, chain_shape=chain_shape) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
670 # determine batch sizes for reductions
671 mesh = _parse_mesh(num_chains, mesh) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
672 target_platform = _parse_target_platform( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
673 y, mesh, target_platform, resid_num_batches, count_num_batches, prec_num_batches
674 )
675 red_cfg = _parse_reduction_configs( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
676 resid_num_batches,
677 count_num_batches,
678 prec_num_batches,
679 prec_count_num_trees,
680 y,
681 num_trees,
682 mesh,
683 target_platform,
684 )
686 # check there aren't too many deactivated predictors
687 msg = ( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
688 f'there are more than {filter_splitless_vars=} predictors with no splits, '
689 'please increase `filter_splitless_vars` or investigate the missing splits'
690 )
691 offset = error_if(offset, jnp.sum(max_split == 0) > filter_splitless_vars, msg) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
693 # determine shapes for trees
694 tree_shape = (*chain_shape, num_trees) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
695 tree_size = 2**max_depth 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
697 # initialize all remaining stuff and put it in an unsharded state
698 state = State( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
699 X=X,
700 y=y,
701 z=_LazyArray(jnp.full, resid_shape, offset) if is_binary else None,
702 offset=offset,
703 resid=_LazyArray(jnp.zeros, resid_shape)
704 if is_binary
705 else None, # in this case, resid is created later after y and offset are sharded
706 error_cov_inv=add_chains(error_cov_inv),
707 prec_scale=error_scale, # temporarily set to error_scale, fix after sharding
708 error_cov_df=error_cov_df,
709 error_cov_scale=error_cov_scale,
710 forest=Forest(
711 leaf_tree=_LazyArray(
712 jnp.zeros, (*tree_shape, *kshape, tree_size), jnp.float32
713 ),
714 var_tree=_LazyArray(
715 jnp.zeros, (*tree_shape, tree_size // 2), minimal_unsigned_dtype(p - 1)
716 ),
717 split_tree=_LazyArray(
718 jnp.zeros, (*tree_shape, tree_size // 2), max_split.dtype
719 ),
720 affluence_tree=_LazyArray(
721 _initial_affluence_tree,
722 (*tree_shape, tree_size // 2),
723 n,
724 min_points_per_decision_node,
725 ),
726 blocked_vars=_get_blocked_vars(filter_splitless_vars, max_split),
727 max_split=max_split,
728 grow_prop_count=_LazyArray(jnp.zeros, chain_shape, int),
729 grow_acc_count=_LazyArray(jnp.zeros, chain_shape, int),
730 prune_prop_count=_LazyArray(jnp.zeros, chain_shape, int),
731 prune_acc_count=_LazyArray(jnp.zeros, chain_shape, int),
732 p_nonterminal=p_nonterminal[tree_depths(tree_size)],
733 p_propose_grow=p_nonterminal[tree_depths(tree_size // 2)],
734 leaf_indices=_LazyArray(
735 jnp.ones, (*tree_shape, n), minimal_unsigned_dtype(tree_size - 1)
736 ),
737 min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
738 min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
739 log_trans_prior=_LazyArray(jnp.zeros, (*chain_shape, num_trees))
740 if save_ratios
741 else None,
742 log_likelihood=_LazyArray(jnp.zeros, (*chain_shape, num_trees))
743 if save_ratios
744 else None,
745 leaf_prior_cov_inv=leaf_prior_cov_inv,
746 log_s=add_chains(_asarray_or_none(log_s)),
747 theta=add_chains(_asarray_or_none(theta)),
748 rho=_asarray_or_none(rho),
749 a=_asarray_or_none(a),
750 b=_asarray_or_none(b),
751 ),
752 config=StepConfig(
753 steps_done=jnp.int32(0),
754 sparse_on_at=_asarray_or_none(sparse_on_at),
755 mesh=mesh,
756 **red_cfg,
757 ),
758 )
760 # delete big input arrays such that they can be deleted as soon as they
761 # are sharded, only those arrays that contain an (n,) sized axis
762 del X, y, error_scale 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
764 # move all arrays to the appropriate device
765 state = _shard_state(state) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
767 # calculate initial resid in the continuous outcome case, such that y and
768 # offset are already sharded if needed
769 if state.resid is None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
770 resid = _LazyArray(_initial_resid, resid_shape, state.y, state.offset) 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
771 resid = _shard_leaf(resid, 0, -1, state.config.mesh) 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
772 state = replace(state, resid=resid) 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
774 # calculate prec_scale after sharding to do the calculation on the right
775 # devices
776 if state.prec_scale is not None: 2{ ubibi , K z (bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
777 prec_scale = _compute_prec_scale(state.prec_scale) 2ibz (bq jbkblbl r mbs E y v ' w nbobpb_ rbt sbx
778 state = replace(state, prec_scale=prec_scale) 2ibz (bq jbkblbl r mbs E y v ' w nbobpb_ rbt sbx
780 # make all types strong to avoid unwanted recompilations
781 return _remove_weak_types(state) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
784def _initial_resid(
785 shape: tuple[int, ...],
786 y: Float32[Array, ' n'] | Float32[Array, 'k n'],
787 offset: Float32[Array, ''] | Float32[Array, ' k'],
788) -> Float32[Array, ' n'] | Float32[Array, 'k n']:
789 """Calculate the initial value for `State.resid` in the continuous outcome case."""
790 return jnp.broadcast_to(y - offset[..., None], shape) 1uqmlRrUs[]wZStTxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ
793def _initial_affluence_tree(
794 shape: tuple[int, ...], n: int, min_points_per_decision_node: int | None
795) -> Array:
796 """Create the initial value of `Forest.affluence_tree`."""
797 return ( 1iuBqDmhlRCrVy[Fv%G']HwZSItTAxjkLWXY;nopcdNefgabJ9!#$MOPQ()*+
798 jnp.zeros(shape, bool)
799 .at[..., 1]
800 .set(
801 True
802 if min_points_per_decision_node is None
803 else n >= min_points_per_decision_node
804 )
805 )
808@partial(jit, donate_argnums=(0,))
809def _compute_prec_scale(error_scale: Float32[Array, ' n']) -> Float32[Array, ' n']:
810 """Compute 1 / error_scale**2.
812 This is a separate function to use donate_argnums to avoid intermediate
813 copies.
814 """
815 return jnp.reciprocal(jnp.square(error_scale)) 2(bq l r s v w t x
818def _get_blocked_vars(
819 filter_splitless_vars: int, max_split: UInt[Array, ' p']
820) -> None | UInt[Array, ' q']:
821 """Initialize the `blocked_vars` field."""
822 if filter_splitless_vars: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
823 (p,) = max_split.shape 1iAL
824 (blocked_vars,) = jnp.nonzero( 1iAL
825 max_split == 0, size=filter_splitless_vars, fill_value=p
826 )
827 return blocked_vars.astype(minimal_unsigned_dtype(p)) 1iAL
828 # see `fully_used_variables` for the type cast
829 else:
830 return None 2{ ubib, K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT x j k W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
833def _add_chains(
834 x: Shaped[Array, '*shape'] | None, chain_shape: tuple[int, ...]
835) -> Shaped[Array, '*shape'] | Shaped[Array, ' num_chains *shape'] | None:
836 """Broadcast `x` to all chains."""
837 if x is None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
838 return None 2ubi , K z B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
839 else:
840 return jnp.broadcast_to(x, chain_shape + x.shape) 2{ ibi , K z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
843def _parse_mesh(
844 num_chains: int | None, mesh: Mesh | dict[str, int] | None
845) -> Mesh | None:
846 """Parse the `mesh` argument."""
847 if mesh is None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
848 return None 2ubK B D vbwbm h l C xbtbybzb[ F G ] H AbBbCbDbZ . / qbI EbA W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p N ` hb9 ! # $ M O P Q ( ) * +
850 # convert dict format to actual mesh
851 if isinstance(mesh, dict): 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
852 assert set(mesh).issubset({'chains', 'data'}) 1cdefgabJ
853 mesh = make_mesh( 1cdefgabJ
854 tuple(mesh.values()), tuple(mesh), axis_types=(AxisType.Auto,) * len(mesh)
855 )
857 # check there's no chain mesh axis if there are no chains
858 if num_chains is None: 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
859 assert 'chains' not in mesh.axis_names 2{ , | } ~ R abU - V % bbcbdb^ ebS fbT J
861 # check the axes we use are in auto mode
862 assert 'chains' not in mesh.axis_names or 'chains' in _auto_axes(mesh) 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
863 assert 'data' not in mesh.axis_names or 'data' in _auto_axes(mesh) 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
865 return mesh 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
868def _parse_target_platform(
869 y: Array,
870 mesh: Mesh | None,
871 target_platform: Literal['cpu', 'gpu'] | None,
872 resid_num_batches: int | None | Literal['auto'],
873 count_num_batches: int | None | Literal['auto'],
874 prec_num_batches: int | None | Literal['auto'],
875) -> Literal['cpu', 'gpu'] | None:
876 if mesh is not None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
877 assert target_platform is None, 'mesh provided, do not set target_platform' 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
878 return mesh.devices.flat[0].platform 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
879 elif hasattr(y, 'platform'): 2ubK B D vbwbm h l C xbtbybzb[ F G ] H AbBbCbDbZ . / qbI EbA W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p N ` hb9 ! # $ M O P Q ( ) * +
880 assert target_platform is None, 'device inferred from y, unset target_platform' 2ubK B D vbwbm h l C xbtbybzb[ F G ] H AbBbCbDbZ . / qbI EbA n o p N ` hb9 ! # $ M O P Q ( ) * +
881 return y.platform() 2ubK B D vbwbm h l C xbtbybzb[ F G ] H AbBbCbDbZ . / qbI EbA n o p N ` hb9 ! # $ M O P Q ( ) * +
882 elif ( 2tbs ;
883 resid_num_batches == 'auto'
884 or count_num_batches == 'auto'
885 or prec_num_batches == 'auto'
886 ):
887 assert target_platform in ('cpu', 'gpu') 2tbW 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gb
888 return target_platform 2tbW 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gb
889 else:
890 assert target_platform is None, 'target_platform not used, unset it' 1s
891 return target_platform 1s
894def _auto_axes(mesh: Mesh) -> list[str]:
895 """Re-implement `Mesh.auto_axes` because that's missing in jax v0.5."""
896 # Mesh.auto_axes added in jax v0.6.0
897 return [ 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
898 n
899 for n, t in zip(mesh.axis_names, mesh.axis_types, strict=True)
900 if t == AxisType.Auto
901 ]
904@partial(filter_jit, donate='all')
905# jit and donate because otherwise type conversion would create copies
906def _remove_weak_types(x: PyTree[Array, 'T']) -> PyTree[Array, 'T']:
907 """Make all types strong.
909 This is to avoid recompilation in `run_mcmc` or `step`.
910 """
912 def remove_weak(x: T) -> T: 1i,KzuBqDmhlRCrUs-EVy[Fv%G']HwZ./SItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+
913 if isinstance(x, Array) and x.weak_type: 1i,KzuBqDmhlRCrUs-EVy[Fv%G']HwZ./SItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+
914 return x.astype(x.dtype) 1i,KzuBqDmhlRCrUs-EVy[Fv%G']HwZ./SItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+
915 else:
916 return x 1i,KzuBqDmhlRCrUs-EVy[Fv%G']HwZ./SItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+
918 return tree.map(remove_weak, x) 1i,KzuBqDmhlRCrUs-EVy[Fv%G']HwZ./SItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+
921def _shard_state(state: State) -> State:
922 """Place all arrays on the appropriate devices, and instantiate lazily defined arrays."""
923 mesh = state.config.mesh 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
924 shard_leaf = partial(_shard_leaf, mesh=mesh) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
925 return tree.map( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
926 shard_leaf,
927 state,
928 chain_vmap_axes(state),
929 data_vmap_axes(state),
930 is_leaf=lambda x: x is None or isinstance(x, _LazyArray),
931 )
934def _shard_leaf(
935 x: Array | None | _LazyArray,
936 chain_axis: int | None,
937 data_axis: int | None,
938 mesh: Mesh | None,
939) -> Array | None:
940 """Create `x` if it's lazy and shard it."""
941 if x is None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
942 return None 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
944 if mesh is None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
945 sharding = None 2ubK B D vbwbm h l C xbtbybzb[ F G ] H AbBbCbDbZ . / qbI EbA W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p N ` hb9 ! # $ M O P Q ( ) * +
946 else:
947 spec = [None] * x.ndim 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
948 if chain_axis is not None and 'chains' in mesh.axis_names: 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
949 spec[chain_axis] = 'chains' 2ibi z u q jbkblbm h l r mbs E y v ' w nbobpb_ rbt sbx j k L c d a b
950 if data_axis is not None and 'data' in mesh.axis_names: 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
951 spec[data_axis] = 'data' 2{ , u | } ~ m h l R abU - V % bbcbdb^ ebS fbT e f g a b J
953 # remove trailing Nones to be consistent with jax's output, it's useful
954 # for comparing shardings during debugging
955 while spec and spec[-1] is None: 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
956 spec.pop() 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
958 spec = PartitionSpec(*spec) 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
959 sharding = NamedSharding(mesh, spec) 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
961 if isinstance(x, _LazyArray): 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
962 x = _concretize_lazy_array(x, sharding) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
963 elif sharding is not None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
964 x = device_put(x, sharding, donate=True) 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J
966 return x 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
969@filter_jit
970# jit such that in recent jax versions the shards are created on the right
971# devices immediately instead of being created on the wrong device and then
972# copied
973def _concretize_lazy_array(x: _LazyArray, sharding: NamedSharding | None) -> Array:
974 """Create an array from an abstract spec on the appropriate devices."""
975 x = x() 1iuBqDmhlRCrUsVy[Fv%G']HwZSItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+
976 if sharding is not None: 1iuBqDmhlRCrUsVy[Fv%G']HwZSItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+
977 x = lax.with_sharding_constraint(x, sharding) 1iuqmhlRrUsVyv%'wStTxjkLcdefgabJ
978 return x 1iuBqDmhlRCrUsVy[Fv%G']HwZSItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+
981def _all_none_or_not_none(*args: object) -> bool:
982 is_none = [x is None for x in args] 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
983 return all(is_none) or not any(is_none) 2{ ubibi , K z (bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
986def _asarray_or_none(x: object) -> Array | None:
987 if x is None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
988 return None 2ubibi , K z B q D jbvbkbwblbh l R C r xbmbtbs - ybE V zby [ F v G ' ] H w AbnbBbobCbpbDb_ Z . / qbrbI t EbsbT A x j L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
989 return jnp.asarray(x) 2{ ibi , K z u B q | jb} kb~ lbm l R C r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gb
992def _get_platform(mesh: Mesh | None) -> str:
993 if mesh is None:
994 return get_default_device().platform
995 else:
996 return mesh.devices.flat[0].platform
999class _ReductionConfig(TypedDict):
1000 """Fields of `StepConfig` related to reductions."""
1002 resid_num_batches: int | None
1003 count_num_batches: int | None
1004 prec_num_batches: int | None
1005 prec_count_num_trees: int | None
1008def _parse_reduction_configs(
1009 resid_num_batches: int | None | Literal['auto'],
1010 count_num_batches: int | None | Literal['auto'],
1011 prec_num_batches: int | None | Literal['auto'],
1012 prec_count_num_trees: int | None | Literal['auto'],
1013 y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'],
1014 num_trees: int,
1015 mesh: Mesh | None,
1016 target_platform: Literal['cpu', 'gpu'] | None,
1017) -> _ReductionConfig:
1018 """Determine settings for indexed reduces."""
1019 n = y.shape[-1] 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1020 n //= get_axis_size(mesh, 'data') # per-device datapoints 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1021 parse_num_batches = partial(_parse_num_batches, target_platform, n) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1022 return dict( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1023 resid_num_batches=parse_num_batches(resid_num_batches, 'resid'),
1024 count_num_batches=parse_num_batches(count_num_batches, 'count'),
1025 prec_num_batches=parse_num_batches(prec_num_batches, 'prec'),
1026 prec_count_num_trees=_parse_prec_count_num_trees(
1027 prec_count_num_trees, num_trees, n
1028 ),
1029 )
1032def _parse_num_batches(
1033 target_platform: Literal['cpu', 'gpu'] | None,
1034 n: int,
1035 num_batches: int | None | Literal['auto'],
1036 which: Literal['resid', 'count', 'prec'],
1037) -> int | None:
1038 """Return the number of batches or determine it automatically."""
1039 final_round = partial(_final_round, n) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1040 if num_batches != 'auto': 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1041 nb = num_batches 2{ ib, z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ ebrbS t fbsbT x ` 9 ! # $ M O P Q ( ) * +
1042 elif target_platform == 'cpu': 1042 ↛ 1044line 1042 didn't jump to line 1044 because the condition on line 1042 was always true2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1043 nb = final_round(16) 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1044 elif target_platform == 'gpu':
1045 nb = dict(resid=1024, count=2048, prec=1024)[which] # on an A4000
1046 nb = final_round(nb)
1047 return nb 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1050def _final_round(n: int, num: float) -> int | None:
1051 """Bound batch size, round number of batches to a power of 2, and disable batching if there's only 1 batch."""
1052 # at least some elements per batch
1053 num = min(n // 32, num) 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1055 # round to the nearest power of 2 because I guess XLA and the hardware
1056 # will like that (not sure about this, maybe just multiple of 32?)
1057 num = 2 ** round(log2(num)) if num else 0 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1059 # disable batching if the batch is as large as the whole dataset
1060 return num if num > 1 else None 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1063def _parse_prec_count_num_trees(
1064 prec_count_num_trees: int | None | Literal['auto'], num_trees: int, n: int
1065) -> int | None:
1066 """Return the number of trees to process at a time or determine it automatically."""
1067 if prec_count_num_trees != 'auto': 2{ ubibi , K z (bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1068 return prec_count_num_trees 2{ ib, z (bu q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ ebrbS t fbsbT x
1069 max_n_by_ntree = 2**27 # about 100M 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1070 pcnt = max_n_by_ntree // max(1, n) 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1071 pcnt = min(num_trees, pcnt) 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1072 pcnt = max(1, pcnt) 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1073 pcnt = _search_divisor( 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1074 pcnt, num_trees, max(1, pcnt // 2), max(1, min(num_trees, pcnt * 2))
1075 )
1076 if pcnt >= num_trees: 1076 ↛ 1078line 1076 didn't jump to line 1078 because the condition on line 1076 was always true2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1077 pcnt = None 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1078 return pcnt 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1081def _search_divisor(target_divisor: int, dividend: int, low: int, up: int) -> int:
1082 """Find the divisor closest to `target_divisor` in [low, up] if `target_divisor` is not already.
1084 If there is none, give up and return `target_divisor`.
1085 """
1086 assert target_divisor >= 1 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1087 assert 1 <= low <= up <= dividend 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1088 if dividend % target_divisor == 0: 1088 ↛ 1090line 1088 didn't jump to line 1090 because the condition on line 1088 was always true2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1089 return target_divisor 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1090 candidates = numpy.arange(low, up + 1)
1091 divisors = candidates[dividend % candidates == 0]
1092 if divisors.size == 0:
1093 return target_divisor
1094 penalty = numpy.abs(divisors - target_divisor)
1095 closest = numpy.argmin(penalty)
1096 return divisors[closest].item()
1099def get_axis_size(mesh: Mesh | None, axis_name: str) -> int:
1100 if mesh is None or axis_name not in mesh.axis_names: 2{ ubibi , K z |b}b~b{bac]bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1101 return 1 2ubibi , K z |b}b~bac]bu B q D jbvbkbwblbm h l R C r xbmbU tbs ybE zby [ F v % G ' ] H w AbnbBbobCbpb^ Db_ Z . / qbrbS I t EbsbA x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N ` hb9 ! # $ M O P Q ( ) * +
1102 else:
1103 i = mesh.axis_names.index(axis_name) 2{ , z {b]bu q | } ~ m h l R r abU - V v % ' bbcbdb^ _ ebS t fbT e f g a b J
1104 return mesh.axis_sizes[i] 2{ , z {b]bu q | } ~ m h l R r abU - V v % ' bbcbdb^ _ ebS t fbT e f g a b J
1107def chol_with_gersh(
1108 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool = False
1109) -> Float32[Array, '*batch_shape k k']:
1110 """Cholesky with Gershgorin stabilization, supports batching."""
1111 return _chol_with_gersh_impl(mat, absolute_eps) 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b
1114@partial(jnp.vectorize, signature='(k,k)->(k,k)', excluded=(1,))
1115def _chol_with_gersh_impl(
1116 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool
1117) -> Float32[Array, '*batch_shape k k']:
1118 rho = jnp.max(jnp.sum(jnp.abs(mat), axis=1), initial=0.0) 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b
1119 eps = jnp.finfo(mat.dtype).eps 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b
1120 u = mat.shape[0] * rho * eps 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b
1121 if absolute_eps: 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b
1122 u += eps 2FbGbHbIbRbSbn c d N e a b J )bM *b+b,b-b.b/b
1123 mat = mat.at[jnp.diag_indices_from(mat)].add(u) 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b
1124 return jnp.linalg.cholesky(mat) 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b
1127def _inv_via_chol_with_gersh(mat: Float32[Array, 'k k']) -> Float32[Array, 'k k']:
1128 """Compute matrix inverse via Cholesky with Gershgorin stabilization.
1130 DO NOT USE THIS FUNCTION UNLESS YOU REALLY NEED TO.
1131 """
1132 L = chol_with_gersh(mat) 2W 0 1 2 X 3 4 5 Y 6 7 8 n o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1133 I = jnp.eye(mat.shape[0], dtype=mat.dtype) 2W 0 1 2 X 3 4 5 Y 6 7 8 n o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1134 L_inv = solve_triangular(L, I, lower=True) 2W 0 1 2 X 3 4 5 Y 6 7 8 n o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1135 return L_inv.T @ L_inv 2W 0 1 2 X 3 4 5 Y 6 7 8 n o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +
1138def get_num_chains(x: PyTree) -> int | None:
1139 """Get the number of chains of a pytree.
1141 Find all nodes in the structure that define 'num_chains()', stopping
1142 traversal at nodes that define it. Check all values obtained invoking
1143 `num_chains` are equal, then return it.
1144 """
1145 leaves, _ = tree.flatten(x, is_leaf=lambda x: hasattr(x, 'num_chains')) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbgbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J M O P Q
1146 num_chains = [x.num_chains() for x in leaves if hasattr(x, 'num_chains')] 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbgbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J M O P Q
1147 ref = num_chains[0] 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbgbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J M O P Q
1148 assert all(c == ref for c in num_chains) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbgbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J M O P Q
1149 return ref 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbgbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J M O P Q
1152def _chain_axes_with_keys(x: PyTree) -> PyTree[int | None]:
1153 """Return `chain_vmap_axes(x)` but also set to 0 for random keys."""
1154 axes = chain_vmap_axes(x) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1156 def axis_if_key(x: object, axis: int | None) -> int | None: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1157 if is_key(x): 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1158 return 0 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1159 else:
1160 return axis 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1162 return tree.map(axis_if_key, x, axes) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1165def _get_mc_out_axes(
1166 fun: Callable[[tuple, dict], PyTree], args: PyTree, in_axes: PyTree[int | None]
1167) -> PyTree[int | None]:
1168 """Decide chain vmap axes for outputs."""
1169 vmapped_fun = vmap(fun, in_axes=in_axes) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1170 out = eval_shape(vmapped_fun, *args) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1171 return chain_vmap_axes(out) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1174def _find_mesh(x: PyTree) -> Mesh | None:
1175 """Find the mesh used for chains."""
1177 class MeshFound(Exception): 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1178 pass 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1180 def find_mesh(x: object) -> None: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1181 if isinstance(x, State): 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1182 raise MeshFound(x.config.mesh) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1184 try: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1185 tree.map(find_mesh, x, is_leaf=lambda x: isinstance(x, State)) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1186 except MeshFound as e: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1187 return e.args[0] 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1188 else:
1189 raise ValueError
1192def _split_all_keys(x: PyTree, num_chains: int) -> PyTree:
1193 """Split all random keys in `num_chains` keys."""
1194 mesh = _find_mesh(x) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1196 def split_key(x: object) -> object: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1197 if is_key(x): 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1198 x = random.split(x, num_chains) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1199 if mesh is not None and 'chains' in mesh.axis_names: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1200 x = device_put(x, NamedSharding(mesh, PartitionSpec('chains'))) 1izuqmhlrsEyvwtxjkcdab
1201 return x 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1203 return tree.map(split_key, x) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1206def vmap_chains(fun: Callable[..., T]) -> Callable[..., T]:
1207 """Apply vmap on chain axes automatically if the inputs are multichain."""
1209 @wraps(fun)
1210 def auto_vmapped_fun(*args: Any, **kwargs: Any) -> T:
1211 all_args = args, kwargs 2i , K z u B q D m h l R C r U s - E V y [ F v G ] H w Z . / S I t T A x j k UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbn o p c d N e f g a b J M O P Q
1212 num_chains = get_num_chains(all_args) 2i , K z u B q D m h l R C r U s - E V y [ F v G ] H w Z . / S I t T A x j k UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbn o p c d N e f g a b J M O P Q
1213 if num_chains is not None: 2i , K z u B q D m h l R C r U s - E V y [ F v G ] H w Z . / S I t T A x j k UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbn o p c d N e f g a b J M O P Q
1214 all_args = _split_all_keys(all_args, num_chains) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1216 def wrapped_fun(args: tuple[Any, ...], kwargs: dict[str, Any]) -> T: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1217 return fun(*args, **kwargs) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1219 mc_in_axes = _chain_axes_with_keys(all_args) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1220 mc_out_axes = _get_mc_out_axes(wrapped_fun, all_args, mc_in_axes) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1221 vmapped_fun = vmap(wrapped_fun, in_axes=mc_in_axes, out_axes=mc_out_axes) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1222 return vmapped_fun(*all_args) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b
1224 else:
1225 return fun(*args, **kwargs) 2, R U - V [ ] Z . / S T UbRbSbTbVbN J M O P Q
1227 return auto_vmapped_fun