Coverage for src / bartz / mcmcstep / _state.py: 93%

301 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-13 00:35 +0000

1# bartz/src/bartz/mcmcstep/_state.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Module defining the BART MCMC state and initialization.""" 

26 

27from collections.abc import Callable, Hashable 

28from dataclasses import Field, fields 

29from functools import partial, wraps 

30from math import ceil, log2 

31from typing import Any, Literal, TypeVar 

32 

33from equinox import Module, error_if 

34from equinox import field as eqx_field 

35from jax import NamedSharding, device_put, eval_shape, make_mesh, random, tree, vmap 

36from jax import numpy as jnp 

37from jax.scipy.linalg import solve_triangular 

38from jax.sharding import AxisType, Mesh, PartitionSpec 

39from jax.tree import flatten 

40from jaxtyping import Array, Bool, Float32, Int32, Integer, PyTree, Shaped, UInt 

41 

42from bartz.grove import make_tree, tree_depths 

43from bartz.jaxext import get_default_device, is_key, minimal_unsigned_dtype 

44 

45 

46def field(*, chains: bool = False, data: bool = False, **kwargs) -> Field: 

47 """Extend `equinox.field` with two new parameters. 

48 

49 Parameters 

50 ---------- 

51 chains 

52 Whether the arrays in the field have an optional first axis that 

53 represents independent Markov chains. 

54 data 

55 Whether the last axis of the arrays in the field represent units of 

56 the data. 

57 **kwargs 

58 Other parameters passed to `equinox.field`. 

59 

60 Returns 

61 ------- 

62 A dataclass field descriptor with the special attributes in the metadata, unset if False. 

63 """ 

64 metadata = dict(kwargs.pop('metadata', {})) 

65 assert 'chains' not in metadata 

66 assert 'data' not in metadata 

67 if chains: 

68 metadata['chains'] = True 

69 if data: 

70 metadata['data'] = True 

71 return eqx_field(metadata=metadata, **kwargs) 

72 

73 

74def chain_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']: 

75 """Determine vmapping axes for chains. 

76 

77 This function determines the argument to the `in_axes` or `out_axes` 

78 parameter of `jax.vmap` to vmap over all and only the chain axes found in the 

79 pytree `x`. 

80 

81 Parameters 

82 ---------- 

83 x 

84 A pytree. Subpytrees that are Module attributes marked with 

85 ``field(..., chains=True)`` are considered to have a leading chain axis. 

86 

87 Returns 

88 ------- 

89 A pytree with the same structure as `x` with 0 or None in the leaves. 

90 """ 

91 return _find_metadata(x, 'chains', 0, None) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

92 

93 

94def data_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']: 

95 """Determine vmapping axes for data. 

96 

97 This is analogous to `chain_vmap_axes` but returns -1 for all fields 

98 marked with ``field(..., data=True)``. 

99 """ 

100 return _find_metadata(x, 'data', -1, None) 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%-WXpqr'YsZvwxab0efgcdy(

101 

102 

103T = TypeVar('T') 

104 

105 

106def _find_metadata( 

107 x: PyTree[Any, ' S'], key: Hashable, if_true: T, if_false: T 

108) -> PyTree[T, ' S']: 

109 """Replace all subtrees of x marked with a metadata key.""" 

110 if isinstance(x, Module): 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

111 args = [] 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

112 for f in fields(x): 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

113 v = getattr(x, f.name) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

114 if f.metadata.get('static', False): 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

115 args.append(v) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 O 2 C t I k J l P D m E n 8 Q 9 R ! S # T $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbv w x a b 0 e f g c d y 1 z (

116 elif f.metadata.get(key, False): 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

117 subtree = tree.map(lambda _: if_true, v) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

118 args.append(subtree) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

119 else: 

120 args.append(_find_metadata(v, key, if_true, if_false)) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

121 return x.__class__(*args) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

122 

123 def is_leaf(x) -> bool: 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

124 return isinstance(x, Module) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

125 

126 def get_axes(x: Module | Any) -> PyTree[T]: 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

127 if isinstance(x, Module): 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

128 return _find_metadata(x, key, if_true, if_false) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

129 else: 

130 return tree.map(lambda _: if_false, x) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

131 

132 return tree.map(get_axes, x, is_leaf=is_leaf) 2K 3 h 4 L 5b6b7bu A i 5 M 6 N H B j 7 . O 2 C t I k J l P D m , E n 8 Q 9 R ! S # T ) * + $ U V F o % - W X G p q r ' Y s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z (

133 

134 

135class Forest(Module): 

136 """ 

137 Represents the MCMC state of a sum of trees. 

138 

139 Parameters 

140 ---------- 

141 leaf_tree 

142 The leaf values. 

143 var_tree 

144 The decision axes. 

145 split_tree 

146 The decision boundaries. 

147 affluence_tree 

148 Marks leaves that can be grown. 

149 max_split 

150 The maximum split index for each predictor. 

151 blocked_vars 

152 Indices of variables that are not used. This shall include at least 

153 the `i` such that ``max_split[i] == 0``, otherwise behavior is 

154 undefined. 

155 p_nonterminal 

156 The prior probability of each node being nonterminal, conditional on 

157 its ancestors. Includes the nodes at maximum depth which should be set 

158 to 0. 

159 p_propose_grow 

160 The unnormalized probability of picking a leaf for a grow proposal. 

161 leaf_indices 

162 The index of the leaf each datapoints falls into, for each tree. 

163 min_points_per_decision_node 

164 The minimum number of data points in a decision node. 

165 min_points_per_leaf 

166 The minimum number of data points in a leaf node. 

167 log_trans_prior 

168 The log transition and prior Metropolis-Hastings ratio for the 

169 proposed move on each tree. 

170 log_likelihood 

171 The log likelihood ratio. 

172 grow_prop_count 

173 prune_prop_count 

174 The number of grow/prune proposals made during one full MCMC cycle. 

175 grow_acc_count 

176 prune_acc_count 

177 The number of grow/prune moves accepted during one full MCMC cycle. 

178 leaf_prior_cov_inv 

179 The prior precision matrix of a leaf, conditional on the tree structure. 

180 For the univariate case (k=1), this is a scalar (the inverse variance). 

181 The prior covariance of the sum of trees is 

182 ``num_trees * leaf_prior_cov_inv^-1``. 

183 log_s 

184 The logarithm of the prior probability for choosing a variable to split 

185 along in a decision rule, conditional on the ancestors. Not normalized. 

186 If `None`, use a uniform distribution. 

187 theta 

188 The concentration parameter for the Dirichlet prior on the variable 

189 distribution `s`. Required only to update `s`. 

190 a 

191 b 

192 rho 

193 Parameters of the prior on `theta`. Required only to sample `theta`. 

194 See `step_theta`. 

195 """ 

