Coverage for src / bartz / mcmcloop.py: 96%

227 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/mcmcloop.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions that implement the full BART posterior MCMC loop. 

26 

27The entry points are `run_mcmc` and `make_default_callback`. 

28""" 

29 

30from collections.abc import Callable 

31from dataclasses import fields 

32from functools import partial, update_wrapper, wraps 

33from math import floor 

34from typing import Any, NamedTuple, Protocol, TypeVar 

35 

36import jax 

37import numpy 

38from equinox import Module 

39from jax import ( 

40 NamedSharding, 

41 ShapeDtypeStruct, 

42 debug, 

43 device_put, 

44 eval_shape, 

45 jit, 

46 lax, 

47 named_call, 

48 tree, 

49) 

50from jax import numpy as jnp 

51from jax.nn import softmax 

52from jax.sharding import Mesh, PartitionSpec 

53from jaxtyping import ( 

54 Array, 

55 ArrayLike, 

56 Bool, 

57 Float32, 

58 Int32, 

59 Integer, 

60 Key, 

61 PyTree, 

62 Shaped, 

63 UInt, 

64) 

65 

66from bartz import jaxext, mcmcstep 

67from bartz.grove import TreeHeaps, evaluate_forest, forest_fill, var_histogram 

68from bartz.jaxext import autobatch, jit_active 

69from bartz.mcmcstep import State 

70from bartz.mcmcstep._state import chain_vmap_axes, field, get_axis_size, get_num_chains 

71 

72 

73class BurninTrace(Module): 

74 """MCMC trace with only diagnostic values.""" 

75 

76 error_cov_inv: ( 

77 Float32[Array, '*chains_and_samples'] 

78 | Float32[Array, '*chains_and_samples k k'] 

79 | None 

80 ) = field(chains=True) 

81 theta: Float32[Array, '*chains_and_samples'] | None = field(chains=True) 

82 grow_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True) 

83 grow_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True) 

84 prune_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True) 

85 prune_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True) 

86 log_likelihood: Float32[Array, '*chains_and_samples'] | None = field(chains=True) 

87 log_trans_prior: Float32[Array, '*chains_and_samples'] | None = field(chains=True) 

88 

89 @classmethod 

90 def from_state(cls, state: State) -> 'BurninTrace': 

91 """Create a single-item burn-in trace from a MCMC state.""" 

92 return cls( 1BqfcbjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./

93 error_cov_inv=state.error_cov_inv, 

94 theta=state.forest.theta, 

95 grow_prop_count=state.forest.grow_prop_count, 

96 grow_acc_count=state.forest.grow_acc_count, 

97 prune_prop_count=state.forest.prune_prop_count, 

98 prune_acc_count=state.forest.prune_acc_count, 

99 log_likelihood=state.forest.log_likelihood, 

100 log_trans_prior=state.forest.log_trans_prior, 

101 ) 

102 

103 

104class MainTrace(BurninTrace): 

105 """MCMC trace with trees and diagnostic values.""" 

106 

107 leaf_tree: ( 

108 Float32[Array, '*chains_and_samples 2**d'] 

109 | Float32[Array, '*chains_and_samples k 2**d'] 

110 ) = field(chains=True) 

111 var_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True) 

112 split_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True) 

113 offset: Float32[Array, '*samples'] | Float32[Array, '*samples k'] 

114 varprob: Float32[Array, '*chains_and_samples p'] | None = field(chains=True) 

115 

116 @classmethod 

117 def from_state(cls, state: State) -> 'MainTrace': 

118 """Create a single-item main trace from a MCMC state.""" 

119 # compute varprob 

120 log_s = state.forest.log_s 1BqfcbjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./

121 if log_s is None: 1Bqfc@bjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./

122 varprob = None 1BjghkKmzGxstoI[1+WXYZQRS0TUV2345}6789!#$%'()*,-./

123 else: 

124 varprob = softmax(log_s, where=state.forest.max_split.astype(bool)) 1qfc@badeulMLyrNEOFvn;:wHipPJCD

125 

126 return cls( 1BqfcbjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./

127 leaf_tree=state.forest.leaf_tree, 

128 var_tree=state.forest.var_tree, 

129 split_tree=state.forest.split_tree, 

130 offset=state.offset, 

131 varprob=varprob, 

132 **vars(BurninTrace.from_state(state)), 

133 ) 

134 

135 

136CallbackState = PyTree[Any, 'T'] 

137 

138 

139class RunMCMCResult(NamedTuple): 

140 """Return value of `run_mcmc`.""" 

141 

142 final_state: State 

143 """The final MCMC state.""" 

144 

145 burnin_trace: PyTree[ 

146 Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...'] 

147 ] 

148 """The trace of the burn-in phase. For the default layout, see `BurninTrace`.""" 

149 

150 main_trace: PyTree[ 

151 Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...'] 

152 ] 

153 """The trace of the main phase. For the default layout, see `MainTrace`.""" 

154 

155 

156class Callback(Protocol): 

157 """Callback type for `run_mcmc`.""" 

158 

159 def __call__( 

160 self, 

161 *, 

162 key: Key[Array, ''], 

163 bart: State, 

164 burnin: Bool[Array, ''], 

165 i_total: Int32[Array, ''], 

166 callback_state: CallbackState, 

167 n_burn: Int32[Array, ''], 

168 n_save: Int32[Array, ''], 

169 n_skip: Int32[Array, ''], 

170 i_outer: Int32[Array, ''], 

171 inner_loop_length: Int32[Array, ''], 

172 ) -> tuple[State, CallbackState] | None: 

173 """Do an arbitrary action after an iteration of the MCMC. 

