Coverage for src / bartz / mcmcstep / _state.py: 94%

409 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/mcmcstep/_state.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Module defining the BART MCMC state and initialization.""" 

26 

27from collections.abc import Callable, Hashable 

28from dataclasses import fields, replace 

29from functools import partial, wraps 

30from math import log2 

31from typing import Any, Literal, TypedDict, TypeVar 

32 

33import numpy 

34from equinox import Module, error_if, filter_jit 

35from equinox import field as eqx_field 

36from jax import ( 

37 NamedSharding, 

38 device_put, 

39 eval_shape, 

40 jit, 

41 lax, 

42 make_mesh, 

43 random, 

44 tree, 

45 vmap, 

46) 

47from jax import numpy as jnp 

48from jax.scipy.linalg import solve_triangular 

49from jax.sharding import AxisType, Mesh, PartitionSpec 

50from jaxtyping import Array, Bool, Float32, Int32, Integer, PyTree, Shaped, UInt 

51 

52from bartz.grove import tree_depths 

53from bartz.jaxext import get_default_device, is_key, minimal_unsigned_dtype 

54 

55 

56def field(*, chains: bool = False, data: bool = False, **kwargs: Any): # noqa: ANN202 

57 """Extend `equinox.field` with two new parameters. 

58 

59 Parameters 

60 ---------- 

61 chains 

62 Whether the arrays in the field have an optional first axis that 

63 represents independent Markov chains. 

64 data 

65 Whether the last axis of the arrays in the field represent units of 

66 the data. 

67 **kwargs 

68 Other parameters passed to `equinox.field`. 

69 

70 Returns 

71 ------- 

72 A dataclass field descriptor with the special attributes in the metadata, unset if False. 

73 """ 

74 metadata = dict(kwargs.pop('metadata', {})) 

75 assert 'chains' not in metadata 

76 assert 'data' not in metadata 

77 if chains: 

78 metadata['chains'] = True 

79 if data: 

80 metadata['data'] = True 

81 return eqx_field(metadata=metadata, **kwargs) 

82 

83 

84def chain_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']: 

85 """Determine vmapping axes for chains. 

86 

87 This function determines the argument to the `in_axes` or `out_axes` 

88 parameter of `jax.vmap` to vmap over all and only the chain axes found in the 

89 pytree `x`. 

90 

91 Parameters 

92 ---------- 

93 x 

94 A pytree. Subpytrees that are Module attributes marked with 

95 ``field(..., chains=True)`` are considered to have a leading chain axis. 

96 

97 Returns 

98 ------- 

99 A pytree with the same structure as `x` with 0 or None in the leaves. 

100 """ 

101 return _find_metadata(x, 'chains', 0, None) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

102 

103 

104def data_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']: 

105 """Determine vmapping axes for data. 

106 

107 This is analogous to `chain_vmap_axes` but returns -1 for all fields 

108 marked with ``field(..., data=True)``. 

109 """ 

110 return _find_metadata(x, 'data', -1, None) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

111 

112 

113T = TypeVar('T') 

114 

115 

116def _find_metadata( 

117 x: PyTree[Any, ' S'], key: Hashable, if_true: T, if_false: T 

118) -> PyTree[T, ' S']: 

119 """Replace all subtrees of x marked with a metadata key.""" 

120 

121 def is_lazy_array(x: object) -> bool: 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

122 return isinstance(x, _LazyArray) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

123 

124 def is_module(x: object) -> bool: 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

125 return isinstance(x, Module) and not is_lazy_array(x) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

126 

127 if is_module(x): 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

128 args = [] 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

129 for f in fields(x): 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

130 v = getattr(x, f.name) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

131 if f.metadata.get('static', False): 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

132 args.append(v) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L FbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

133 elif f.metadata.get(key, False): 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

134 subtree = tree.map(lambda _: if_true, v, is_leaf=is_lazy_array) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

135 args.append(subtree) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

136 else: 

137 args.append(_find_metadata(v, key, if_true, if_false)) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

138 return x.__class__(*args) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

139 

140 def get_axes(x: object) -> PyTree[T]: 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

141 if is_module(x): 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

142 return _find_metadata(x, key, if_true, if_false) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

143 else: 

144 return tree.map(lambda _: if_false, x, is_leaf=is_lazy_array) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

145 

146 def is_leaf(x: object) -> bool: 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

147 return isinstance(x, Module) # this catches _LazyArray as well 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

148 

149 return tree.map(get_axes, x, is_leaf=is_leaf) 2{ ubibi , K z $b%b'bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbW Gb0 Hb1 Ob= Jb2 KbX Ib3 Pb? Lb4 Mb5 NbY Qb@ Rb6 Sb7 Tb8 Vb; gbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

150 

151 

152class Forest(Module): 

153 """Represents the MCMC state of a sum of trees.""" 

154 

155 leaf_tree: ( 

156 Float32[Array, '*chains num_trees 2**d'] 

157 | Float32[Array, '*chains num_trees k 2**d'] 

158 ) = field(chains=True) 

159 """The leaf values.""" 

160 

161 var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

162 """The decision axes.""" 

163 

164 split_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

165 """The decision boundaries.""" 

166 

167 affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

168 """Marks leaves that can be grown.""" 

169 

170 max_split: UInt[Array, ' p'] 

171 """The maximum split index for each predictor.""" 

172 

173 blocked_vars: UInt[Array, ' q'] | None 

174 """Indices of variables that are not used. This shall include at least 

175 the `i` such that ``max_split[i] == 0``, otherwise behavior is 

176 undefined.""" 

177 

178 p_nonterminal: Float32[Array, ' 2**d'] 

179 """The prior probability of each node being nonterminal, conditional on 

180 its ancestors. Includes the nodes at maximum depth which should be set 

181 to 0.""" 

182 

183 p_propose_grow: Float32[Array, ' 2**(d-1)'] 

184 """The unnormalized probability of picking a leaf for a grow proposal.""" 

185 

186 leaf_indices: UInt[Array, '*chains num_trees n'] = field(chains=True, data=True) 

187 """The index of the leaf each datapoints falls into, for each tree.""" 

188 

189 min_points_per_decision_node: Int32[Array, ''] | None 

190 """The minimum number of data points in a decision node.""" 

191 

192 min_points_per_leaf: Int32[Array, ''] | None 

193 """The minimum number of data points in a leaf node.""" 

194 

195 log_trans_prior: Float32[Array, '*chains num_trees'] | None = field(chains=True) 

196 """The log transition and prior Metropolis-Hastings ratio for the 

197 proposed move on each tree.""" 

198 

199 log_likelihood: Float32[Array, '*chains num_trees'] | None = field(chains=True) 

200 """The log likelihood ratio.""" 

201 

202 grow_prop_count: Int32[Array, '*chains'] = field(chains=True) 

203 """The number of grow proposals made during one full MCMC cycle.""" 

204 

205 prune_prop_count: Int32[Array, '*chains'] = field(chains=True) 

206 """The number of prune proposals made during one full MCMC cycle.""" 

207 

208 grow_acc_count: Int32[Array, '*chains'] = field(chains=True) 

209 """The number of grow moves accepted during one full MCMC cycle.""" 

210 

211 prune_acc_count: Int32[Array, '*chains'] = field(chains=True) 

212 """The number of prune moves accepted during one full MCMC cycle.""" 

213 

214 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None 

215 """The prior precision matrix of a leaf, conditional on the tree structure. 