196 

197 leaf_tree: ( 

198 Float32[Array, '*chains num_trees 2**d'] 

199 | Float32[Array, '*chains num_trees k 2**d'] 

200 ) = field(chains=True) 

201 var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

202 split_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

203 affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True) 

204 max_split: UInt[Array, ' p'] 

205 blocked_vars: UInt[Array, ' q'] | None 

206 p_nonterminal: Float32[Array, ' 2**d'] 

207 p_propose_grow: Float32[Array, ' 2**(d-1)'] 

208 leaf_indices: UInt[Array, '*chains num_trees n'] = field(chains=True, data=True) 

209 min_points_per_decision_node: Int32[Array, ''] | None 

210 min_points_per_leaf: Int32[Array, ''] | None 

211 log_trans_prior: Float32[Array, '*chains num_trees'] | None = field(chains=True) 

212 log_likelihood: Float32[Array, '*chains num_trees'] | None = field(chains=True) 

213 grow_prop_count: Int32[Array, '*chains'] = field(chains=True) 

214 prune_prop_count: Int32[Array, '*chains'] = field(chains=True) 

215 grow_acc_count: Int32[Array, '*chains'] = field(chains=True) 

216 prune_acc_count: Int32[Array, '*chains'] = field(chains=True) 

217 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None 

218 log_s: Float32[Array, '*chains p'] | None = field(chains=True) 

219 theta: Float32[Array, '*chains'] | None = field(chains=True) 

220 a: Float32[Array, ''] | None 

221 b: Float32[Array, ''] | None 

222 rho: Float32[Array, ''] | None 

223 

224 def num_chains(self) -> int | None: 

225 """Return the number of chains, or `None` if not multichain.""" 

226 # maybe this should be replaced by chain_shape() -> () | (int,) 

227 if self.var_tree.ndim == 2: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =

228 return None 2K h 4 u A i 5 6 H B j 7 2 C t I k J l P D m , E n 8 9 ! # ) * + fb$ V F o % X G p q r ' s zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbOb1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =

229 else: 

230 return self.var_tree.shape[0] 23 h mbL u A i nbM obN B j . O C t pbk qbl D m E n rbQ sbR tbS ubT U F o - W G p q r vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbPbQbRbSbTbUbVbWbXbYbZb0bv w x a b e f g c d z

231 

232 

233class StepConfig(Module): 

234 """Options for the MCMC step. 

235 

236 Parameters 

237 ---------- 

238 steps_done 

239 The number of MCMC steps completed so far. 

240 sparse_on_at 

241 After how many steps to turn on variable selection. 

242 resid_batch_size 

243 count_batch_size 

244 The data batch sizes for computing the sufficient statistics. If `None`, 

245 they are computed with no batching. 

246 mesh 

247 The mesh used to shard data and computation across multiple devices. 

248 """ 

249 

250 steps_done: Int32[Array, ''] 

251 sparse_on_at: Int32[Array, ''] | None 

252 resid_batch_size: int | None = field(static=True) 

253 count_batch_size: int | None = field(static=True) 

254 mesh: Mesh | None = field(static=True) 

255 

256 

257class State(Module): 

258 """ 

259 Represents the MCMC state of BART. 

260 

261 Parameters 

262 ---------- 

263 X 

264 The predictors. 

265 y 

266 The response. If the data type is `bool`, the model is binary regression. 

267 resid 

268 The residuals (`y` or `z` minus sum of trees). 

269 z 

270 The latent variable for binary regression. `None` in continuous 

271 regression. 

272 offset 

273 Constant shift added to the sum of trees. 

274 error_cov_inv 

275 The inverse error covariance (scalar for univariate, matrix for multivariate). 

276 `None` in binary regression. 

277 prec_scale 

278 The scale on the error precision, i.e., ``1 / error_scale ** 2``. 

279 `None` in binary regression. 

280 error_cov_df 

281 error_cov_scale 

282 The df and scale parameters of the inverse Wishart prior on the noise 

283 covariance. For the univariate case, the relationship to the inverse 

284 gamma prior parameters is ``alpha = df / 2``, ``beta = scale / 2``. 

285 `None` in binary regression. 

286 forest 

287 The sum of trees model. 

288 config 

289 Metadata and configurations for the MCMC step. 

290 """ 

291 

292 X: UInt[Array, 'p n'] = field(data=True) 

293 y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'] = field( 

294 data=True 

295 ) 

296 z: None | Float32[Array, '*chains n'] = field(chains=True, data=True) 

297 offset: Float32[Array, ''] | Float32[Array, ' k'] 

298 resid: Float32[Array, '*chains n'] | Float32[Array, '*chains k n'] = field( 

299 chains=True, data=True 

300 ) 

301 error_cov_inv: Float32[Array, '*chains'] | Float32[Array, '*chains k k'] | None = ( 

302 field(chains=True) 

303 ) 

304 prec_scale: Float32[Array, ' n'] | None = field(data=True) 

305 error_cov_df: Float32[Array, ''] | None 

306 error_cov_scale: Float32[Array, ''] | Float32[Array, 'k k'] | None 

307 forest: Forest 

308 config: StepConfig 

309 

310 