174 

175 Parameters 

176 ---------- 

177 key 

178 A key for random number generation. 

179 bart 

180 The MCMC state just after updating it. 

181 burnin 

182 Whether the last iteration was in the burn-in phase. 

183 i_total 

184 The index of the last MCMC iteration (0-based). 

185 callback_state 

186 The callback state, initially set to the argument passed to 

187 `run_mcmc`, afterwards to the value returned by the last invocation 

188 of the callback. 

189 n_burn 

190 n_save 

191 n_skip 

192 The corresponding `run_mcmc` arguments as-is. 

193 i_outer 

194 The index of the last outer loop iteration (0-based). 

195 inner_loop_length 

196 The number of MCMC iterations in the inner loop. 

197 

198 Returns 

199 ------- 

200 bart : State 

201 A possibly modified MCMC state. To avoid modifying the state, 

202 return the `bart` argument passed to the callback as-is. 

203 callback_state : CallbackState 

204 The new state to be passed on the next callback invocation. 

205 

206 Notes 

207 ----- 

208 For convenience, the callback may return `None`, and the states won't 

209 be updated. 

210 """ 

211 ... 

212 

213 

214class _Carry(Module): 

215 """Carry used in the loop in `run_mcmc`.""" 

216 

217 bart: State 

218 i_total: Int32[Array, ''] 

219 key: Key[Array, ''] 

220 burnin_trace: PyTree[ 

221 Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...'] 

222 ] 

223 main_trace: PyTree[ 

224 Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...'] 

225 ] 

226 callback_state: CallbackState 

227 

228 

229def run_mcmc( 

230 key: Key[Array, ''], 

231 bart: State, 

232 n_save: int, 

233 *, 

234 n_burn: int = 0, 

235 n_skip: int = 1, 

236 inner_loop_length: int | None = None, 

237 callback: Callback | None = None, 

238 callback_state: CallbackState = None, 

239 burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state, 

240 main_extractor: Callable[[State], PyTree] = MainTrace.from_state, 

241) -> RunMCMCResult: 

242 """ 

243 Run the MCMC for the BART posterior. 

244 

245 Parameters 

246 ---------- 

247 key 

248 A key for random number generation. 

249 bart 

250 The initial MCMC state, as created and updated by the functions in 

251 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies, 

252 so this variable is invalidated after running `run_mcmc`. Make a copy 

253 beforehand to use it again. 

254 n_save 

255 The number of iterations to save. 

256 n_burn 

257 The number of initial iterations which are not saved. 

258 n_skip 

259 The number of iterations to skip between each saved iteration, plus 1. 

260 The effective burn-in is ``n_burn + n_skip - 1``. 

261 inner_loop_length 

262 The MCMC loop is split into an outer and an inner loop. The outer loop 

263 is in Python, while the inner loop is in JAX. `inner_loop_length` is the 