216 For the univariate case (k=1), this is a scalar (the inverse variance). 

217 The prior covariance of the sum of trees is 

218 ``num_trees * leaf_prior_cov_inv^-1``.""" 

219 

220 log_s: Float32[Array, '*chains p'] | None = field(chains=True) 

221 """The logarithm of the prior probability for choosing a variable to split 

222 along in a decision rule, conditional on the ancestors. Not normalized. 

223 If `None`, use a uniform distribution.""" 

224 

225 theta: Float32[Array, '*chains'] | None = field(chains=True) 

226 """The concentration parameter for the Dirichlet prior on the variable 

227 distribution `s`. Required only to update `log_s`.""" 

228 

229 a: Float32[Array, ''] | None 

230 """Parameter of the prior on `theta`. Required only to sample `theta`. 

231 See `step_theta`.""" 

232 

233 b: Float32[Array, ''] | None 

234 """Parameter of the prior on `theta`. Required only to sample `theta`. 

235 See `step_theta`.""" 

236 

237 rho: Float32[Array, ''] | None 

238 """Parameter of the prior on `theta`. Required only to sample `theta`. 

239 See `step_theta`.""" 

240 

241 def num_chains(self) -> int | None: 

242 """Return the number of chains, or `None` if not multichain.""" 

243 # maybe this should be replaced by chain_shape() -> () | (int,) 

244 if self.var_tree.ndim == 2: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbgbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J M O P Q

245 return None 2{ , | } ~ R abU - V [ % ] bbcbdb^ Z . / qbebS fbT UbRbSbTbVbgb8b9b!b#bN J M O P Q

246 else: 

247 return self.var_tree.shape[0] 2ubibi K z u B q D jbvbkbwblbm h l C r xbmbtbs ybE zby F v G ' H w AbnbBbobCbpbDb_ rbI t EbsbA x j k L FbGbHbObJbKbIbPbLbMbNbQbWbXbYbZb0b1b2b3b4b5b6b7bn o p c d e f g a b

248 

249 

250class StepConfig(Module): 

251 """Options for the MCMC step.""" 

252 

253 steps_done: Int32[Array, ''] 

254 """The number of MCMC steps completed so far.""" 

255 

256 sparse_on_at: Int32[Array, ''] | None 

257 """After how many steps to turn on variable selection.""" 

258 

259 resid_num_batches: int | None = field(static=True) 

260 """The number of batches for computing the sum of residuals. If 

261 `None`, they are computed with no batching.""" 

262 

263 count_num_batches: int | None = field(static=True) 

264 """The number of batches for computing counts. If 

265 `None`, they are computed with no batching.""" 

266 

267 prec_num_batches: int | None = field(static=True) 

268 """The number of batches for computing precision scales. If 

269 `None`, they are computed with no batching.""" 

270 

271 prec_count_num_trees: int | None = field(static=True) 

272 """Batch size for processing trees to compute count and prec trees.""" 

273 

274 mesh: Mesh | None = field(static=True) 

275 """The mesh used to shard data and computation across multiple devices.""" 

276 

277 

278class State(Module): 

279 """Represents the MCMC state of BART.""" 

280 

281 X: UInt[Array, 'p n'] = field(data=True) 

282 """The predictors.""" 

283 

284 y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'] = field( 

285 data=True 

286 ) 

287 """The response. If the data type is `bool`, the model is binary regression.""" 

288 

289 z: None | Float32[Array, '*chains n'] = field(chains=True, data=True) 

290 """The latent variable for binary regression. `None` in continuous 

291 regression.""" 

292 

293 offset: Float32[Array, ''] | Float32[Array, ' k'] 

294 """Constant shift added to the sum of trees.""" 

295 

296 resid: Float32[Array, '*chains n'] | Float32[Array, '*chains k n'] = field( 

297 chains=True, data=True 

298 ) 

299 """The residuals (`y` or `z` minus sum of trees).""" 

300 

301 error_cov_inv: Float32[Array, '*chains'] | Float32[Array, '*chains k k'] | None = ( 

302 field(chains=True) 

303 ) 

304 """The inverse error covariance (scalar for univariate, matrix for multivariate). 

305 `None` in binary regression.""" 

306 

307 prec_scale: Float32[Array, ' n'] | None = field(data=True) 

308 """The scale on the error precision, i.e., ``1 / error_scale ** 2``. 

309 `None` in binary regression.""" 

310 

311 error_cov_df: Float32[Array, ''] | None 

312 """The df parameter of the inverse Wishart prior on the noise 

313 covariance. For the univariate case, the relationship to the inverse 

314 gamma prior parameters is ``alpha = df / 2``. 

315 `None` in binary regression.""" 

316 

317 error_cov_scale: Float32[Array, ''] | Float32[Array, 'k k'] | None 

318 """The scale parameter of the inverse Wishart prior on the noise 

319 covariance. For the univariate case, the relationship to the inverse 

320 gamma prior parameters is ``beta = scale / 2``. 

321 `None` in binary regression.""" 

322 

323 forest: Forest 

324 """The sum of trees model.""" 

325 

326 config: StepConfig 

327 """Metadata and configurations for the MCMC step.""" 

328 

329 

330def _init_shape_shifting_parameters( 

331 y: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'], 

332 offset: Float32[Array, ''] | Float32[Array, ' k'], 

333 error_scale: Float32[Any, ' n'] | None, 

334 error_cov_df: float | Float32[Any, ''] | None, 

335 error_cov_scale: float | Float32[Any, ''] | Float32[Any, 'k k'] | None, 

336 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

337) -> tuple[ 

338 bool, 

339 tuple[()] | tuple[int], 

340 None | Float32[Array, ''], 

341 None | Float32[Array, ''], 

342 None | Float32[Array, ''], 

343]: 

344 """ 

345 Check and initialize parameters that change array type/shape based on outcome kind. 

346 

347 Parameters 

348 ---------- 

349 y 

350 The response variable; the outcome type is deduced from `y` and then 

351 all other parameters are checked against it. 

352 offset 

353 The offset to add to the predictions. 

354 error_scale 

355 Per-observation error scale (univariate only). 

356 error_cov_df 

357 The error covariance degrees of freedom. 

358 error_cov_scale 

359 The error covariance scale. 

360 leaf_prior_cov_inv 

361 The inverse of the leaf prior covariance. 

362 

363 Returns 

364 ------- 

365 is_binary 

366 Whether the outcome is binary. 

367 kshape 

368 The outcome shape, empty for univariate, (k,) for multivariate. 

369 error_cov_inv 

370 The initialized error covariance inverse. 

371 error_cov_df 

372 The error covariance degrees of freedom (as array). 

373 error_cov_scale 

374 The error covariance scale (as array). 

375 

376 Raises 

377 ------ 

378 ValueError 

379 If `y` is binary and multivariate. 