311def _init_shape_shifting_parameters( 

312 y: Float32[Array, ' n'] | Float32[Array, 'k n'] | Bool[Array, ' n'], 

313 offset: Float32[Array, ''] | Float32[Array, ' k'], 

314 error_scale: Float32[Any, ' n'] | None, 

315 error_cov_df: float | Float32[Any, ''] | None, 

316 error_cov_scale: float | Float32[Any, ''] | Float32[Any, 'k k'] | None, 

317 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'], 

318) -> tuple[ 

319 bool, 

320 tuple[()] | tuple[int], 

321 None | Float32[Array, ''], 

322 None | Float32[Array, ''], 

323 None | Float32[Array, ''], 

324]: 

325 """ 

326 Check and initialize parameters that change array type/shape based on outcome kind. 

327 

328 Parameters 

329 ---------- 

330 y 

331 The response variable; the outcome type is deduced from `y` and then 

332 all other parameters are checked against it. 

333 offset 

334 The offset to add to the predictions. 

335 error_scale 

336 Per-observation error scale (univariate only). 

337 error_cov_df 

338 The error covariance degrees of freedom. 

339 error_cov_scale 

340 The error covariance scale. 

341 leaf_prior_cov_inv 

342 The inverse of the leaf prior covariance. 

343 

344 Returns 

345 ------- 

346 is_binary 

347 Whether the outcome is binary. 

348 kshape 

349 The outcome shape, empty for univariate, (k,) for multivariate. 

350 error_cov_inv 

351 The initialized error covariance inverse. 

352 error_cov_df 

353 The error covariance degrees of freedom (as array). 

354 error_cov_scale 

355 The error covariance scale (as array). 

356 

357 Raises 

358 ------ 

359 ValueError 

360 If `y` is binary and multivariate. 

361 """ 

362 # determine outcome kind, binary/continuous x univariate/multivariate 

363 is_binary = y.dtype == bool 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

364 kshape = y.shape[:-1] 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

365 

366 # Binary vs continuous 

367 if is_binary: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

368 if kshape: 368 ↛ 369line 368 didn't jump to line 369 because the condition on line 368 was never true23 mbA nbobB . C pbqbD E rbsbtbubF - G vb

369 msg = 'Binary multivariate regression not supported, open an issue at https://github.com/bartz-org/bartz/issues if you need it.' 

370 raise ValueError(msg) 

371 assert error_scale is None 23 mbA nbobB . C pbqbD E rbsbtbubF - G vb

372 assert error_cov_df is None 23 mbA nbobB . C pbqbD E rbsbtbubF - G vb

373 assert error_cov_scale is None 23 mbA nbobB . C pbqbD E rbsbtbubF - G vb

374 error_cov_inv = None 23 mbA nbobB . C pbqbD E rbsbtbubF - G vb

375 else: 

376 error_cov_df = jnp.asarray(error_cov_df) 2K h 4 L u i 5 M 6 N H j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T ) * + fb$ U V o % W X p q r ' Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

377 error_cov_scale = jnp.asarray(error_cov_scale) 2K h 4 L u i 5 M 6 N H j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T ) * + fb$ U V o % W X p q r ' Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

378 assert error_cov_scale.shape == 2 * kshape 2K h 4 L u i 5 M 6 N H j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T ) * + fb$ U V o % W X p q r ' Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

379 

380 # Multivariate vs univariate 

381 if kshape: 2K h 4 L u i 5 M 6 N H j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T ) * + fb$ U V o % W X p q r ' Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