264 number of iterations of the inner loop to run for each iteration of the 

265 outer loop. If not specified, the outer loop will iterate just once, 

266 with all iterations done in a single inner loop run. The inner stride is 

267 unrelated to the stride used for saving the trace. 

268 callback 

269 An arbitrary function run during the loop after updating the state. For 

270 the signature, see `Callback`. The callback is called under the jax jit, 

271 so the argument values are not available at the time the Python code is 

272 executed. Use the utilities in `jax.debug` to access the values at 

273 actual runtime. The callback may return new values for the MCMC state 

274 and the callback state. 

275 callback_state 

276 The initial custom state for the callback. 

277 burnin_extractor 

278 main_extractor 

279 Functions that extract the variables to be saved respectively in the 

280 burnin trace and main traces, given the MCMC state as argument. Must 

281 return a pytree, and must be vmappable. 

282 

283 Returns 

284 ------- 

285 A namedtuple with the final state, the burn-in trace, and the main trace. 

286 

287 Raises 

288 ------ 

289 RuntimeError 

290 If `run_mcmc` detects it's being invoked in a `jit`-wrapped context and 

291 with settings that would create unrolled loops in the trace. 

292 

293 Notes 

294 ----- 

295 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do 

296 not include the initial state, and include the final state. 

297 """ 

298 # create empty traces 

299 burnin_trace = _empty_trace(n_burn, bart, burnin_extractor) 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 } 6 7 8 9 ! # $ % ' ( ) * , - . /

300 main_trace = _empty_trace(n_save, bart, main_extractor) 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 } 6 7 8 9 ! # $ % ' ( ) * , - . /

301 

302 # determine number of iterations for inner and outer loops 

303 n_iters = n_burn + n_skip * n_save 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 } 6 7 8 9 ! # $ % ' ( ) * , - . /

304 if inner_loop_length is None: 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 } 6 7 8 9 ! # $ % ' ( ) * , - . /

305 inner_loop_length = n_iters 2abf j g bbcbh k y dbr ebfbm z G ~ ibjbkbx s t lbo mbI 6 7 8 9 ! # $ % ' ( ) * , - . /

306 if inner_loop_length: 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 } 6 7 8 9 ! # $ % ' ( ) * , - . /

307 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length) 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 }

308 else: 

309 n_outer = 1 1az6789!#$%'()*,-./

310 # setting to 0 would make for a clean noop, but it's useful to keep the 

311 # same code path for benchmarking and testing 

312 

313 # error if under jit and there are unrolled loops 

314 if jit_active() and n_outer > 1: 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 } 6 7 8 9 ! # $ % ' ( ) * , - . /

315 msg = ( 1}

316 '`run_mcmc` was called within a jit-compiled function and ' 

317 'there are more than 1 outer loops, ' 

318 'please either do not jit or set `inner_loop_length=None`' 

319 ) 

320 raise RuntimeError(msg) 1}

321 

322 replicate = partial(_replicate, mesh=bart.config.mesh) 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /

323 carry = _Carry( 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /

324 bart, 

325 replicate(jnp.int32(0)), 

326 replicate(key), 

327 burnin_trace, 

328 main_trace, 

329 callback_state, 

330 ) 

331 _run_mcmc_inner_loop._fun.reset_call_counter() # noqa: SLF001 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /

332 for i_outer in range(n_outer): 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /

333 carry = _run_mcmc_inner_loop( 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n z w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /

334 carry, 

335 inner_loop_length, 

336 callback, 

337 burnin_extractor, 

338 main_extractor, 

339 n_burn, 

340 n_save, 

341 n_skip, 

342 i_outer, 

343 n_iters, 

344 ) 

345 

346 return RunMCMCResult(carry.bart, carry.burnin_trace, carry.main_trace) 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /

347 

348 

349def _replicate(x: Array, mesh: Mesh | None) -> Array: 

350 if mesh is None: 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /

351 return x 2abf j g bbcbd h e k K dbebfbv m z w G ~ ibjbkbx s t lbo mbI 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /

352 else: 

353 return device_put(x, NamedSharding(mesh, PartitionSpec())) 2nbobB q c b a pbqbrbsbtbubd h e u l M L y r N E O F n ; : H gbhbvbwbxbyb= ? zbAbi p BbCbP J C D [

354 

355 

356@partial(jit, static_argnums=(0, 2)) 

357def _empty_trace( 

358 length: int, bart: State, extractor: Callable[[State], PyTree] 

359) -> PyTree: 

360 num_chains = get_num_chains(bart) 1BqfcbjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./

361 if num_chains is None: 1BqfcbjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./

362 out_axes = 0 1quMyNOv;wxstiP12345},-./

363 else: 

364 example_output = eval_shape(extractor, bart) 1BfcbjagdheklKLrEFmnz:GHopIJCD[+WXYZQRS0TUV6789!#$%'()*

365 chain_axes = chain_vmap_axes(example_output) 1BfcbjagdheklKLrEFmnz:GHopIJCD[+WXYZQRS0TUV6789!#$%'()*

366 out_axes = tree.map( 1BfcbjagdheklKLrEFmnz:GHopIJCD[+WXYZQRS0TUV6789!#$%'()*

367 lambda a: 0 if a is None else 1, chain_axes, is_leaf=lambda a: a is None 

368 ) 

369 return jax.vmap(extractor, in_axes=None, out_axes=out_axes, axis_size=length)(bart) 1BqfcbjagdheuklMKLyrNEOFvmn;z:wGHxstiopPIJCD[1+WXYZQRS0TUV2345}6789!#$%'()*,-./

370 

371 

372T = TypeVar('T') 

373 

374 

375class _CallCounter: 

376 """Wrap a callable to check it's not called more than once.""" 