380 """ 

381 # determine outcome kind, binary/continuous x univariate/multivariate 

382 is_binary = y.dtype == bool 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

383 kshape = y.shape[:-1] 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

384 

385 # Binary vs continuous 

386 if is_binary: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

387 if kshape: 387 ↛ 388line 387 didn't jump to line 388 because the condition on line 387 was never true2ubK B D vbwbh C xbtbybzbF G H AbBbCbDbI EbA

388 msg = 'Binary multivariate regression not supported, open an issue at https://github.com/bartz-org/bartz/issues if you need it.' 

389 raise ValueError(msg) 

390 assert error_scale is None 2ubK B D vbwbh C xbtbybzbF G H AbBbCbDbI EbA

391 assert error_cov_df is None 2ubK B D vbwbh C xbtbybzbF G H AbBbCbDbI EbA

392 assert error_cov_scale is None 2ubK B D vbwbh C xbtbybzbF G H AbBbCbDbI EbA

393 error_cov_inv = None 2ubK B D vbwbh C xbtbybzbF G H AbBbCbDbI EbA

394 else: 

395 error_cov_df = jnp.asarray(error_cov_df) 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

396 error_cov_scale = jnp.asarray(error_cov_scale) 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

397 assert error_cov_scale.shape == 2 * kshape 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

398 

399 # Multivariate vs univariate 

400 if kshape: 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

401 error_cov_inv = error_cov_df * _inv_via_chol_with_gersh(error_cov_scale) 2W 0 1 2 X 3 4 5 Y 6 7 8 n o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

402 else: 

403 # inverse gamma prior: alpha = df / 2, beta = scale / 2 

404 error_cov_inv = error_cov_df / error_cov_scale 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L = ? @ ; gb9 ! # $ ( ) * +

405 

406 assert leaf_prior_cov_inv.shape == 2 * kshape 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

407 assert offset.shape == kshape 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

408 

409 return is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

410 

411 

412def _parse_p_nonterminal( 

413 p_nonterminal: Float32[Any, ' d_minus_1'], 

414) -> Float32[Array, ' d_minus_1+1']: 

415 """Check it's in (0, 1) and pad with a 0 at the end.""" 

416 p_nonterminal = jnp.asarray(p_nonterminal) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

417 ok = (p_nonterminal > 0) & (p_nonterminal < 1) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

418 p_nonterminal = error_if(p_nonterminal, ~ok, 'p_nonterminal must be in (0, 1)') 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

419 return jnp.pad(p_nonterminal, (0, 1)) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

420 

421 

422def make_p_nonterminal( 

423 d: int, 

424 alpha: float | Float32[Array, ''] = 0.95, 

425 beta: float | Float32[Array, ''] = 2.0, 

426) -> Float32[Array, ' {d}-1']: 

427 """Prepare the `p_nonterminal` argument to `init`. 

428 

429 It is calculated according to the formula: 

430 

431 P_nt(depth) = alpha / (1 + depth)^beta, with depth 0-based 

432 

433 Parameters 

434 ---------- 

435 d 

436 The maximum depth of the trees (d=1 means tree with only root node) 

437 alpha 

438 The a priori probability of the root node having children, conditional 

439 on it being possible 

440 beta 

441 The exponent of the power decay of the probability of having children 

442 with depth. 

443 

444 Returns 

445 ------- 

446 An array of probabilities, one per tree level but the last. 

447 """ 