382 error_cov_inv = error_cov_df * _inv_via_chol_with_gersh(error_cov_scale) 2{ | } ~ abbbcbdbeb^ _ ` v w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

383 else: 

384 # inverse gamma prior: alpha = df / 2, beta = scale / 2 

385 error_cov_inv = error_cov_df / error_cov_scale 2K h 4 L u i 5 M 6 N H j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T ) * + fb$ U V o % W X p q r ' Y s Z wbxbyblbgbhbibjb/ : ; =

386 

387 assert leaf_prior_cov_inv.shape == 2 * kshape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

388 assert offset.shape == kshape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

389 

390 return is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

391 

392 

393def _parse_p_nonterminal( 

394 p_nonterminal: Float32[Any, ' d_minus_1'], 

395) -> Float32[Array, ' d_minus_1+1']: 

396 """Check it's in (0, 1) and pad with a 0 at the end.""" 

397 p_nonterminal = jnp.asarray(p_nonterminal) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

398 ok = (p_nonterminal > 0) & (p_nonterminal < 1) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

399 p_nonterminal = error_if(p_nonterminal, ~ok, 'p_nonterminal must be in (0, 1)') 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

400 return jnp.pad(p_nonterminal, (0, 1)) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

401 

402 

403def init( 

404 *, 

405 X: UInt[Any, 'p n'], 

406 y: Float32[Any, ' n'] | Float32[Any, ' k n'] | Bool[Any, ' n'], 

407 offset: float | Float32[Any, ''] | Float32[Any, ' k'], 

408 max_split: UInt[Any, ' p'], 

409 num_trees: int, 

410 p_nonterminal: Float32[Any, ' d_minus_1'], 

411 leaf_prior_cov_inv: float | Float32[Any, ''] | Float32[Array, 'k k'], 

412 error_cov_df: float | Float32[Any, ''] | None = None, 

413 error_cov_scale: float | Float32[Any, ''] | Float32[Array, 'k k'] | None = None, 

414 error_scale: Float32[Any, ' n'] | None = None, 

415 min_points_per_decision_node: int | Integer[Any, ''] | None = None, 

416 resid_batch_size: int | None | Literal['auto'] = 'auto', 

417 count_batch_size: int | None | Literal['auto'] = 'auto', 

418 save_ratios: bool = False, 

419 filter_splitless_vars: bool = True, 

420 min_points_per_leaf: int | Integer[Any, ''] | None = None, 

421 log_s: Float32[Any, ' p'] | None = None, 

422 theta: float | Float32[Any, ''] | None = None, 

423 a: float | Float32[Any, ''] | None = None, 

424 b: float | Float32[Any, ''] | None = None, 

425 rho: float | Float32[Any, ''] | None = None, 

426 sparse_on_at: int | Integer[Any, ''] | None = None, 

427 num_chains: int | None = None, 

428 mesh: Mesh | dict[str, int] | None = None, 

429 target_platform: Literal['cpu', 'gpu'] | None = None, 

430) -> State: 

431 """ 

432 Make a BART posterior sampling MCMC initial state. 

433 

434 Parameters 

435 ---------- 

436 X 

437 The predictors. Note this is trasposed compared to the usual convention. 

438 y 

439 The response. If the data type is `bool`, the regression model is binary 

440 regression with probit. If two-dimensional, the outcome is multivariate 

441 with the first axis indicating the component. 

442 offset 

443 Constant shift added to the sum of trees. 0 if not specified. 

444 max_split 

445 The maximum split index for each variable. All split ranges start at 1. 

446 num_trees 

447 The number of trees in the forest. 

448 p_nonterminal 

449 The probability of a nonterminal node at each depth. The maximum depth 

450 of trees is fixed by the length of this array. 

451 leaf_prior_cov_inv 

452 The prior precision matrix of a leaf, conditional on the tree structure. 

453 For the univariate case (k=1), this is a scalar (the inverse variance). 

454 The prior covariance of the sum of trees is 

455 ``num_trees * leaf_prior_cov_inv^-1``. The prior mean of leaves is 

456 always zero. 

457 error_cov_df 

458 error_cov_scale 

459 The df and scale parameters of the inverse Wishart prior on the error 

460 covariance. For the univariate case, the relationship to the inverse 

461 gamma prior parameters is ``alpha = df / 2``, ``beta = scale / 2``. 

462 Leave unspecified for binary regression. 

463 error_scale 

464 Each error is scaled by the corresponding factor in `error_scale`, so 

465 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``. 

466 Not supported for binary regression. If not specified, defaults to 1 for 

467 all points, but potentially skipping calculations. 

468 min_points_per_decision_node 

469 The minimum number of data points in a decision node. 0 if not 

470 specified. 

471 resid_batch_size 

472 count_batch_size 

473 The batch sizes, along datapoints, for summing the residuals and 

474 counting the number of datapoints in each leaf. `None` for no batching. 

475 If 'auto', it's chosen automatically based on the target platform; see 

476 the description of `target_platform` below for how it is determined. 

477 save_ratios 

478 Whether to save the Metropolis-Hastings ratios. 

479 filter_splitless_vars 

480 Whether to check `max_split` for variables without available cutpoints. 

481 If any are found, they are put into a list of variables to exclude from 

482 the MCMC. If `False`, no check is performed, but the results may be 

483 wrong if any variable is blocked. The function is jax-traceable only 

484 if this is set to `False`. 

485 min_points_per_leaf 

486 The minimum number of datapoints in a leaf node. 0 if not specified. 

487 Unlike `min_points_per_decision_node`, this constraint is not taken into 

488 account in the Metropolis-Hastings ratio because it would be expensive 

489 to compute. Grow moves that would violate this constraint are vetoed. 

490 This parameter is independent of `min_points_per_decision_node` and 

491 there is no check that they are coherent. It makes sense to set 

492 ``min_points_per_decision_node >= 2 * min_points_per_leaf``. 

493 log_s 

494 The logarithm of the prior probability for choosing a variable to split 

495 along in a decision rule, conditional on the ancestors. Not normalized. 

496 If not specified, use a uniform distribution. If not specified and 

497 `theta` or `rho`, `a`, `b` are, it's initialized automatically. 

498 theta 

499 The concentration parameter for the Dirichlet prior on `s`. Required 

500 only to update `log_s`. If not specified, and `rho`, `a`, `b` are 

501 specified, it's initialized automatically. 

502 a 

503 b 

504 rho 

505 Parameters of the prior on `theta`. Required only to sample `theta`. 

506 sparse_on_at 

507 After how many MCMC steps to turn on variable selection. 

508 num_chains 

509 The number of independent MCMC chains to represent in the state. Single 

510 chain with scalar values if not specified. 

511 mesh 

512 A jax mesh used to shard data and computation across multiple devices. 

513 If it has a 'chains' axis, that axis is used to shard the chains. If it 

514 has a 'data' axis, that axis is used to shard the datapoints. 

515 

516 As a shorthand, if a dictionary mapping axis names to axis size is 

517 passed, the corresponding mesh is created, e.g., ``dict(chains=4, 

518 data=2)`` will let jax pick 8 devices to split chains (which must be a 

519 multiple of 4) across 4 pairs of devices, where in each pair the data is 

520 split in two. 

521 

522 Note: if a mesh is passed, the arrays are always sharded according to 

523 it. In particular even if the mesh has no 'chains' or 'data' axis, the 

524 arrays will be replicated on all devices in the mesh. 

525 target_platform 

526 Platform ('cpu' or 'gpu') used to determine the batch sizes 

527 automatically. If `mesh` is specified, the platform is inferred from the 

528 devices in the mesh. Otherwise, if `y` is a concrete array (i.e., `init` 

529 is not invoked in a `jax.jit` context), the platform is set to the 

530 platform of `y`. Otherwise, use `target_platform`. 

531 

532 To avoid confusion, in all cases where the `target_platform` argument 

533 would be ignored, `init` raises an exception if `target_platform` is 

534 set. 

535 

536 Returns 

537 ------- 

538 An initialized BART MCMC state. 

539 

540 Raises 

541 ------ 

542 ValueError 

543 If `y` is boolean and arguments unused in binary regression are set. 

544 

545 Notes 

546 ----- 

547 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out 

548 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left 

549 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be 

550 integers in the range ``[0, 1, ..., max_split[i]]``. 

551 """ 

552 # convert to array all array-like arguments that are used in other 

553 # configurations but don't need further processing themselves 

554 X = jnp.asarray(X) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

555 y = jnp.asarray(y) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

556 offset = jnp.asarray(offset) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

557 leaf_prior_cov_inv = jnp.asarray(leaf_prior_cov_inv) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

558 max_split = jnp.asarray(max_split) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

559 

560 # check p_nonterminal and pad it with a 0 at the end (still not final shape) 

561 p_nonterminal = _parse_p_nonterminal(p_nonterminal) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

562 

563 # process arguments that change depending on outcome type 

564 is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale = ( 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

565 _init_shape_shifting_parameters( 

566 y, offset, error_scale, error_cov_df, error_cov_scale, leaf_prior_cov_inv 

567 ) 

568 ) 

569 

570 # extract array sizes from arguments 

571 (max_depth,) = p_nonterminal.shape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

572 p, n = X.shape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

573 

574 # check and initialize sparsity parameters 

575 if not _all_none_or_not_none(rho, a, b): 575 ↛ 576line 575 didn't jump to line 576 because the condition on line 575 was never true2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

576 msg = 'rho, a, b are not either all `None` or all set' 

577 raise ValueError(msg) 

578 if theta is None and rho is not None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

579 theta = rho 1K4u56H72IJP,89!#$V%Xr'

580 if log_s is None and theta is not None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

581 log_s = jnp.zeros(max_split.size) 1Kh4Lui5M6NHj7O2tIkJlPm,n8Q9R!S#T$UVo%WXpqr'Y

582 if not _all_none_or_not_none(theta, sparse_on_at): 582 ↛ 583line 582 didn't jump to line 583 because the condition on line 582 was never true2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

583 msg = 'sparsity params (either theta or rho,a,b) and sparse_on_at must be either all None or all set' 

584 raise ValueError(msg) 

585 

586 # process multichain settings 

587 chain_shape = () if num_chains is None else (num_chains,) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

588 resid_shape = chain_shape + y.shape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

589 tree_shape = (*chain_shape, num_trees) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

590 add_chains = partial(_add_chains, chain_shape=chain_shape) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

591 

592 # determine batch sizes for reductions 

593 mesh = _parse_mesh(num_chains, mesh) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

594 target_platform = _parse_target_platform( 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

595 y, mesh, target_platform, resid_batch_size, count_batch_size 

596 ) 

597 resid_batch_size, count_batch_size = _choose_suffstat_batch_size( 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

598 resid_batch_size, 

599 count_batch_size, 

600 y, 

601 max_depth, 

602 num_trees, 

603 num_chains, 

604 mesh, 

605 target_platform, 

606 ) 

607 

608 # initialize all remaining stuff and put it in an unsharded state 

609 state = State( 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

610 X=X, 

611 y=y, 

612 z=jnp.full(resid_shape, offset) if is_binary else None, 

613 offset=offset, 

614 resid=jnp.zeros(resid_shape) 

615 if is_binary 

616 else jnp.broadcast_to(y - offset[..., None], resid_shape), 

617 error_cov_inv=add_chains(error_cov_inv), 

618 prec_scale=( 

619 None if error_scale is None else jnp.reciprocal(jnp.square(error_scale)) 

620 ), 

621 error_cov_df=error_cov_df, 

622 error_cov_scale=error_cov_scale, 

623 forest=Forest( 

624 leaf_tree=make_tree(max_depth, jnp.float32, tree_shape + kshape), 

625 var_tree=make_tree( 

626 max_depth - 1, minimal_unsigned_dtype(p - 1), tree_shape 

627 ), 

628 split_tree=make_tree(max_depth - 1, max_split.dtype, tree_shape), 

629 affluence_tree=( 

630 make_tree(max_depth - 1, bool, tree_shape) 

631 .at[..., 1] 

632 .set( 

633 True 

634 if min_points_per_decision_node is None 

635 else n >= min_points_per_decision_node 

636 ) 

637 ), 

638 blocked_vars=_get_blocked_vars(filter_splitless_vars, max_split), 

639 max_split=max_split, 

640 grow_prop_count=jnp.zeros(chain_shape, int), 

641 grow_acc_count=jnp.zeros(chain_shape, int), 

642 prune_prop_count=jnp.zeros(chain_shape, int), 

643 prune_acc_count=jnp.zeros(chain_shape, int), 

644 p_nonterminal=p_nonterminal[tree_depths(2**max_depth)], 

645 p_propose_grow=p_nonterminal[tree_depths(2 ** (max_depth - 1))], 

646 leaf_indices=jnp.ones( 

647 (*tree_shape, n), minimal_unsigned_dtype(2**max_depth - 1) 

648 ), 

649 min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node), 

650 min_points_per_leaf=_asarray_or_none(min_points_per_leaf), 

651 log_trans_prior=jnp.zeros((*chain_shape, num_trees)) 

652 if save_ratios 

653 else None, 

654 log_likelihood=jnp.zeros((*chain_shape, num_trees)) 

655 if save_ratios 

656 else None, 

657 leaf_prior_cov_inv=leaf_prior_cov_inv, 

658 log_s=add_chains(_asarray_or_none(log_s)), 

659 theta=add_chains(_asarray_or_none(theta)), 

660 rho=_asarray_or_none(rho), 

661 a=_asarray_or_none(a), 

662 b=_asarray_or_none(b), 

663 ), 

664 config=StepConfig( 

665 steps_done=jnp.int32(0), 

666 sparse_on_at=_asarray_or_none(sparse_on_at), 

667 resid_batch_size=resid_batch_size, 

668 count_batch_size=count_batch_size, 

669 mesh=mesh, 

670 ), 

671 ) 

672 

673 # move all arrays to the appropriate device 

674 return _shard_state(state) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

675 

676 

677def _get_blocked_vars( 

678 filter_splitless_vars: bool, max_split: UInt[Array, ' p'] 

679) -> None | UInt[Array, ' q']: 

680 """Initialize the `blocked_vars` field.""" 

681 if filter_splitless_vars: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

682 (p,) = max_split.shape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z v w x a b 0 e f g c d y 1 z (

683 (blocked_vars,) = jnp.nonzero(max_split == 0) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z v w x a b 0 e f g c d y 1 z (

684 return blocked_vars.astype(minimal_unsigned_dtype(p)) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z v w x a b 0 e f g c d y 1 z (

685 # see `fully_used_variables` for the type cast 

686 else: 

687 return None 22 C t { | } wb~ abbbxbcbdbebyb^ _ ` lbgbhbibjb? @ [ ] / : ; =