377 

378 def __init__(self, func: Callable[..., T]) -> None: 

379 self.func = func 

380 self.n_calls = 0 

381 update_wrapper(self, func) 

382 

383 def reset_call_counter(self) -> None: 

384 """Reset the call counter.""" 

385 self.n_calls = 0 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [ 1 + W X Y Z Q R S 0 T U V 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * , - . /

386 

387 def __call__(self, *args: Any, **kwargs: Any) -> T: 

388 if self.n_calls: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

389 msg = ( 11

390 'The inner loop of `run_mcmc` was traced more than once, ' 

391 'which indicates a double compilation of the MCMC code. This ' 

392 'probably depends on the input state having different type from the ' 

393 'output state. Check the input is in a format that is the ' 

394 'same jax would output, e.g., all arrays and scalars are jax ' 

395 'arrays, with the right shardings.' 

396 ) 

397 raise RuntimeError(msg) 11

398 self.n_calls += 1 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

399 return self.func(*args, **kwargs) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

400 

401 

402@partial(jit, donate_argnums=(0,), static_argnums=(2, 3, 4)) 

403@_CallCounter 

404def _run_mcmc_inner_loop( 

405 carry: _Carry, 

406 inner_loop_length: Int32[Array, ''], 

407 callback: Callback | None, 

408 burnin_extractor: Callable[[State], PyTree], 

409 main_extractor: Callable[[State], PyTree], 

410 n_burn: Int32[Array, ''], 

411 n_save: Int32[Array, ''], 

412 n_skip: Int32[Array, ''], 

413 i_outer: Int32[Array, ''], 

414 n_iters: Int32[Array, ''], 

415) -> _Carry: 

416 # determine number of iterations for this loop batch 

417 i_upper = jnp.minimum(carry.i_total + inner_loop_length, n_iters) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

418 

419 def cond(carry: _Carry) -> Bool[Array, '']: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

420 """Whether to continue the MCMC loop.""" 

421 return carry.i_total < i_upper 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

422 

423 def body(carry: _Carry) -> _Carry: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

424 """Update the MCMC state.""" 

425 # split random key 

426 keys = jaxext.split(carry.key, 3) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

427 key = keys.pop() 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

428 

429 # update state 

430 bart = mcmcstep.step(keys.pop(), carry.bart) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

431 

432 # invoke callback 

433 callback_state = carry.callback_state 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

434 if callback is not None: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

435 rt = callback( 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

436 key=keys.pop(), 

437 bart=bart, 

438 burnin=carry.i_total < n_burn, 

439 i_total=carry.i_total, 

440 callback_state=callback_state, 

441 n_burn=n_burn, 

442 n_save=n_save, 

443 n_skip=n_skip, 

444 i_outer=i_outer, 

445 inner_loop_length=inner_loop_length, 

446 ) 

447 if rt is not None: 447 ↛ 448line 447 didn't jump to line 448 because the condition on line 447 was never true1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

448 bart, callback_state = rt 

449 

450 # save to trace 

451 burnin_trace, main_trace = _save_state_to_trace( 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

452 carry.burnin_trace, 

453 carry.main_trace, 

454 burnin_extractor, 

455 main_extractor, 

456 bart, 

457 carry.i_total, 

458 n_burn, 

459 n_skip, 

460 ) 

461 

462 return _Carry( 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

463 bart=bart, 

464 i_total=carry.i_total + 1, 

465 key=key, 

466 burnin_trace=burnin_trace, 

467 main_trace=main_trace, 

468 callback_state=callback_state, 

469 ) 

470 

471 return lax.while_loop(cond, body, carry) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

472 

473 

474@named_call 

475def _save_state_to_trace( 

476 burnin_trace: PyTree, 

477 main_trace: PyTree, 

478 burnin_extractor: Callable[[State], PyTree], 

479 main_extractor: Callable[[State], PyTree], 

480 bart: State, 

481 i_total: Int32[Array, ''], 

482 n_burn: Int32[Array, ''], 

483 n_skip: Int32[Array, ''], 

484) -> tuple[PyTree, PyTree]: 

485 # trace index where to save during burnin; out-of-bounds => noop after 

486 # burnin 

487 burnin_idx = i_total 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

488 

489 # trace index where to save during main phase; force it out-of-bounds 

490 # during burnin 

491 main_idx = (i_total - n_burn) // n_skip 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

492 noop_idx = jnp.iinfo(jnp.int32).max 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

493 noop_cond = i_total < n_burn 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

494 main_idx = jnp.where(noop_cond, noop_idx, main_idx) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

495 

496 # prepare array index 

497 num_chains = get_num_chains(bart) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

498 burnin_trace = _set(burnin_trace, burnin_idx, burnin_extractor(bart), num_chains) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

499 main_trace = _set(main_trace, main_idx, main_extractor(bart), num_chains) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

500 

501 return burnin_trace, main_trace 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

502 

503 

504def _set( 

505 trace: PyTree[Array, ' T'], 

506 index: Int32[Array, ''], 

507 val: PyTree[Array, ' T'], 

508 num_chains: int | None, 

509) -> PyTree[Array, ' T']: 

510 """Do ``trace[index] = val`` but fancier.""" 

511 chain_axis = chain_vmap_axes(val) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

512 

513 def at_set( 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

514 trace: Shaped[Array, 'chains samples *shape'] 

515 | None 

516 | Shaped[Array, ' samples *shape'] 

517 | None, 

518 val: Shaped[Array, ' chains *shape'] | Shaped[Array, '*shape'] | None, 

519 chain_axis: int | None, 

520 ) -> Shaped[Array, 'chains samples *shape'] | None: 

521 if trace is None or trace.size == 0: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

522 # this handles the case where an array is empty because jax refuses 

523 # to index into an axis of length 0, even if just in the abstract, 

524 # and optional elements that are considered leaves due to `is_leaf` 

525 # below needed to traverse `chain_axis`. 

526 return trace 1BqfcjgdhekMKLmzGxstoICD1+WXYZQRS0TUV23456789!#$%'()*,-./

527 

528 if num_chains is None or chain_axis is None: 1BqfcbjagdheuklMKLyrNEOFvmnwGHxstiopPIJCD1WXYZQRS0TUV2345}