448 assert d >= 1 2{ ubibi , K z ^b_b`bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gb

449 depth = jnp.arange(d - 1) 2{ ubibi , K z ^b_b`bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gb

450 return alpha / (1 + depth).astype(float) ** beta 2{ ubibi , K z ^b_b`bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gb

451 

452 

453class _LazyArray(Module): 

454 """Like `functools.partial` but specialized to array-creating functions like `jax.numpy.zeros`.""" 

455 

456 array_creator: Callable 

457 shape: tuple[int, ...] 

458 args: tuple 

459 

460 def __init__( 

461 self, array_creator: Callable, shape: tuple[int, ...], *args: Any 

462 ) -> None: 

463 self.array_creator = array_creator 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

464 self.shape = shape 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

465 self.args = args 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

466 

467 def __call__(self, **kwargs: Any) -> T: 

468 return self.array_creator(self.shape, *self.args, **kwargs) 1iuBqDmhlRCrUsVy[Fv%G']HwZSItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+

469 

470 @property 

471 def ndim(self) -> int: 

472 return len(self.shape) 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

473 

474 

475def init( 

476 *, 

477 X: UInt[Any, 'p n'], 

478 y: Float32[Any, ' n'] | Float32[Any, ' k n'] | Bool[Any, ' n'], 

479 offset: float | Float32[Any, ''] | Float32[Any, ' k'], 

480 max_split: UInt[Any, ' p'], 

481 num_trees: int, 

482 p_nonterminal: Float32[Any, ' d_minus_1'], 

483 leaf_prior_cov_inv: float | Float32[Any, ''] | Float32[Array, 'k k'], 

484 error_cov_df: float | Float32[Any, ''] | None = None, 

485 error_cov_scale: float | Float32[Any, ''] | Float32[Array, 'k k'] | None = None, 

486 error_scale: Float32[Any, ' n'] | None = None, 

487 min_points_per_decision_node: int | Integer[Any, ''] | None = None, 

488 resid_num_batches: int | None | Literal['auto'] = 'auto', 

489 count_num_batches: int | None | Literal['auto'] = 'auto', 

490 prec_num_batches: int | None | Literal['auto'] = 'auto', 

491 prec_count_num_trees: int | None | Literal['auto'] = 'auto', 

492 save_ratios: bool = False, 

493 filter_splitless_vars: int = 0, 

494 min_points_per_leaf: int | Integer[Any, ''] | None = None, 

495 log_s: Float32[Any, ' p'] | None = None, 

496 theta: float | Float32[Any, ''] | None = None, 

497 a: float | Float32[Any, ''] | None = None, 

498 b: float | Float32[Any, ''] | None = None, 

499 rho: float | Float32[Any, ''] | None = None, 

500 sparse_on_at: int | Integer[Any, ''] | None = None, 

501 num_chains: int | None = None, 

502 mesh: Mesh | dict[str, int] | None = None, 

503 target_platform: Literal['cpu', 'gpu'] | None = None, 

504) -> State: 

505 """ 

506 Make a BART posterior sampling MCMC initial state. 

507 

508 Parameters 

509 ---------- 

510 X 

511 The predictors. Note this is trasposed compared to the usual convention. 

512 y 

513 The response. If the data type is `bool`, the regression model is binary 

514 regression with probit. If two-dimensional, the outcome is multivariate 

515 with the first axis indicating the component. 

516 offset 

517 Constant shift added to the sum of trees. 0 if not specified. 

518 max_split 

519 The maximum split index for each variable. All split ranges start at 1. 

520 num_trees 

521 The number of trees in the forest. 

522 p_nonterminal 

523 The probability of a nonterminal node at each depth. The maximum depth 

524 of trees is fixed by the length of this array. Use `make_p_nonterminal` 

525 to set it with the conventional formula. 

526 leaf_prior_cov_inv 

527 The prior precision matrix of a leaf, conditional on the tree structure. 

528 For the univariate case (k=1), this is a scalar (the inverse variance). 

529 The prior covariance of the sum of trees is 

530 ``num_trees * leaf_prior_cov_inv^-1``. The prior mean of leaves is 

531 always zero. 

532 error_cov_df 

533 error_cov_scale 

534 The df and scale parameters of the inverse Wishart prior on the error 

535 covariance. For the univariate case, the relationship to the inverse 

536 gamma prior parameters is ``alpha = df / 2``, ``beta = scale / 2``. 

537 Leave unspecified for binary regression. 

538 error_scale 

539 Each error is scaled by the corresponding factor in `error_scale`, so 

540 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``. 

541 Not supported for binary regression. If not specified, defaults to 1 for 

542 all points, but potentially skipping calculations. 

543 min_points_per_decision_node 

544 The minimum number of data points in a decision node. 0 if not 

545 specified. 

546 resid_num_batches 

547 count_num_batches 

548 prec_num_batches 

549 The number of batches, along datapoints, for summing the residuals, 

550 counting the number of datapoints in each leaf, and computing the 

551 likelihood precision in each leaf, respectively. `None` for no batching. 

552 If 'auto', it's chosen automatically based on the target platform; see 

553 the description of `target_platform` below for how it is determined. 

554 prec_count_num_trees 

555 The number of trees to process at a time when counting datapoints or 

556 computing the likelihood precision. If `None`, do all trees at once, 

557 which may use too much memory. If 'auto' (default), it's chosen 

558 automatically. 

559 save_ratios 

560 Whether to save the Metropolis-Hastings ratios. 

561 filter_splitless_vars 

562 The maximum number of variables without splits that can be ignored. If 

563 there are more, `init` raises an exception. 

564 min_points_per_leaf 

565 The minimum number of datapoints in a leaf node. 0 if not specified. 

566 Unlike `min_points_per_decision_node`, this constraint is not taken into 

567 account in the Metropolis-Hastings ratio because it would be expensive 

568 to compute. Grow moves that would violate this constraint are vetoed. 

569 This parameter is independent of `min_points_per_decision_node` and 

570 there is no check that they are coherent. It makes sense to set 

571 ``min_points_per_decision_node >= 2 * min_points_per_leaf``. 

572 log_s 

573 The logarithm of the prior probability for choosing a variable to split 

574 along in a decision rule, conditional on the ancestors. Not normalized. 

575 If not specified, use a uniform distribution. If not specified and 

576 `theta` or `rho`, `a`, `b` are, it's initialized automatically. 

577 theta 

578 The concentration parameter for the Dirichlet prior on `s`. Required 

579 only to update `log_s`. If not specified, and `rho`, `a`, `b` are 

580 specified, it's initialized automatically. 

581 a 

582 b 

583 rho 

584 Parameters of the prior on `theta`. Required only to sample `theta`. 

585 sparse_on_at 

586 After how many MCMC steps to turn on variable selection. 

587 num_chains 

588 The number of independent MCMC chains to represent in the state. Single 

589 chain with scalar values if not specified. 

590 mesh 

591 A jax mesh used to shard data and computation across multiple devices. 

592 If it has a 'chains' axis, that axis is used to shard the chains. If it 

593 has a 'data' axis, that axis is used to shard the datapoints. 

594 

595 As a shorthand, if a dictionary mapping axis names to axis size is 

596 passed, the corresponding mesh is created, e.g., ``dict(chains=4, 

597 data=2)`` will let jax pick 8 devices to split chains (which must be a 

598 multiple of 4) across 4 pairs of devices, where in each pair the data is 

599 split in two. 

600 

601 Note: if a mesh is passed, the arrays are always sharded according to 

602 it. In particular even if the mesh has no 'chains' or 'data' axis, the 

603 arrays will be replicated on all devices in the mesh. 

604 target_platform 

605 Platform ('cpu' or 'gpu') used to determine the number of batches 

606 automatically. If `mesh` is specified, the platform is inferred from the 

607 devices in the mesh. Otherwise, if `y` is a concrete array (i.e., `init` 

608 is not invoked in a `jax.jit` context), the platform is set to the 

609 platform of `y`. Otherwise, use `target_platform`. 

610 

611 To avoid confusion, in all cases where the `target_platform` argument 

612 would be ignored, `init` raises an exception if `target_platform` is 

613 set. 

614 

615 Returns 

616 ------- 

617 An initialized BART MCMC state. 

618 

619 Raises 

620 ------ 

621 ValueError 

622 If `y` is boolean and arguments unused in binary regression are set. 

623 

624 Notes 

625 ----- 

626 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out 

627 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left 

628 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be 

629 integers in the range ``[0, 1, ..., max_split[i]]``. 

630 """ 

631 # convert to array all array-like arguments that are used in other 

632 # configurations but don't need further processing themselves 

633 X = jnp.asarray(X) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

634 y = jnp.asarray(y) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

635 offset = jnp.asarray(offset) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

636 leaf_prior_cov_inv = jnp.asarray(leaf_prior_cov_inv) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

637 max_split = jnp.asarray(max_split) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

638 

639 # check p_nonterminal and pad it with a 0 at the end (still not final shape) 

640 p_nonterminal = _parse_p_nonterminal(p_nonterminal) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

641 

642 # process arguments that change depending on outcome type 

643 is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale = ( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

644 _init_shape_shifting_parameters( 

645 y, offset, error_scale, error_cov_df, error_cov_scale, leaf_prior_cov_inv 

646 ) 

647 ) 

648 

649 # extract array sizes from arguments 

650 (max_depth,) = p_nonterminal.shape 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

651 p, n = X.shape 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

652 

653 # check and initialize sparsity parameters 

654 if not _all_none_or_not_none(rho, a, b): 654 ↛ 655line 654 didn't jump to line 655 because the condition on line 654 was never true2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

655 msg = 'rho, a, b are not either all `None` or all set' 

656 raise ValueError(msg) 

657 if theta is None and rho is not None: 2{ ubibi , K z (bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

658 theta = rho 2{ u | } ~ m R abU - V [ % ] bbcbdb^ ebS fbT k

659 if log_s is None and theta is not None: 2{ ubibi , K z (bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

660 log_s = jnp.zeros(max_split.size) 2{ u | } ~ m R abU - V [ % ] bbcbdb^ ebS fbT j k

661 if not _all_none_or_not_none(theta, sparse_on_at): 661 ↛ 662line 661 didn't jump to line 662 because the condition on line 661 was never true2{ ubibi , K z (bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

662 msg = 'sparsity params (either theta or rho,a,b) and sparse_on_at must be either all None or all set' 

663 raise ValueError(msg) 

664 

665 # process multichain settings 

666 chain_shape = () if num_chains is None else (num_chains,) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L FbW 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

667 resid_shape = chain_shape + y.shape 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L FbW 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

668 add_chains = partial(_add_chains, chain_shape=chain_shape) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

669 

670 # determine batch sizes for reductions 

671 mesh = _parse_mesh(num_chains, mesh) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

672 target_platform = _parse_target_platform( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

673 y, mesh, target_platform, resid_num_batches, count_num_batches, prec_num_batches 

674 ) 

675 red_cfg = _parse_reduction_configs( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

676 resid_num_batches, 

677 count_num_batches, 

678 prec_num_batches, 

679 prec_count_num_trees, 

680 y, 

681 num_trees, 

682 mesh, 

683 target_platform, 

684 ) 

685 

686 # check there aren't too many deactivated predictors 

687 msg = ( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

688 f'there are more than {filter_splitless_vars=} predictors with no splits, ' 

689 'please increase `filter_splitless_vars` or investigate the missing splits' 

690 ) 

691 offset = error_if(offset, jnp.sum(max_split == 0) > filter_splitless_vars, msg) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

692 

693 # determine shapes for trees 

694 tree_shape = (*chain_shape, num_trees) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

695 tree_size = 2**max_depth 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

696 

697 # initialize all remaining stuff and put it in an unsharded state 

698 state = State( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

699 X=X, 

700 y=y, 

701 z=_LazyArray(jnp.full, resid_shape, offset) if is_binary else None, 

702 offset=offset, 

703 resid=_LazyArray(jnp.zeros, resid_shape) 

704 if is_binary 

705 else None, # in this case, resid is created later after y and offset are sharded 

706 error_cov_inv=add_chains(error_cov_inv), 

707 prec_scale=error_scale, # temporarily set to error_scale, fix after sharding 

708 error_cov_df=error_cov_df, 

709 error_cov_scale=error_cov_scale, 

710 forest=Forest( 

711 leaf_tree=_LazyArray( 

712 jnp.zeros, (*tree_shape, *kshape, tree_size), jnp.float32 

713 ), 

714 var_tree=_LazyArray( 

715 jnp.zeros, (*tree_shape, tree_size // 2), minimal_unsigned_dtype(p - 1) 

716 ), 

717 split_tree=_LazyArray( 

718 jnp.zeros, (*tree_shape, tree_size // 2), max_split.dtype 

719 ), 

720 affluence_tree=_LazyArray( 

721 _initial_affluence_tree, 

722 (*tree_shape, tree_size // 2), 

723 n, 

724 min_points_per_decision_node, 

725 ), 

726 blocked_vars=_get_blocked_vars(filter_splitless_vars, max_split), 

727 max_split=max_split, 

728 grow_prop_count=_LazyArray(jnp.zeros, chain_shape, int), 

729 grow_acc_count=_LazyArray(jnp.zeros, chain_shape, int), 

730 prune_prop_count=_LazyArray(jnp.zeros, chain_shape, int), 

731 prune_acc_count=_LazyArray(jnp.zeros, chain_shape, int), 

732 p_nonterminal=p_nonterminal[tree_depths(tree_size)], 

733 p_propose_grow=p_nonterminal[tree_depths(tree_size // 2)], 

734 leaf_indices=_LazyArray( 

735 jnp.ones, (*tree_shape, n), minimal_unsigned_dtype(tree_size - 1) 

736 ), 

737 min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node), 

738 min_points_per_leaf=_asarray_or_none(min_points_per_leaf), 

739 log_trans_prior=_LazyArray(jnp.zeros, (*chain_shape, num_trees)) 

740 if save_ratios 

741 else None, 

742 log_likelihood=_LazyArray(jnp.zeros, (*chain_shape, num_trees)) 

743 if save_ratios 

744 else None, 

745 leaf_prior_cov_inv=leaf_prior_cov_inv, 

746 log_s=add_chains(_asarray_or_none(log_s)), 

747 theta=add_chains(_asarray_or_none(theta)), 

748 rho=_asarray_or_none(rho), 

749 a=_asarray_or_none(a), 

750 b=_asarray_or_none(b), 

751 ), 

752 config=StepConfig( 

753 steps_done=jnp.int32(0), 

754 sparse_on_at=_asarray_or_none(sparse_on_at), 

755 mesh=mesh, 

756 **red_cfg, 

757 ), 

758 ) 

759 

760 # delete big input arrays such that they can be deleted as soon as they 

761 # are sharded, only those arrays that contain an (n,) sized axis 

762 del X, y, error_scale 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

763 

764 # move all arrays to the appropriate device 

765 state = _shard_state(state) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

766 

767 # calculate initial resid in the continuous outcome case, such that y and 

768 # offset are already sharded if needed 

769 if state.resid is None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

770 resid = _LazyArray(_initial_resid, resid_shape, state.y, state.offset) 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

771 resid = _shard_leaf(resid, 0, -1, state.config.mesh) 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

772 state = replace(state, resid=resid) 2{ ibi , z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

773 

774 # calculate prec_scale after sharding to do the calculation on the right 

775 # devices 

776 if state.prec_scale is not None: 2{ ubibi , K z (bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

777 prec_scale = _compute_prec_scale(state.prec_scale) 2ibz (bq jbkblbl r mbs E y v ' w nbobpb_ rbt sbx

778 state = replace(state, prec_scale=prec_scale) 2ibz (bq jbkblbl r mbs E y v ' w nbobpb_ rbt sbx

779 

780 # make all types strong to avoid unwanted recompilations 

781 return _remove_weak_types(state) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

782 

783 

784def _initial_resid( 

785 shape: tuple[int, ...], 

786 y: Float32[Array, ' n'] | Float32[Array, 'k n'], 

787 offset: Float32[Array, ''] | Float32[Array, ' k'], 

788) -> Float32[Array, ' n'] | Float32[Array, 'k n']: 

789 """Calculate the initial value for `State.resid` in the continuous outcome case.""" 

790 return jnp.broadcast_to(y - offset[..., None], shape) 1uqmlRrUs[]wZStTxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ

791 

792 

793def _initial_affluence_tree( 

794 shape: tuple[int, ...], n: int, min_points_per_decision_node: int | None 

795) -> Array: 

796 """Create the initial value of `Forest.affluence_tree`.""" 

797 return ( 1iuBqDmhlRCrVy[Fv%G']HwZSItTAxjkLWXY;nopcdNefgabJ9!#$MOPQ()*+

798 jnp.zeros(shape, bool) 

799 .at[..., 1] 

800 .set( 

801 True 

802 if min_points_per_decision_node is None 

803 else n >= min_points_per_decision_node 

804 ) 

805 ) 

806 

807 

808@partial(jit, donate_argnums=(0,)) 

809def _compute_prec_scale(error_scale: Float32[Array, ' n']) -> Float32[Array, ' n']: 

810 """Compute 1 / error_scale**2. 

811 

812 This is a separate function to use donate_argnums to avoid intermediate 

813 copies. 

814 """ 

815 return jnp.reciprocal(jnp.square(error_scale)) 2(bq l r s v w t x

816 

817 

818def _get_blocked_vars( 

819 filter_splitless_vars: int, max_split: UInt[Array, ' p'] 

820) -> None | UInt[Array, ' q']: 

821 """Initialize the `blocked_vars` field.""" 

822 if filter_splitless_vars: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

823 (p,) = max_split.shape 1iAL

824 (blocked_vars,) = jnp.nonzero( 1iAL

825 max_split == 0, size=filter_splitless_vars, fill_value=p 

826 ) 

827 return blocked_vars.astype(minimal_unsigned_dtype(p)) 1iAL

828 # see `fully_used_variables` for the type cast 

829 else: 

830 return None 2{ ubib, K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT x j k W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

831 

832 

833def _add_chains( 

834 x: Shaped[Array, '*shape'] | None, chain_shape: tuple[int, ...] 

835) -> Shaped[Array, '*shape'] | Shaped[Array, ' num_chains *shape'] | None: 

836 """Broadcast `x` to all chains.""" 

837 if x is None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

838 return None 2ubi , K z B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

839 else: 

840 return jnp.broadcast_to(x, chain_shape + x.shape) 2{ ibi , K z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ Z . / qbebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

841 

842 

843def _parse_mesh( 

844 num_chains: int | None, mesh: Mesh | dict[str, int] | None 

845) -> Mesh | None: 

846 """Parse the `mesh` argument.""" 

847 if mesh is None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

848 return None 2ubK B D vbwbm h l C xbtbybzb[ F G ] H AbBbCbDbZ . / qbI EbA W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p N ` hb9 ! # $ M O P Q ( ) * +

849 

850 # convert dict format to actual mesh 

851 if isinstance(mesh, dict): 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

852 assert set(mesh).issubset({'chains', 'data'}) 1cdefgabJ

853 mesh = make_mesh( 1cdefgabJ

854 tuple(mesh.values()), tuple(mesh), axis_types=(AxisType.Auto,) * len(mesh) 

855 ) 

856 

857 # check there's no chain mesh axis if there are no chains 

858 if num_chains is None: 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

859 assert 'chains' not in mesh.axis_names 2{ , | } ~ R abU - V % bbcbdb^ ebS fbT J

860 

861 # check the axes we use are in auto mode 

862 assert 'chains' not in mesh.axis_names or 'chains' in _auto_axes(mesh) 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

863 assert 'data' not in mesh.axis_names or 'data' in _auto_axes(mesh) 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

864 

865 return mesh 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

866 

867 

868def _parse_target_platform( 

869 y: Array, 

870 mesh: Mesh | None, 

871 target_platform: Literal['cpu', 'gpu'] | None, 

872 resid_num_batches: int | None | Literal['auto'], 

873 count_num_batches: int | None | Literal['auto'], 

874 prec_num_batches: int | None | Literal['auto'], 

875) -> Literal['cpu', 'gpu'] | None: 

876 if mesh is not None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

877 assert target_platform is None, 'mesh provided, do not set target_platform' 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

878 return mesh.devices.flat[0].platform 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

879 elif hasattr(y, 'platform'): 2ubK B D vbwbm h l C xbtbybzb[ F G ] H AbBbCbDbZ . / qbI EbA W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p N ` hb9 ! # $ M O P Q ( ) * +

880 assert target_platform is None, 'device inferred from y, unset target_platform' 2ubK B D vbwbm h l C xbtbybzb[ F G ] H AbBbCbDbZ . / qbI EbA n o p N ` hb9 ! # $ M O P Q ( ) * +

881 return y.platform() 2ubK B D vbwbm h l C xbtbybzb[ F G ] H AbBbCbDbZ . / qbI EbA n o p N ` hb9 ! # $ M O P Q ( ) * +

882 elif ( 2tbs ;

883 resid_num_batches == 'auto' 

884 or count_num_batches == 'auto' 

885 or prec_num_batches == 'auto' 

886 ): 

887 assert target_platform in ('cpu', 'gpu') 2tbW 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gb

888 return target_platform 2tbW 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gb

889 else: 

890 assert target_platform is None, 'target_platform not used, unset it' 1s

891 return target_platform 1s

892 

893 

894def _auto_axes(mesh: Mesh) -> list[str]: 

895 """Re-implement `Mesh.auto_axes` because that's missing in jax v0.5.""" 

896 # Mesh.auto_axes added in jax v0.6.0 

897 return [ 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

898 n 

899 for n, t in zip(mesh.axis_names, mesh.axis_types, strict=True) 

900 if t == AxisType.Auto 

901 ] 

902 

903 

904@partial(filter_jit, donate='all') 

905# jit and donate because otherwise type conversion would create copies 

906def _remove_weak_types(x: PyTree[Array, 'T']) -> PyTree[Array, 'T']: 

907 """Make all types strong. 

908 

909 This is to avoid recompilation in `run_mcmc` or `step`. 

910 """ 

911 

912 def remove_weak(x: T) -> T: 1i,KzuBqDmhlRCrUs-EVy[Fv%G']HwZ./SItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+

913 if isinstance(x, Array) and x.weak_type: 1i,KzuBqDmhlRCrUs-EVy[Fv%G']HwZ./SItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+

914 return x.astype(x.dtype) 1i,KzuBqDmhlRCrUs-EVy[Fv%G']HwZ./SItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+

915 else: 

916 return x 1i,KzuBqDmhlRCrUs-EVy[Fv%G']HwZ./SItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+

917 

918 return tree.map(remove_weak, x) 1i,KzuBqDmhlRCrUs-EVy[Fv%G']HwZ./SItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+

919 

920 

921def _shard_state(state: State) -> State: 

922 """Place all arrays on the appropriate devices, and instantiate lazily defined arrays.""" 

923 mesh = state.config.mesh 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

924 shard_leaf = partial(_shard_leaf, mesh=mesh) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

925 return tree.map( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

926 shard_leaf, 

927 state, 

928 chain_vmap_axes(state), 

929 data_vmap_axes(state), 

930 is_leaf=lambda x: x is None or isinstance(x, _LazyArray), 

931 ) 

932 

933 

934def _shard_leaf( 

935 x: Array | None | _LazyArray, 

936 chain_axis: int | None, 

937 data_axis: int | None, 

938 mesh: Mesh | None, 

939) -> Array | None: 

940 """Create `x` if it's lazy and shard it.""" 

941 if x is None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

942 return None 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

943 

944 if mesh is None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

945 sharding = None 2ubK B D vbwbm h l C xbtbybzb[ F G ] H AbBbCbDbZ . / qbI EbA W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p N ` hb9 ! # $ M O P Q ( ) * +

946 else: 

947 spec = [None] * x.ndim 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

948 if chain_axis is not None and 'chains' in mesh.axis_names: 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

949 spec[chain_axis] = 'chains' 2ibi z u q jbkblbm h l r mbs E y v ' w nbobpb_ rbt sbx j k L c d a b

950 if data_axis is not None and 'data' in mesh.axis_names: 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

951 spec[data_axis] = 'data' 2{ , u | } ~ m h l R abU - V % bbcbdb^ ebS fbT e f g a b J

952 

953 # remove trailing Nones to be consistent with jax's output, it's useful 

954 # for comparing shardings during debugging 

955 while spec and spec[-1] is None: 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

956 spec.pop() 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

957 

958 spec = PartitionSpec(*spec) 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

959 sharding = NamedSharding(mesh, spec) 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

960 

961 if isinstance(x, _LazyArray): 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

962 x = _concretize_lazy_array(x, sharding) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

963 elif sharding is not None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

964 x = device_put(x, sharding, donate=True) 2{ ibi , z u q | jb} kb~ lbm h l R r abmbU s - E V y v % ' w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L c d e f g a b J

965 

966 return x 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

967 

968 

969@filter_jit 

970# jit such that in recent jax versions the shards are created on the right 

971# devices immediately instead of being created on the wrong device and then 

972# copied 

973def _concretize_lazy_array(x: _LazyArray, sharding: NamedSharding | None) -> Array: 

974 """Create an array from an abstract spec on the appropriate devices.""" 

975 x = x() 1iuBqDmhlRCrUsVy[Fv%G']HwZSItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+

976 if sharding is not None: 1iuBqDmhlRCrUsVy[Fv%G']HwZSItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+

977 x = lax.with_sharding_constraint(x, sharding) 1iuqmhlRrUsVyv%'wStTxjkLcdefgabJ

978 return x 1iuBqDmhlRCrUsVy[Fv%G']HwZSItTAxjkLW01=2X3?45Y@678;nopcdNefgabJ9!#$MOPQ()*+

979 

980 

981def _all_none_or_not_none(*args: object) -> bool: 

982 is_none = [x is None for x in args] 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

983 return all(is_none) or not any(is_none) 2{ ubibi , K z (bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

984 

985 

986def _asarray_or_none(x: object) -> Array | None: 

987 if x is None: 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

988 return None 2ubibi , K z B q D jbvbkbwblbh l R C r xbmbtbs - ybE V zby [ F v G ' ] H w AbnbBbobCbpbDb_ Z . / qbrbI t EbsbT A x j L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

989 return jnp.asarray(x) 2{ ibi , K z u B q | jb} kb~ lbm l R C r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ ebrbS t fbsbT x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gb

990 

991 

992def _get_platform(mesh: Mesh | None) -> str: 

993 if mesh is None: 

994 return get_default_device().platform 

995 else: 

996 return mesh.devices.flat[0].platform 

997 

998 

999class _ReductionConfig(TypedDict): 

1000 """Fields of `StepConfig` related to reductions.""" 

1001 

1002 resid_num_batches: int | None 

1003 count_num_batches: int | None 

1004 prec_num_batches: int | None 

1005 prec_count_num_trees: int | None 

1006 

1007 

1008def _parse_reduction_configs( 

1009 resid_num_batches: int | None | Literal['auto'], 

1010 count_num_batches: int | None | Literal['auto'], 

1011 prec_num_batches: int | None | Literal['auto'], 

1012 prec_count_num_trees: int | None | Literal['auto'], 

1013 y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'], 

1014 num_trees: int, 

1015 mesh: Mesh | None, 

1016 target_platform: Literal['cpu', 'gpu'] | None, 

1017) -> _ReductionConfig: 

1018 """Determine settings for indexed reduces.""" 

1019 n = y.shape[-1] 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1020 n //= get_axis_size(mesh, 'data') # per-device datapoints 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1021 parse_num_batches = partial(_parse_num_batches, target_platform, n) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1022 return dict( 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1023 resid_num_batches=parse_num_batches(resid_num_batches, 'resid'), 

1024 count_num_batches=parse_num_batches(count_num_batches, 'count'), 

1025 prec_num_batches=parse_num_batches(prec_num_batches, 'prec'), 

1026 prec_count_num_trees=_parse_prec_count_num_trees( 

1027 prec_count_num_trees, num_trees, n 

1028 ), 

1029 ) 

1030 

1031 

1032def _parse_num_batches( 

1033 target_platform: Literal['cpu', 'gpu'] | None, 

1034 n: int, 

1035 num_batches: int | None | Literal['auto'], 

1036 which: Literal['resid', 'count', 'prec'], 

1037) -> int | None: 

1038 """Return the number of batches or determine it automatically.""" 

1039 final_round = partial(_final_round, n) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1040 if num_batches != 'auto': 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1041 nb = num_batches 2{ ib, z u q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ ebrbS t fbsbT x ` 9 ! # $ M O P Q ( ) * +

1042 elif target_platform == 'cpu': 1042 ↛ 1044line 1042 didn't jump to line 1044 because the condition on line 1042 was always true2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1043 nb = final_round(16) 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1044 elif target_platform == 'gpu': 

1045 nb = dict(resid=1024, count=2048, prec=1024)[which] # on an A4000 

1046 nb = final_round(nb) 

1047 return nb 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1048 

1049 

1050def _final_round(n: int, num: float) -> int | None: 

1051 """Bound batch size, round number of batches to a power of 2, and disable batching if there's only 1 batch.""" 

1052 # at least some elements per batch 

1053 num = min(n // 32, num) 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1054 

1055 # round to the nearest power of 2 because I guess XLA and the hardware 

1056 # will like that (not sure about this, maybe just multiple of 32?) 

1057 num = 2 ** round(log2(num)) if num else 0 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1058 

1059 # disable batching if the batch is as large as the whole dataset 

1060 return num if num > 1 else None 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1061 

1062 

1063def _parse_prec_count_num_trees( 

1064 prec_count_num_trees: int | None | Literal['auto'], num_trees: int, n: int 

1065) -> int | None: 

1066 """Return the number of trees to process at a time or determine it automatically.""" 

1067 if prec_count_num_trees != 'auto': 2{ ubibi , K z (bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1068 return prec_count_num_trees 2{ ib, z (bu q | jb} kb~ lbm l R r abmbU s - E V y [ v % ' ] w bbnbcbobdbpb^ _ ebrbS t fbsbT x

1069 max_n_by_ntree = 2**27 # about 100M 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1070 pcnt = max_n_by_ntree // max(1, n) 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1071 pcnt = min(num_trees, pcnt) 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1072 pcnt = max(1, pcnt) 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1073 pcnt = _search_divisor( 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1074 pcnt, num_trees, max(1, pcnt // 2), max(1, min(num_trees, pcnt * 2)) 

1075 ) 

1076 if pcnt >= num_trees: 1076 ↛ 1078line 1076 didn't jump to line 1078 because the condition on line 1076 was always true2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1077 pcnt = None 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1078 return pcnt 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1079 

1080 

1081def _search_divisor(target_divisor: int, dividend: int, low: int, up: int) -> int: 

1082 """Find the divisor closest to `target_divisor` in [low, up] if `target_divisor` is not already. 

1083 

1084 If there is none, give up and return `target_divisor`. 

1085 """ 

1086 assert target_divisor >= 1 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1087 assert 1 <= low <= up <= dividend 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1088 if dividend % target_divisor == 0: 1088 ↛ 1090line 1088 didn't jump to line 1090 because the condition on line 1088 was always true2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1089 return target_divisor 2ubi K B D vbwbh C xbtbybzbF G H AbBbCbDbZ . / qbI EbA j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1090 candidates = numpy.arange(low, up + 1) 

1091 divisors = candidates[dividend % candidates == 0] 

1092 if divisors.size == 0: 

1093 return target_divisor 

1094 penalty = numpy.abs(divisors - target_divisor) 

1095 closest = numpy.argmin(penalty) 

1096 return divisors[closest].item() 

1097 

1098 

1099def get_axis_size(mesh: Mesh | None, axis_name: str) -> int: 

1100 if mesh is None or axis_name not in mesh.axis_names: 2{ ubibi , K z |b}b~b{bac]bu B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1101 return 1 2ubibi , K z |b}b~bac]bu B q D jbvbkbwblbm h l R C r xbmbU tbs ybE zby [ F v % G ' ] H w AbnbBbobCbpb^ Db_ Z . / qbrbS I t EbsbA x j k L W 0 1 = 2 X 3 ? 4 5 Y @ 6 7 8 ; gbn o p c d N ` hb9 ! # $ M O P Q ( ) * +

1102 else: 

1103 i = mesh.axis_names.index(axis_name) 2{ , z {b]bu q | } ~ m h l R r abU - V v % ' bbcbdb^ _ ebS t fbT e f g a b J

1104 return mesh.axis_sizes[i] 2{ , z {b]bu q | } ~ m h l R r abU - V v % ' bbcbdb^ _ ebS t fbT e f g a b J

1105 

1106 

1107def chol_with_gersh( 

1108 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool = False 

1109) -> Float32[Array, '*batch_shape k k']: 

1110 """Cholesky with Gershgorin stabilization, supports batching.""" 

1111 return _chol_with_gersh_impl(mat, absolute_eps) 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b

1112 

1113 

1114@partial(jnp.vectorize, signature='(k,k)->(k,k)', excluded=(1,)) 

1115def _chol_with_gersh_impl( 

1116 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool 

1117) -> Float32[Array, '*batch_shape k k']: 

1118 rho = jnp.max(jnp.sum(jnp.abs(mat), axis=1), initial=0.0) 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b

1119 eps = jnp.finfo(mat.dtype).eps 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b

1120 u = mat.shape[0] * rho * eps 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b

1121 if absolute_eps: 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b

1122 u += eps 2FbGbHbIbRbSbn c d N e a b J )bM *b+b,b-b.b/b

1123 mat = mat.at[jnp.diag_indices_from(mat)].add(u) 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b

1124 return jnp.linalg.cholesky(mat) 2FbW Gb0 Hb1 Jb2 KbX Ib3 Lb4 Mb5 NbY Rb6 Sb7 Tb8 n o p c d N e f g a b J ` hb9 ! # $ )bM O P Q ( ) * + :b;b=b?b@b[b*b+b,b-b.b/b

1125 

1126 

1127def _inv_via_chol_with_gersh(mat: Float32[Array, 'k k']) -> Float32[Array, 'k k']: 

1128 """Compute matrix inverse via Cholesky with Gershgorin stabilization. 

1129 

1130 DO NOT USE THIS FUNCTION UNLESS YOU REALLY NEED TO. 

1131 """ 

1132 L = chol_with_gersh(mat) 2W 0 1 2 X 3 4 5 Y 6 7 8 n o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1133 I = jnp.eye(mat.shape[0], dtype=mat.dtype) 2W 0 1 2 X 3 4 5 Y 6 7 8 n o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1134 L_inv = solve_triangular(L, I, lower=True) 2W 0 1 2 X 3 4 5 Y 6 7 8 n o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1135 return L_inv.T @ L_inv 2W 0 1 2 X 3 4 5 Y 6 7 8 n o p c d N e f g a b J ` hb9 ! # $ M O P Q ( ) * +

1136 

1137 

1138def get_num_chains(x: PyTree) -> int | None: 

1139 """Get the number of chains of a pytree. 

1140 

1141 Find all nodes in the structure that define 'num_chains()', stopping 

1142 traversal at nodes that define it. Check all values obtained invoking 

1143 `num_chains` are equal, then return it. 

1144 """ 

1145 leaves, _ = tree.flatten(x, is_leaf=lambda x: hasattr(x, 'num_chains')) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbgbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J M O P Q

1146 num_chains = [x.num_chains() for x in leaves if hasattr(x, 'num_chains')] 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbgbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J M O P Q

1147 ref = num_chains[0] 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbgbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J M O P Q

1148 assert all(c == ref for c in num_chains) 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbgbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J M O P Q

1149 return ref 2{ ubibi , K z u B q | D jb} vbkb~ wblbm h l R C r abxbmbU tbs - ybE V zby [ F v % G ' ] H w bbAbnbcbBbobdbCbpb^ Db_ Z . / qbebrbS I t fbEbsbT A x j k L UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbgbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#bn o p c d N e f g a b J M O P Q

1150 

1151 

1152def _chain_axes_with_keys(x: PyTree) -> PyTree[int | None]: 

1153 """Return `chain_vmap_axes(x)` but also set to 0 for random keys.""" 

1154 axes = chain_vmap_axes(x) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1155 

1156 def axis_if_key(x: object, axis: int | None) -> int | None: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1157 if is_key(x): 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1158 return 0 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1159 else: 

1160 return axis 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1161 

1162 return tree.map(axis_if_key, x, axes) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1163 

1164 

1165def _get_mc_out_axes( 

1166 fun: Callable[[tuple, dict], PyTree], args: PyTree, in_axes: PyTree[int | None] 

1167) -> PyTree[int | None]: 

1168 """Decide chain vmap axes for outputs.""" 

1169 vmapped_fun = vmap(fun, in_axes=in_axes) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1170 out = eval_shape(vmapped_fun, *args) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1171 return chain_vmap_axes(out) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1172 

1173 

1174def _find_mesh(x: PyTree) -> Mesh | None: 

1175 """Find the mesh used for chains.""" 

1176 

1177 class MeshFound(Exception): 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1178 pass 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1179 

1180 def find_mesh(x: object) -> None: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1181 if isinstance(x, State): 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1182 raise MeshFound(x.config.mesh) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1183 

1184 try: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1185 tree.map(find_mesh, x, is_leaf=lambda x: isinstance(x, State)) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1186 except MeshFound as e: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1187 return e.args[0] 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1188 else: 

1189 raise ValueError 

1190 

1191 

1192def _split_all_keys(x: PyTree, num_chains: int) -> PyTree: 

1193 """Split all random keys in `num_chains` keys.""" 

1194 mesh = _find_mesh(x) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1195 

1196 def split_key(x: object) -> object: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1197 if is_key(x): 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1198 x = random.split(x, num_chains) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1199 if mesh is not None and 'chains' in mesh.axis_names: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1200 x = device_put(x, NamedSharding(mesh, PartitionSpec('chains'))) 1izuqmhlrsEyvwtxjkcdab

1201 return x 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1202 

1203 return tree.map(split_key, x) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1204 

1205 

1206def vmap_chains(fun: Callable[..., T]) -> Callable[..., T]: 

1207 """Apply vmap on chain axes automatically if the inputs are multichain.""" 

1208 

1209 @wraps(fun) 

1210 def auto_vmapped_fun(*args: Any, **kwargs: Any) -> T: 

1211 all_args = args, kwargs 2i , K z u B q D m h l R C r U s - E V y [ F v G ] H w Z . / S I t T A x j k UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbn o p c d N e f g a b J M O P Q

1212 num_chains = get_num_chains(all_args) 2i , K z u B q D m h l R C r U s - E V y [ F v G ] H w Z . / S I t T A x j k UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbn o p c d N e f g a b J M O P Q

1213 if num_chains is not None: 2i , K z u B q D m h l R C r U s - E V y [ F v G ] H w Z . / S I t T A x j k UbFbGbHbObJbKbIbPbLbMbNbQbRbSbTbVbn o p c d N e f g a b J M O P Q

1214 all_args = _split_all_keys(all_args, num_chains) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1215 

1216 def wrapped_fun(args: tuple[Any, ...], kwargs: dict[str, Any]) -> T: 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1217 return fun(*args, **kwargs) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1218 

1219 mc_in_axes = _chain_axes_with_keys(all_args) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1220 mc_out_axes = _get_mc_out_axes(wrapped_fun, all_args, mc_in_axes) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1221 vmapped_fun = vmap(wrapped_fun, in_axes=mc_in_axes, out_axes=mc_out_axes) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1222 return vmapped_fun(*all_args) 2i K z u B q D m h l C r s E y F v G H w I t A x j k FbGbHbObJbKbIbPbLbMbNbQbn o p c d e f g a b

1223 

1224 else: 

1225 return fun(*args, **kwargs) 2, R U - V [ ] Z . / S T UbRbSbTbVbN J M O P Q

1226 

1227 return auto_vmapped_fun