688 

689 

690def _add_chains( 

691 x: Shaped[Array, '*shape'] | None, chain_shape: tuple[int, ...] 

692) -> Shaped[Array, '*shape'] | Shaped[Array, ' num_chains *shape'] | None: 

693 """Broadcast `x` to all chains.""" 

694 if x is None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

695 return None 23 mbA nbobB . C pbqbD E rbsbtbub) * + fbF - G vbs Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

696 else: 

697 return jnp.broadcast_to(x, chain_shape + x.shape) 2K h 4 L u i 5 M 6 N H j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T ) * + fb$ U V o % W X p q r ' Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

698 

699 

700def _parse_mesh( 

701 num_chains: int | None, mesh: Mesh | dict[str, int] | None 

702) -> Mesh | None: 

703 """Parse the `mesh` argument.""" 

704 if mesh is None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

705 return None 23 mbA nbobB . C pbqbD , E rbsbtbub) * + fbF - G vb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x 0 1 z ( gbhbibjb? @ [ ] / : ; =

706 

707 # convert dict format to actual mesh 

708 if isinstance(mesh, dict): 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

709 assert set(mesh).issubset({'chains', 'data'}) 1abefgcdy

710 mesh = make_mesh( 1abefgcdy

711 tuple(mesh.values()), tuple(mesh), axis_types=(AxisType.Auto,) * len(mesh) 

712 ) 

713 

714 # check there's no chain mesh axis if there are no chains 

715 if num_chains is None: 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

716 assert 'chains' not in mesh.axis_names 1K456H72IJP89!#$V%X'y

717 

718 # check the axes we use are in auto mode 

719 assert 'chains' not in mesh.axis_names or 'chains' in _auto_axes(mesh) 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

720 assert 'data' not in mesh.axis_names or 'data' in _auto_axes(mesh) 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

721 

722 return mesh 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

723 

724 

725def _parse_target_platform( 

726 y: Array, 

727 mesh: Mesh | None, 

728 target_platform: Literal['cpu', 'gpu'] | None, 

729 resid_batch_size: int | None | Literal['auto'], 

730 count_batch_size: int | None | Literal['auto'], 

731) -> Literal['cpu', 'gpu'] | None: 

732 if mesh is not None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

733 assert target_platform is None, 'mesh provided, do not set target_platform' 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

734 return mesh.devices.flat[0].platform 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

735 elif hasattr(y, 'platform'): 23 mbA nbobB . C pbqbD , E rbsbtbub) * + fbF - G vb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x 0 1 z ( gbhbibjb? @ [ ] / : ; =

736 assert target_platform is None, 'device inferred from y, unset target_platform' 23 mbA nbobB . C pbqbD , E rbsbtbub) * + fbF - G vbv w x 0 1 z ( gbhbibjb? @ [ ] / : ; =

737 return y.platform() 23 mbA nbobB . C pbqbD , E rbsbtbub) * + fbF - G vbv w x 0 1 z ( gbhbibjb? @ [ ] / : ; =

738 elif resid_batch_size == 'auto' or count_batch_size == 'auto': 2C { | } wb~ abbbxbcbdbebyb^ _ ` lb

739 assert target_platform in ('cpu', 'gpu') 2{ | } wb~ abbbxbcbdbebyb^ _ ` lb

740 return target_platform 2{ | } wb~ abbbxbcbdbebyb^ _ ` lb

741 else: 

742 assert target_platform is None, 'target_platform not used, unset it' 1C

743 return target_platform 1C

744 

745 

746def _auto_axes(mesh: Mesh) -> list[str]: 

747 """Re-implement `Mesh.auto_axes` because that's missing in jax v0.5.""" 

748 # Mesh.auto_axes added in jax v0.6.0 

749 return [ 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

750 n 

751 for n, t in zip(mesh.axis_names, mesh.axis_types, strict=True) 

752 if t == AxisType.Auto 

753 ] 

754 

755 

756def _shard_state(state: State) -> State: 

757 """Place all fields in the state on the appropriate devices.""" 

758 mesh = state.config.mesh 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

759 if mesh is None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

760 return state 23 mbA nbobB . C pbqbD , E rbsbtbub) * + fbF - G vb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x 0 1 z ( gbhbibjb? @ [ ] / : ; =

761 

762 def shard_leaf( 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

763 x: Array | None, chain_axis: int | None, data_axis: int | None 

764 ) -> Array | None: 

765 if x is None: 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

766 return None 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

767 

768 spec = [None] * x.ndim 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

769 if chain_axis is not None and 'chains' in mesh.axis_names: 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

770 spec[chain_axis] = 'chains' 1hLuiMNjOtklmnQRSTUoWpqrYsZabcd

771 if data_axis is not None and 'data' in mesh.axis_names: 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

772 spec[data_axis] = 'data' 1K4u56H72IJP89!#$V%X'efgcdy

773 

774 spec = PartitionSpec(*spec) 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

775 return device_put(x, NamedSharding(mesh, spec), donate=True) 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

776 

777 return tree.map( 1Kh4Lui5M6NHj7O2tIkJlPmn8Q9R!S#T$UVo%WXpqr'YsZabefgcdy

778 shard_leaf, 

779 state, 

780 chain_vmap_axes(state), 

781 data_vmap_axes(state), 

782 is_leaf=lambda x: x is None, 

783 ) 

784 

785 

786def _all_none_or_not_none(*args): 

787 is_none = [x is None for x in args] 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

788 return all(is_none) or not any(is_none) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

789 

790 

791def _asarray_or_none(x): 

792 if x is None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

793 return None 23 h mbL A i nbM obN H B j . O C t I pbk J qbl D m E n rbQ sbR tbS ubT ) * + fbU F o - W G p q vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

794 return jnp.asarray(x) 2K h 4 L u A i 5 M 6 N H B j 7 O 2 t I k J l P m , n 8 Q 9 R ! S # T $ U V o % W X p q r ' Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lb

795 

796 

797def _get_platform(mesh: Mesh | None) -> str: 

798 if mesh is None: 

799 return get_default_device().platform 

800 else: 

801 return mesh.devices.flat[0].platform 

802 

803 

804def _choose_suffstat_batch_size( 

805 resid_batch_size: int | None | Literal['auto'], 

806 count_batch_size: int | None | Literal['auto'], 

807 y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'], 

808 max_depth: int, 

809 num_trees: int, 

810 num_chains: int | None, 

811 mesh: Mesh | None, 

812 target_platform: Literal['cpu', 'gpu'] | None, 

813) -> tuple[int | None, int | None]: 

814 """Determine batch sizes for reductions.""" 

815 # get number of outcomes and of datapoints 

816 k, n = _get_k_n(y) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

817 

818 # get per-device values 

819 if num_chains is None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

820 num_chains = 1 2K 4 5 6 H 7 2 I J P , 8 9 ! # ) * + fb$ V % X ' ^ _ ` lb0 y 1 z ( gbhbibjb? @ [ ] / : ; =

821 num_chains //= get_axis_size(mesh, 'chains') 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zb{ | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

822 n //= get_axis_size(mesh, 'data') 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

823 

824 # compute auxiliary sizes 

825 batch_size = k * num_chains 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

826 unbatched_accum_bytes_times_batch_size = num_trees * 2**max_depth * 4 * n 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

827 

828 def final_round(s: float) -> int | None: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

829 # multiply by batch_size because if the calculation is already 

830 # parallelizable over batching dims there is correspondingly less need 

831 # to parallelize across datapoints 

832 s *= batch_size 2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (

833 

834 # at least 1, i.e., each datapoint is its own batch 

835 s = max(1, s) 2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (

836 

837 # round to the nearest power of 2 because I guess XLA and the hardware 

838 # will like that 

839 s = 2 ** round(log2(s)) 2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (

840 

841 # disable batching if the batch is as large as the whole dataset 

842 return s if s < n else None 2h L i M N :bj O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( ;b

843 

844 if resid_batch_size != 'auto': 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

845 rbs = resid_batch_size 2K 3 4 mbu A 5 nb6 obH B 7 . 2 C I pbJ qbP D , E 8 rb9 sb! tb# ub$ V F % - X G ' vb1 z gbhbibjb? @ [ ] / : ; =

846 elif target_platform == 'cpu': 846 ↛ 850line 846 didn't jump to line 850 because the condition on line 846 was always true2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (

847 rbs = final_round(n / 6) 2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (

848 # instead of 6 I guess I should have in general the number of "good" 

849 # physical cores 

850 elif target_platform == 'gpu': 

851 rbs = final_round((2 * n) ** (1 / 3)) 

852 

853 if count_batch_size != 'auto': 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

854 cbs = count_batch_size 2K 3 4 mbu A 5 nb6 obH B 7 . 2 C I pbJ qbP D , E 8 rb9 sb! tb# ub$ V F % - X G ' vb1 z gbhbibjb? @ [ ] / : ; =

855 elif target_platform == 'cpu': 855 ↛ 857line 855 didn't jump to line 857 because the condition on line 855 was always true2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (

856 cbs = None 2h L i M N j O t k l m n Q R S T ) * + fbU o W p q r Y s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z (

857 elif target_platform == 'gpu': 

858 cbs = (n / 16) ** 0.5 

859 

860 # ensure we don't exceed ~512MiB of memory usage per device 

861 max_memory = 2**29 

862 min_batch_size = ceil(unbatched_accum_bytes_times_batch_size / max_memory) 

863 cbs = max(cbs, min_batch_size) 

864 

865 cbs = final_round(cbs) 

866 

867 return rbs, cbs 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

868 

869 

870def get_axis_size(mesh: Mesh | None, axis_name: str) -> int: 

871 if mesh is None or axis_name not in mesh.axis_names: 2K 3 h 4 mbL ,b-b.b+b/b*bu A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

872 return 1 2K 3 h 4 mbL ,b-b.b/b*bA i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g y 1 z ( gbhbibjb? @ [ ] / : ; =

873 else: 

874 i = mesh.axis_names.index(axis_name) 2K h 4 L +b*bu i 5 M 6 N H j 7 O 2 t I k J l P m n 8 Q 9 R ! S # T $ U V o % W X p q r ' Y s Z a b e f g c d y

875 return mesh.axis_sizes[i] 2K h 4 L +b*bu i 5 M 6 N H j 7 O 2 t I k J l P m n 8 Q 9 R ! S # T $ U V o % W X p q r ' Y s Z a b e f g c d y

876 

877 

878def _get_k_n(y: Array) -> tuple[int, int]: 

879 if y.ndim == 2: 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z { | } wb~ abbbxbcbdbebyb^ _ ` lbv w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

880 return y.shape 2{ | } ~ abbbcbdbeb^ _ ` v w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

881 else: 

882 (n,) = y.shape 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z wbxbyblbgbhbibjb/ : ; =

883 return 1, n 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z wbxbyblbgbhbibjb/ : ; =

884 

885 

886def chol_with_gersh( 

887 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool = False 

888) -> Float32[Array, '*batch_shape k k']: 

889 """Cholesky with Gershgorin stabilization, supports batching.""" 

890 return _chol_with_gersh_impl(mat, absolute_eps) 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b

891 

892 

893@partial(jnp.vectorize, signature='(k,k)->(k,k)', excluded=(1,)) 

894def _chol_with_gersh_impl( 

895 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool 

896) -> Float32[Array, '*batch_shape k k']: 

897 rho = jnp.max(jnp.sum(jnp.abs(mat), axis=1), initial=0.0) 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b

898 eps = jnp.finfo(mat.dtype).eps 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b

899 u = mat.shape[0] * rho * eps 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b

900 if absolute_eps: 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b

901 u += eps 2zbAbBbCbLbMbv a b 0 e c d y 8b? 9b!b

902 mat = mat.at[jnp.diag_indices_from(mat)].add(u) 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b

903 return jnp.linalg.cholesky(mat) 2zb{ Ab| Bb} Db~ EbabCbbbFbcbGbdbHbebLb^ Mb_ Nb` v w x a b 0 e f g c d y 1 z ( gbhbibjb8b? @ [ ] / : ; = #b$b%b'b(b)b9b!b

904 

905 

906def _inv_via_chol_with_gersh(mat: Float32[Array, 'k k']) -> Float32[Array, 'k k']: 

907 """Compute matrix inverse via Cholesky with Gershgorin stabilization. 

908 

909 DO NOT USE THIS FUNCTION UNLESS YOU REALLY NEED TO. 

910 """ 

911 L = chol_with_gersh(mat) 2{ | } ~ abbbcbdbeb^ _ ` v w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

912 I = jnp.eye(mat.shape[0], dtype=mat.dtype) 2{ | } ~ abbbcbdbeb^ _ ` v w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

913 L_inv = solve_triangular(L, I, lower=True) 2{ | } ~ abbbcbdbeb^ _ ` v w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

914 return L_inv.T @ L_inv 2{ | } ~ abbbcbdbeb^ _ ` v w x a b 0 e f g c d y 1 z ( gbhbibjb? @ [ ] / : ; =

915 

916 

917def get_num_chains(x: PyTree) -> int | None: 

918 """Get the number of chains of a pytree. 

919 

920 Find all nodes in the structure that define 'num_chains()', stopping 

921 traversal at nodes that define it. Check all values obtained invoking 

922 `num_chains` are equal, then return it. 

923 """ 

924 leaves, _ = flatten(x, is_leaf=lambda x: hasattr(x, 'num_chains')) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =

925 num_chains = [x.num_chains() for x in leaves if hasattr(x, 'num_chains')] 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =

926 ref = num_chains[0] 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =

927 assert all(c == ref for c in num_chains) 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =

928 return ref 2K 3 h 4 mbL u A i 5 nbM 6 obN H B j 7 . O 2 C t I pbk J qbl P D m , E n 8 rbQ 9 sbR ! tbS # ubT ) * + fb$ U V F o % - W X G p q r ' vbY s Z zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4bv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =

929 

930 

931def _chain_axes_with_keys(x: PyTree) -> PyTree[int | None]: 

932 """Return `chain_vmap_axes(x)` but also set to 0 for random keys.""" 

933 axes = chain_vmap_axes(x) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

934 

935 def axis_if_key(x, axis): 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

936 if is_key(x): 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

937 return 0 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

938 else: 

939 return axis 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

940 

941 return tree.map(axis_if_key, x, axes) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

942 

943 

944def _get_mc_out_axes( 

945 fun: Callable[[tuple, dict], PyTree], args: PyTree, in_axes: PyTree[int | None] 

946) -> PyTree[int | None]: 

947 """Decide chain vmap axes for outputs.""" 

948 vmapped_fun = vmap(fun, in_axes=in_axes) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

949 out = eval_shape(vmapped_fun, *args) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

950 return chain_vmap_axes(out) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

951 

952 

953def _find_mesh(x: PyTree) -> Mesh | None: 

954 """Find the mesh used for chains.""" 

955 

956 class MeshFound(Exception): 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

957 pass 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

958 

959 def find_mesh(x: State | Any): 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

960 if isinstance(x, State): 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

961 raise MeshFound(x.config.mesh) 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

962 

963 try: 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

964 tree.map(find_mesh, x, is_leaf=lambda x: isinstance(x, State)) 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

965 except MeshFound as e: 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

966 return e.args[0] 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

967 else: 

968 raise ValueError 

969 

970 

971def _split_all_keys(x: PyTree, num_chains: int) -> PyTree: 

972 """Split all random keys in `num_chains` keys.""" 

973 mesh = _find_mesh(x) 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

974 

975 def split_key(x): 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

976 if is_key(x): 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

977 x = random.split(x, num_chains) 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

978 if mesh is not None and 'chains' in mesh.axis_names: 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

979 x = device_put(x, NamedSharding(mesh, PartitionSpec('chains'))) 1huijtklmnopqrsabcd

980 return x 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

981 

982 return tree.map(split_key, x) 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

983 

984 

985def vmap_chains( 

986 fun: Callable[..., T], *, auto_split_keys: bool = False 

987) -> Callable[..., T]: 

988 """Apply vmap on chain axes automatically if the inputs are multichain.""" 

989 

990 @wraps(fun) 

991 def auto_vmapped_fun(*args, **kwargs) -> T: 

992 all_args = args, kwargs 2K 3 h u A i H B j 2 C t I k J l P D m , E n ) * + V F o X G p q r s zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =

993 num_chains = get_num_chains(all_args) 2K 3 h u A i H B j 2 C t I k J l P D m , E n ) * + V F o X G p q r s zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =

994 if num_chains is not None: 2K 3 h u A i H B j 2 C t I k J l P D m , E n ) * + V F o X G p q r s zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =

995 if auto_split_keys: 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

996 all_args = _split_all_keys(all_args, num_chains) 2h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d

997 

998 def wrapped_fun(args, kwargs): 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

999 return fun(*args, **kwargs) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

1000 

1001 mc_in_axes = _chain_axes_with_keys(all_args) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

1002 mc_out_axes = _get_mc_out_axes(wrapped_fun, all_args, mc_in_axes) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

1003 vmapped_fun = vmap(wrapped_fun, in_axes=mc_in_axes, out_axes=mc_out_axes) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

1004 return vmapped_fun(*all_args) 23 h u A i B j C t k l D m E n F o G p q r s zbAbBbIbDbEbCbJbFbGbHbKbv w x a b e f g c d z

1005 

1006 else: 

1007 return fun(*args, **kwargs) 2K h u A i H B j 2 C t I k J l P D m , E n ) * + V F o X G p q r s zbAbBbIbDbEbCbJbFbGbHbKbLbMbNbObv w x a b 0 e f g c d y 1 z ? @ [ ] / : ; =

1008 

1009 return auto_vmapped_fun