529 ndindex = (index, ...) 1BqfcbjagdheuklMyrNEOFvmnwGHxstiopPIJCD1WXYQRSTUV2345}

530 else: 

531 ndindex = (slice(None), index, ...) 1BfcbjagdheklKLrEFmnGHopIJCDZQRS0TUV}

532 

533 return trace.at[ndindex].set(val, mode='drop') 1BqfcbjagdheuklMKLyrNEOFvmnwGHxstiopPIJCD1WXYZQRS0TUV2345

534 

535 return tree.map(at_set, trace, val, chain_axis, is_leaf=lambda x: x is None) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD1+WXYZQRS0TUV23456789!#$%'()*,-./

536 

537 

538def make_default_callback( 

539 state: State, 

540 *, 

541 dot_every: int | Integer[Array, ''] | None = 1, 

542 report_every: int | Integer[Array, ''] | None = 100, 

543) -> dict[str, Any]: 

544 """ 

545 Prepare a default callback for `run_mcmc`. 

546 

547 The callback prints a dot on every iteration, and a longer 

548 report outer loop iteration, and can do variable selection. 

549 

550 Parameters 

551 ---------- 

552 state 

553 The bart state to use the callback with, used to determine device 

554 sharding. 

555 dot_every 

556 A dot is printed every `dot_every` MCMC iterations, `None` to disable. 

557 report_every 

558 A one line report is printed every `report_every` MCMC iterations, 

559 `None` to disable. 

560 

561 Returns 

562 ------- 

563 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback. 

564 

565 Examples 

566 -------- 

567 >>> run_mcmc(key, state, ..., **make_default_callback(state, ...)) 

568 """ 

569 

570 def as_replicated_array_or_none(val: ArrayLike | None) -> None | Array: 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [

571 return None if val is None else _replicate(jnp.asarray(val), state.config.mesh) 2nbabobB q f c @ b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [

572 

573 return dict( 2nbabobB q f c b j a pbg qbrbbbsbtbcbubd h e u k l M K L y dbr N ebE O fbF v m n ; z : w G H gb~ hbvbibwbxbjbyb= kb? x s t lbzbAbi o p BbmbCbP I J C D [

574 callback=print_callback, 

575 callback_state=PrintCallbackState( 

576 as_replicated_array_or_none(dot_every), 

577 as_replicated_array_or_none(report_every), 

578 ), 

579 ) 

580 

581 

582class PrintCallbackState(Module): 

583 """State for `print_callback`.""" 

584 

585 dot_every: Int32[Array, ''] | None 

586 """A dot is printed every `dot_every` MCMC iterations, `None` to disable.""" 

587 

588 report_every: Int32[Array, ''] | None 

589 """A one line report is printed every `report_every` MCMC iterations, 

590 `None` to disable.""" 

591 

592 

593def print_callback( 

594 *, 

595 bart: State, 

596 burnin: Bool[Array, ''], 

597 i_total: Int32[Array, ''], 

598 n_burn: Int32[Array, ''], 

599 n_save: Int32[Array, ''], 

600 n_skip: Int32[Array, ''], 

601 callback_state: PrintCallbackState, 

602 **_: Any, 

603) -> None: 

604 """Print a dot and/or a report periodically during the MCMC.""" 

605 report_every = callback_state.report_every 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

606 dot_every = callback_state.dot_every 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

607 it = i_total + 1 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

608 

609 def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

610 return False if every is None else it % every == 0 1Bqfc@bjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

611 

612 report_cond = get_cond(report_every) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

613 dot_cond = get_cond(dot_every) 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

614 

615 def line_report_branch() -> None: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

616 if report_every is None: 1Bqfc@bjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

617 return 1fjghkyrmzGxstoI

618 if dot_every is None: 1Bqc@badeulMKLNEOFvnwHipPJCD

619 print_newline = False 1MKL

620 else: 

621 print_newline = it % report_every > it % dot_every 1Bqc@badeulNEOFvnwHipPJCD

622 debug.callback( 1Bqc@badeulMKLNEOFvnwHipPJCD

623 _print_report, 

624 print_dot=dot_cond, 

625 print_newline=print_newline, 

626 burnin=burnin, 

627 it=it, 

628 n_iters=n_burn + n_save * n_skip, 

629 num_chains=bart.forest.num_chains(), 

630 grow_prop_count=bart.forest.grow_prop_count.mean(), 

631 grow_acc_count=bart.forest.grow_acc_count.mean(), 

632 prune_acc_count=bart.forest.prune_acc_count.mean(), 

633 prop_total=bart.forest.split_tree.shape[-2], 

634 fill=forest_fill(bart.forest.split_tree), 

635 ) 

636 

637 def just_dot_branch() -> None: 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

638 if dot_every is None: 1Bqfc@bjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

639 return 1fjghkMKLyrmzGxstoI

640 debug.callback( 1Bqc@badeulNEOFvnwHipPJCD

641 lambda: print('.', end='', flush=True) # noqa: T201 

642 ) 

643 # logging can't do in-line printing so we use print 

644 

645 lax.cond( 1BqfcbjagdheuklMKLyrNEOFvmnzwGHxstiopPIJCD

646 report_cond, 

647 line_report_branch, 

648 lambda: lax.cond(dot_cond, just_dot_branch, lambda: None), 

649 ) 

650 

651 

652def _convert_jax_arrays_in_args(func: Callable[..., T]) -> Callable[..., T]: 

653 """Remove jax arrays from a function arguments. 

654 

655 Converts all `jax.Array` instances in the arguments to either Python scalars 

656 or numpy arrays. 

657 """ 

658 

659 def convert_jax_arrays(pytree: PyTree) -> PyTree: 

660 def convert_jax_array(val: object) -> object: 1@ba

661 if not isinstance(val, Array): 661 ↛ 662line 661 didn't jump to line 662 because the condition on line 661 was never true1@ba

662 return val 

663 elif val.shape: 663 ↛ 664line 663 didn't jump to line 664 because the condition on line 663 was never true1@ba

664 return numpy.array(val) 

665 else: 

666 return val.item() 1@ba

667 

668 return tree.map(convert_jax_array, pytree) 1@ba

669 

670 @wraps(func) 

671 def new_func(*args: Any, **kw: Any) -> T: 

672 args = convert_jax_arrays(args) 1@ba

673 kw = convert_jax_arrays(kw) 1@ba

674 return func(*args, **kw) 1@ba

675 

676 return new_func 

677 

678 

679@_convert_jax_arrays_in_args 

680# convert all jax arrays in arguments because operations on them could lead to 

681# deadlock with the main thread 

682def _print_report( 

683 *, 

684 print_dot: bool, 

685 print_newline: bool, 

686 burnin: bool, 

687 it: int, 

688 n_iters: int, 

689 num_chains: int | None, 

690 grow_prop_count: float, 

691 grow_acc_count: float, 

692 prune_acc_count: float, 

693 prop_total: int, 

694 fill: float, 

695) -> None: 

696 """Print the report for `print_callback`.""" 

697 # compute fractions 

698 grow_prop = grow_prop_count / prop_total 1@ba

699 move_acc = (grow_acc_count + prune_acc_count) / prop_total 1@ba

700 

701 # determine prefix 

702 if print_dot: 702 ↛ 704line 702 didn't jump to line 704 because the condition on line 702 was always true1@ba

703 prefix = '.\n' 1@ba

704 elif print_newline: 

705 prefix = '\n' 

706 else: 

707 prefix = '' 

708 

709 # determine suffix in parentheses 

710 msgs = [] 1@ba

711 if num_chains is not None: 1@bai

712 msgs.append(f'avg. {num_chains} chains') 1@ba

713 if burnin: 1@bai

714 msgs.append('burnin') 1@ba

715 suffix = f' ({", ".join(msgs)})' if msgs else '' 1@bai

716 

717 print( # noqa: T201, see print_callback for why not logging 1@bai

718 f'{prefix}Iteration {it}/{n_iters}, ' 

719 f'grow prob: {grow_prop:.0%}, ' 

720 f'move acc: {move_acc:.0%}, ' 

721 f'fill: {fill:.0%}{suffix}' 

722 ) 

723 

724 

725class Trace(TreeHeaps, Protocol): 

726 """Protocol for a MCMC trace.""" 

727 

728 offset: Float32[Array, '*trace_shape'] 

729 

730 

731class TreesTrace(Module): 

732 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace.""" 

733 

734 leaf_tree: ( 

735 Float32[Array, '*trace_shape num_trees 2**d'] 

736 | Float32[Array, '*trace_shape num_trees k 2**d'] 

737 ) 

738 var_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)'] 

739 split_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)'] 

740 

741 @classmethod 

742 def from_dataclass(cls, obj: TreeHeaps) -> 'TreesTrace': 

743 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`.""" 

744 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) 2q f c ] ^ _ ` | { Dbb Ebj Fba g d h e u k l y r v m n ; z : w = ? x s t i o p

745 

746 

747@jit 

748def evaluate_trace( 

749 X: UInt[Array, 'p n'], trace: Trace 

750) -> Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']: 

751 """ 

752 Compute predictions for all iterations of the BART MCMC. 

753 

754 Parameters 

755 ---------- 

756 X 

757 The predictors matrix, with `p` predictors and `n` observations. 

758 trace 

759 A main trace of the BART MCMC, as returned by `run_mcmc`. 

760 

761 Returns 

762 ------- 

763 The predictions for each chain and iteration of the MCMC. 

764 """ 

765 # per-device memory limit 

766 max_io_nbytes = 2**27 # 128 MiB 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

767 

768 # adjust memory limit for number of devices 

769 mesh = jax.typeof(trace.leaf_tree).sharding.mesh 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

770 num_devices = get_axis_size(mesh, 'chains') * get_axis_size(mesh, 'data') 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

771 max_io_nbytes *= num_devices 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

772 

773 # determine batching axes 

774 has_chains = trace.split_tree.ndim > 3 # chains, samples, trees, nodes 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

775 if has_chains: 1qfc]^_`|{@bjagdheuklyrvmn;z:w=?xstiop

776 sample_axis = 1 1fc`|{@bjagdheklrmnz:?op

777 tree_axis = 2 1fc`|{@bjagdheklrmnz:?op

778 else: 

779 sample_axis = 0 1q]^_uyv;w=xsti

780 tree_axis = 1 1q]^_uyv;w=xsti

781 

782 # batch and sum over trees 

783 batched_eval = autobatch( 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

784 evaluate_forest, 

785 max_io_nbytes, 

786 (None, tree_axis), 

787 tree_axis, 

788 reduce_ufunc=jnp.add, 

789 ) 

790 

791 # determine output shape (to avoid autobatch tracing everything 4 times) 

792 is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

793 k = trace.leaf_tree.shape[-2] if is_mv else 1 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

794 mv_shape = (k,) if is_mv else () 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

795 _, n = X.shape 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

796 out_shape = (*trace.split_tree.shape[:-2], *mv_shape, n) 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

797 

798 # adjust memory limit keeping into account that trees are summed over 

799 num_trees, hts = trace.split_tree.shape[-2:] 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

800 out_size = k * n * jnp.float32.dtype.itemsize # the value of the forest 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

801 core_io_size = ( 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

802 num_trees 

803 * hts 

804 * ( 

805 2 * k * trace.leaf_tree.itemsize 

806 + trace.var_tree.itemsize 

807 + trace.split_tree.itemsize 

808 ) 

809 + out_size 

810 ) 

811 core_int_size = (num_trees - 1) * out_size 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

812 max_io_nbytes = max(1, floor(max_io_nbytes / (1 + core_int_size / core_io_size))) 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

813 

814 # batch over mcmc samples 

815 batched_eval = autobatch( 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

816 batched_eval, 

817 max_io_nbytes, 

818 (None, sample_axis), 

819 sample_axis, 

820 warn_on_overflow=False, # the inner autobatch will handle it 

821 result_shape_dtype=ShapeDtypeStruct(out_shape, jnp.float32), 

822 ) 

823 

824 # extract only the trees from the trace 

825 trees = TreesTrace.from_dataclass(trace) 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

826 

827 # evaluate trees 

828 y_centered: Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n'] 

829 y_centered = batched_eval(X, trees) 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

830 return y_centered + trace.offset[..., None] 1qfc]^_`|{bjagdheuklyrvmn;z:w=?xstiop

831 

832 

833@partial(jit, static_argnums=(0,)) 

834def compute_varcount(p: int, trace: TreeHeaps) -> Int32[Array, '*trace_shape {p}']: 

835 """ 

836 Count how many times each predictor is used in each MCMC state. 

837 

838 Parameters 

839 ---------- 

840 p 

841 The number of predictors. 

842 trace 

843 A main trace of the BART MCMC, as returned by `run_mcmc`. 

844 

845 Returns 

846 ------- 

847 Histogram of predictor usage in each MCMC state. 

848 """ 

849 # var_tree has shape (chains? samples trees nodes) 

850 return var_histogram(p, trace.var_tree, trace.split_tree, sum_batch_axis=-1) 2q f c ] ^ _ ` { g gb~ hbs t