Coverage for src / bartz / mcmcloop.py: 95%

205 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-13 00:35 +0000

1# bartz/src/bartz/mcmcloop.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions that implement the full BART posterior MCMC loop. 

26 

27The entry points are `run_mcmc` and `make_default_callback`. 

28""" 

29 

30from collections.abc import Callable 

31from dataclasses import fields 

32from functools import partial, wraps 

33from math import floor 

34from typing import Any, Protocol 

35 

36import jax 

37import numpy 

38from equinox import Module 

39from jax import ShapeDtypeStruct, debug, eval_shape, jit, tree 

40from jax import numpy as jnp 

41from jax.nn import softmax 

42from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt 

43 

44from bartz import jaxext, mcmcstep 

45from bartz._profiler import ( 

46 cond_if_not_profiling, 

47 jit_if_not_profiling, 

48 scan_if_not_profiling, 

49) 

50from bartz.grove import TreeHeaps, evaluate_forest, forest_fill, var_histogram 

51from bartz.jaxext import autobatch 

52from bartz.mcmcstep import State 

53from bartz.mcmcstep._state import chain_vmap_axes, field, get_axis_size, get_num_chains 

54 

55 

56class BurninTrace(Module): 

57 """MCMC trace with only diagnostic values.""" 

58 

59 error_cov_inv: ( 

60 Float32[Array, '*chains_and_samples'] 

61 | Float32[Array, '*chains_and_samples k k'] 

62 | None 

63 ) = field(chains=True) 

64 theta: Float32[Array, '*chains_and_samples'] | None = field(chains=True) 

65 grow_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True) 

66 grow_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True) 

67 prune_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True) 

68 prune_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True) 

69 log_likelihood: Float32[Array, '*chains_and_samples'] | None = field(chains=True) 

70 log_trans_prior: Float32[Array, '*chains_and_samples'] | None = field(chains=True) 

71 

72 @classmethod 

73 def from_state(cls, state: State) -> 'BurninTrace': 

74 """Create a single-item burn-in trace from a MCMC state.""" 

75 return cls( 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+

76 error_cov_inv=state.error_cov_inv, 

77 theta=state.forest.theta, 

78 grow_prop_count=state.forest.grow_prop_count, 

79 grow_acc_count=state.forest.grow_acc_count, 

80 prune_prop_count=state.forest.prune_prop_count, 

81 prune_acc_count=state.forest.prune_acc_count, 

82 log_likelihood=state.forest.log_likelihood, 

83 log_trans_prior=state.forest.log_trans_prior, 

84 ) 

85 

86 

87class MainTrace(BurninTrace): 

88 """MCMC trace with trees and diagnostic values.""" 

89 

90 leaf_tree: ( 

91 Float32[Array, '*chains_and_samples 2**d'] 

92 | Float32[Array, '*chains_and_samples k 2**d'] 

93 ) = field(chains=True) 

94 var_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True) 

95 split_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True) 

96 offset: Float32[Array, '*samples'] | Float32[Array, '*samples k'] 

97 varprob: Float32[Array, '*chains_and_samples p'] | None = field(chains=True) 

98 

99 @classmethod 

100 def from_state(cls, state: State) -> 'MainTrace': 

101 """Create a single-item main trace from a MCMC state.""" 

102 # compute varprob 

103 log_s = state.forest.log_s 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+

104 if log_s is None: 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+

105 varprob = None 1fgDwizrmnkBv^'TUVWNOPXQRSYZ0123456789!#$%()*+

106 else: 

107 varprob = softmax(log_s, where=state.forest.max_split.astype(bool)) 1abdohFEslGxHypjqAceICtu

108 

109 return cls( 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+

110 leaf_tree=state.forest.leaf_tree, 

111 var_tree=state.forest.var_tree, 

112 split_tree=state.forest.split_tree, 

113 offset=state.offset, 

114 varprob=varprob, 

115 **vars(BurninTrace.from_state(state)), 

116 ) 

117 

118 

119CallbackState = PyTree[Any, 'T'] 

120 

121 

122class Callback(Protocol): 

123 """Callback type for `run_mcmc`.""" 

124 

125 def __call__( 

126 self, 

127 *, 

128 key: Key[Array, ''], 

129 bart: State, 

130 burnin: Bool[Array, ''], 

131 i_total: Int32[Array, ''], 

132 i_skip: Int32[Array, ''], 

133 callback_state: CallbackState, 

134 n_burn: Int32[Array, ''], 

135 n_save: Int32[Array, ''], 

136 n_skip: Int32[Array, ''], 

137 i_outer: Int32[Array, ''], 

138 inner_loop_length: int, 

139 ) -> tuple[State, CallbackState] | None: 

140 """Do an arbitrary action after an iteration of the MCMC. 

141 

142 Parameters 

143 ---------- 

144 key 

145 A key for random number generation. 

146 bart 

147 The MCMC state just after updating it. 

148 burnin 

149 Whether the last iteration was in the burn-in phase. 

150 i_total 

151 The index of the last MCMC iteration (0-based). 

152 i_skip 

153 The number of MCMC updates from the last saved state. The initial 

154 state counts as saved, even if it's not copied into the trace. 

155 callback_state 

156 The callback state, initially set to the argument passed to 

157 `run_mcmc`, afterwards to the value returned by the last invocation 

158 of the callback. 

159 n_burn 

160 n_save 

161 n_skip 

162 The corresponding `run_mcmc` arguments as-is. 

163 i_outer 

164 The index of the last outer loop iteration (0-based). 

165 inner_loop_length 

166 The number of MCMC iterations in the inner loop. 

167 

168 Returns 

169 ------- 

170 bart : State 

171 A possibly modified MCMC state. To avoid modifying the state, 

172 return the `bart` argument passed to the callback as-is. 

173 callback_state : CallbackState 

174 The new state to be passed on the next callback invocation. 

175 

176 Notes 

177 ----- 

178 For convenience, the callback may return `None`, and the states won't 

179 be updated. 

180 """ 

181 ... 

182 

183 

184class _Carry(Module): 

185 """Carry used in the loop in `run_mcmc`.""" 

186 

187 bart: State 

188 i_total: Int32[Array, ''] 

189 key: Key[Array, ''] 

190 burnin_trace: PyTree[ 

191 Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...'] 

192 ] 

193 main_trace: PyTree[ 

194 Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...'] 

195 ] 

196 callback_state: CallbackState 

197 

198 

199def run_mcmc( 

200 key: Key[Array, ''], 

201 bart: State, 

202 n_save: int, 

203 *, 

204 n_burn: int = 0, 

205 n_skip: int = 1, 

206 inner_loop_length: int | None = None, 

207 callback: Callback | None = None, 

208 callback_state: CallbackState = None, 

209 burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state, 

210 main_extractor: Callable[[State], PyTree] = MainTrace.from_state, 

211) -> tuple[ 

212 State, 

213 PyTree[Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']], 

214 PyTree[Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']], 

215]: 

216 """ 

217 Run the MCMC for the BART posterior. 

218 

219 Parameters 

220 ---------- 

221 key 

222 A key for random number generation. 

223 bart 

224 The initial MCMC state, as created and updated by the functions in 

225 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies, 

226 so this variable is invalidated after running `run_mcmc`. Make a copy 

227 beforehand to use it again. 

228 n_save 

229 The number of iterations to save. 

230 n_burn 

231 The number of initial iterations which are not saved. 

232 n_skip 

233 The number of iterations to skip between each saved iteration, plus 1. 

234 The effective burn-in is ``n_burn + n_skip - 1``. 

235 inner_loop_length 

236 The MCMC loop is split into an outer and an inner loop. The outer loop 

237 is in Python, while the inner loop is in JAX. `inner_loop_length` is the 

238 number of iterations of the inner loop to run for each iteration of the 

239 outer loop. If not specified, the outer loop will iterate just once, 

240 with all iterations done in a single inner loop run. The inner stride is 

241 unrelated to the stride used for saving the trace. 

242 callback 

243 An arbitrary function run during the loop after updating the state. For 

244 the signature, see `Callback`. The callback is called under the jax jit, 

245 so the argument values are not available at the time the Python code is 

246 executed. Use the utilities in `jax.debug` to access the values at 

247 actual runtime. The callback may return new values for the MCMC state 

248 and the callback state. 

249 callback_state 

250 The initial custom state for the callback. 

251 burnin_extractor 

252 main_extractor 

253 Functions that extract the variables to be saved respectively in the 

254 burnin trace and main traces, given the MCMC state as argument. Must 

255 return a pytree, and must be vmappable. 

256 

257 Returns 

258 ------- 

259 bart : State 

260 The final MCMC state. 

261 burnin_trace : PyTree[Shaped[Array, 'n_burn *']] 

262 The trace of the burn-in phase. For the default layout, see `BurninTrace`. 

263 main_trace : PyTree[Shaped[Array, 'n_save *']] 

264 The trace of the main phase. For the default layout, see `MainTrace`. 

265 

266 Notes 

267 ----- 

268 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do 

269 not include the initial state, and include the final state. 

270 """ 

271 burnin_trace = _empty_trace(n_burn, bart, burnin_extractor) 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +

272 main_trace = _empty_trace(n_save, bart, main_extractor) 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +

273 

274 # determine number of iterations for inner and outer loops 

275 n_iters = n_burn + n_skip * n_save 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +

276 if inner_loop_length is None: 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +

277 inner_loop_length = n_iters 2, - f ` { g s w l | } i z _ bbcbdbr m n ebk fbB gb2 3 4 5 6 7 8 9 ! # $ % ( ) * +

278 if inner_loop_length: 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +

279 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length) 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1

280 else: 

281 n_outer = 1 123456789!#$%()*+

282 # setting to 0 would make for a clean noop, but it's useful to keep the 

283 # same code path for benchmarking and testing 

284 

285 carry = _Carry(bart, jnp.int32(0), key, burnin_trace, main_trace, callback_state) 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +

286 for i_outer in range(n_outer): 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +

287 carry = _run_mcmc_inner_loop( 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +

288 carry, 

289 inner_loop_length, 

290 callback, 

291 burnin_extractor, 

292 main_extractor, 

293 n_burn, 

294 n_save, 

295 n_skip, 

296 i_outer, 

297 n_iters, 

298 ) 

299 

300 return carry.bart, carry.burnin_trace, carry.main_trace 2L , a M - K b f d hb` ibjb{ kbo g h s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^ ' T U V W N O P X Q R S Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ % ( ) * +

301 

302 

303@partial(jit, static_argnums=(0, 2)) 

304def _empty_trace( 

305 length: int, bart: State, extractor: Callable[[State], PyTree] 

306) -> PyTree: 

307 num_chains = get_num_chains(bart) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+

308 if num_chains is None: 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+

309 out_axes = 0 1oFsGHpqrmncIYZ01()*+

310 else: 

311 example_output = eval_shape(extractor, bart) 1abfdghDEwlxyijzAkeBCtuv^'TUVWNOPXQRS23456789!#$%

312 chain_axes = chain_vmap_axes(example_output) 1abfdghDEwlxyijzAkeBCtuv^'TUVWNOPXQRS23456789!#$%

313 out_axes = tree.map( 1abfdghDEwlxyijzAkeBCtuv^'TUVWNOPXQRS23456789!#$%

314 lambda a: 0 if a is None else 1, chain_axes, is_leaf=lambda a: a is None 

315 ) 

316 return jax.vmap(extractor, in_axes=None, out_axes=out_axes, axis_size=length)(bart) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv^'TUVWNOPXQRSYZ0123456789!#$%()*+

317 

318 

319@jit 

320def _compute_i_skip( 

321 i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, ''] 

322) -> Int32[Array, '']: 

323 """Compute the `i_skip` argument passed to `callback`.""" 

324 burnin = i_total < n_burn 1abdce

325 return jnp.where( 1abdce

326 burnin, 

327 i_total + 1, 

328 (i_total - n_burn + 1) % n_skip 

329 + jnp.where(i_total - n_burn + 1 < n_skip, n_burn, 0), 

330 ) 

331 

332 

333@partial(jit_if_not_profiling, donate_argnums=(0,), static_argnums=(1, 2, 3, 4)) 

334def _run_mcmc_inner_loop( 

335 carry: _Carry, 

336 inner_loop_length: int, 

337 callback: Callback | None, 

338 burnin_extractor: Callable[[State], PyTree], 

339 main_extractor: Callable[[State], PyTree], 

340 n_burn: Int32[Array, ''], 

341 n_save: Int32[Array, ''], 

342 n_skip: Int32[Array, ''], 

343 i_outer: Int32[Array, ''], 

344 n_iters: Int32[Array, ''], 

345) -> _Carry: 

346 def loop_impl(carry: _Carry) -> _Carry: 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

347 """Loop body to run if i_total < n_iters.""" 

348 # split random key 

349 keys = jaxext.split(carry.key, 3) 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

350 key = keys.pop() 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

351 

352 # update state 

353 bart = mcmcstep.step(keys.pop(), carry.bart) 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

354 

355 # invoke callback 

356 callback_state = carry.callback_state 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

357 if callback is not None: 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

358 i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip) 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

359 rt = callback( 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

360 key=keys.pop(), 

361 bart=bart, 

362 burnin=carry.i_total < n_burn, 

363 i_total=carry.i_total, 

364 i_skip=i_skip, 

365 callback_state=callback_state, 

366 n_burn=n_burn, 

367 n_save=n_save, 

368 n_skip=n_skip, 

369 i_outer=i_outer, 

370 inner_loop_length=inner_loop_length, 

371 ) 

372 if rt is not None: 372 ↛ 373line 372 didn't jump to line 373 because the condition on line 372 was never true1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

373 bart, callback_state = rt 

374 

375 # save to trace 

376 burnin_trace, main_trace = _save_state_to_trace( 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

377 carry.burnin_trace, 

378 carry.main_trace, 

379 burnin_extractor, 

380 main_extractor, 

381 bart, 

382 carry.i_total, 

383 n_burn, 

384 n_skip, 

385 ) 

386 

387 return _Carry( 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

388 bart=bart, 

389 i_total=carry.i_total + 1, 

390 key=key, 

391 burnin_trace=burnin_trace, 

392 main_trace=main_trace, 

393 callback_state=callback_state, 

394 ) 

395 

396 def loop_noop(carry: _Carry) -> _Carry: 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

397 """Loop body to run if i_total >= n_iters; it does nothing.""" 

398 return carry 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

399 

400 def loop(carry: _Carry, _) -> tuple[_Carry, None]: 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

401 carry = cond_if_not_profiling( 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

402 carry.i_total < n_iters, loop_impl, loop_noop, carry 

403 ) 

404 return carry, None 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

405 

406 carry, _ = scan_if_not_profiling(loop, carry, None, inner_loop_length) 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

407 return carry 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

408 

409 

410@partial(jit, donate_argnums=(0, 1), static_argnums=(2, 3)) 

411# this is jitted because under profiling _run_mcmc_inner_loop and the loop 

412# within it are not, so I need the donate_argnums feature of jit to avoid 

413# creating copies of the traces 

414def _save_state_to_trace( 

415 burnin_trace: PyTree, 

416 main_trace: PyTree, 

417 burnin_extractor: Callable[[State], PyTree], 

418 main_extractor: Callable[[State], PyTree], 

419 bart: State, 

420 i_total: Int32[Array, ''], 

421 n_burn: Int32[Array, ''], 

422 n_skip: Int32[Array, ''], 

423) -> tuple[PyTree, PyTree]: 

424 # trace index where to save during burnin; out-of-bounds => noop after 

425 # burnin 

426 burnin_idx = i_total 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

427 

428 # trace index where to save during main phase; force it out-of-bounds 

429 # during burnin 

430 main_idx = (i_total - n_burn) // n_skip 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

431 noop_idx = jnp.iinfo(jnp.int32).max 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

432 noop_cond = i_total < n_burn 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

433 main_idx = jnp.where(noop_cond, noop_idx, main_idx) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

434 

435 # prepare array index 

436 num_chains = get_num_chains(bart) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

437 burnin_trace = _set(burnin_trace, burnin_idx, burnin_extractor(bart), num_chains) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

438 main_trace = _set(main_trace, main_idx, main_extractor(bart), num_chains) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

439 

440 return burnin_trace, main_trace 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

441 

442 

443def _set( 

444 trace: PyTree[Array, ' T'], 

445 index: Int32[Array, ''], 

446 val: PyTree[Array, ' T'], 

447 num_chains: int | None, 

448) -> PyTree[Array, ' T']: 

449 """Do ``trace[index] = val`` but fancier.""" 

450 chain_axis = chain_vmap_axes(val) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

451 

452 def at_set( 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

453 trace: Shaped[Array, 'chains samples *shape'] 

454 | Shaped[Array, ' samples *shape'] 

455 | None, 

456 val: Shaped[Array, ' chains *shape'] | Shaped[Array, '*shape'] | None, 

457 chain_axis: int | None, 

458 ): 

459 if trace is None or trace.size == 0: 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

460 # this handles the case where an array is empty because jax refuses 

461 # to index into an axis of length 0, even if just in the abstract, 

462 # and optional elements that are considered leaves due to `is_leaf` 

463 # below needed to traverse `chain_axis`. 

464 return trace 1fgFDEwizrmnkBtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

465 

466 if num_chains is None or chain_axis is None: 2a b f d o g h F D E s w l G x H y p i j q z A r m n c k e I B C t u v T U V W N O P X Q R S Y Z 0 1 ub

467 ndindex = (index, ...) 1abfdoghFswlGxHypijqzArmnckeIBCtuvTUVNOPQRSYZ01

468 else: 

469 ndindex = (slice(None), index, ...) 2a b f d g h D E w l x y i j z A k e B C t u v W N O P X Q R S ub

470 

471 return trace.at[ndindex].set(val, mode='drop') 1abfdoghFDEswlGxHypijqzArmnckeIBCtuvTUVWNOPXQRSYZ01

472 

473 return tree.map(at_set, trace, val, chain_axis, is_leaf=lambda x: x is None) 1abfdoghFDEswlGxHypijqzArmnckeIBCtuv'TUVWNOPXQRSYZ0123456789!#$%()*+

474 

475 

476def make_default_callback( 

477 *, 

478 dot_every: int | Integer[Array, ''] | None = 1, 

479 report_every: int | Integer[Array, ''] | None = 100, 

480) -> dict[str, Any]: 

481 """ 

482 Prepare a default callback for `run_mcmc`. 

483 

484 The callback prints a dot on every iteration, and a longer 

485 report outer loop iteration, and can do variable selection. 

486 

487 Parameters 

488 ---------- 

489 dot_every 

490 A dot is printed every `dot_every` MCMC iterations, `None` to disable. 

491 report_every 

492 A one line report is printed every `report_every` MCMC iterations, 

493 `None` to disable. 

494 

495 Returns 

496 ------- 

497 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback. 

498 

499 Examples 

500 -------- 

501 >>> run_mcmc(..., **make_default_callback()) 

502 """ 

503 

504 def asarray_or_none(val: None | Any) -> None | Array: 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^

505 return None if val is None else jnp.asarray(val) 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^

506 

507 return dict( 2L , a M - K b f d hb` ibjb{ kbo g h F D E s w l G | x H } y p i j q z A ~ _ ablbbbmbnbcbob/ db: r m n ebpbqbc k e rbfb. I B C t u sbgbtbv ^

508 callback=print_callback, 

509 callback_state=PrintCallbackState( 

510 asarray_or_none(dot_every), asarray_or_none(report_every) 

511 ), 

512 ) 

513 

514 

515class PrintCallbackState(Module): 

516 """State for `print_callback`. 

517 

518 Parameters 

519 ---------- 

520 dot_every 

521 A dot is printed every `dot_every` MCMC iterations, `None` to disable. 

522 report_every 

523 A one line report is printed every `report_every` MCMC iterations, 

524 `None` to disable. 

525 """ 

526 

527 dot_every: Int32[Array, ''] | None 

528 report_every: Int32[Array, ''] | None 

529 

530 

531def print_callback( 

532 *, 

533 bart: State, 

534 burnin: Bool[Array, ''], 

535 i_total: Int32[Array, ''], 

536 n_burn: Int32[Array, ''], 

537 n_save: Int32[Array, ''], 

538 n_skip: Int32[Array, ''], 

539 callback_state: PrintCallbackState, 

540 **_, 

541): 

542 """Print a dot and/or a report periodically during the MCMC.""" 

543 report_every = callback_state.report_every 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

544 dot_every = callback_state.dot_every 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

545 it = i_total + 1 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

546 

547 def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']: 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

548 return False if every is None else it % every == 0 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

549 

550 report_cond = get_cond(report_every) 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

551 dot_cond = get_cond(dot_every) 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

552 

553 def line_report_branch(): 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

554 if report_every is None: 1LaMKbfdoghFDEswlGxHypijqzArmnckeIBCtuv

555 return 1fgswlizrmnkB

556 if dot_every is None: 1LaMKbdohFDEGxHypjqAceICtuv

557 print_newline = False 1FDE

558 else: 

559 print_newline = it % report_every > it % dot_every 1LaMKbdohGxHypjqAceICtuv

560 debug.callback( 1LaMKbdohFDEGxHypjqAceICtuv

561 _print_report, 

562 print_dot=dot_cond, 

563 print_newline=print_newline, 

564 burnin=burnin, 

565 it=it, 

566 n_iters=n_burn + n_save * n_skip, 

567 num_chains=bart.forest.num_chains(), 

568 grow_prop_count=bart.forest.grow_prop_count.mean(), 

569 grow_acc_count=bart.forest.grow_acc_count.mean(), 

570 prune_acc_count=bart.forest.prune_acc_count.mean(), 

571 prop_total=bart.forest.split_tree.shape[-2], 

572 fill=forest_fill(bart.forest.split_tree), 

573 ) 

574 

575 def just_dot_branch(): 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

576 if dot_every is None: 1LaMKbfdoghFDEswlGxHypijqzArmnckeIBCtuv

577 return 1fgFDEswlizrmnkB

578 debug.callback( 1LaMKbdohGxHypjqAceICtuv

579 lambda: print('.', end='', flush=True) # noqa: T201 

580 ) 

581 # logging can't do in-line printing so we use print 

582 

583 cond_if_not_profiling( 1L,aM-KbfdoghFDEswlGxHypijqzArmnckeIBCtuv

584 report_cond, 

585 line_report_branch, 

586 lambda: cond_if_not_profiling(dot_cond, just_dot_branch, lambda: None), 

587 ) 

588 

589 

590def _convert_jax_arrays_in_args(func: Callable) -> Callable: 

591 """Remove jax arrays from a function arguments. 

592 

593 Converts all `jax.Array` instances in the arguments to either Python scalars 

594 or numpy arrays. 

595 """ 

596 

597 def convert_jax_arrays(pytree: PyTree) -> PyTree: 

598 def convert_jax_array(val: Any) -> Any: 1LaMKb

599 if not isinstance(val, Array): 599 ↛ 600line 599 didn't jump to line 600 because the condition on line 599 was never true1LaMKb

600 return val 

601 elif val.shape: 601 ↛ 602line 601 didn't jump to line 602 because the condition on line 601 was never true1LaMKb

602 return numpy.array(val) 

603 else: 

604 return val.item() 1LaMKb

605 

606 return tree.map(convert_jax_array, pytree) 1LaMKb

607 

608 @wraps(func) 

609 def new_func(*args, **kw): 

610 args = convert_jax_arrays(args) 1LaMKb

611 kw = convert_jax_arrays(kw) 1LaMKb

612 return func(*args, **kw) 1LaMKb

613 

614 return new_func 

615 

616 

617@_convert_jax_arrays_in_args 

618# convert all jax arrays in arguments because operations on them could lead to 

619# deadlock with the main thread 

620def _print_report( 

621 *, 

622 print_dot: bool, 

623 print_newline: bool, 

624 burnin: bool, 

625 it: int, 

626 n_iters: int, 

627 num_chains: int | None, 

628 grow_prop_count: float, 

629 grow_acc_count: float, 

630 prune_acc_count: float, 

631 prop_total: int, 

632 fill: float, 

633): 

634 """Print the report for `print_callback`.""" 

635 # compute fractions 

636 grow_prop = grow_prop_count / prop_total 1LaMKb

637 move_acc = (grow_acc_count + prune_acc_count) / prop_total 1LaMKb

638 

639 # determine prefix 

640 if print_dot: 640 ↛ 642line 640 didn't jump to line 642 because the condition on line 640 was always true1LaMKb

641 prefix = '.\n' 1LaMKb

642 elif print_newline: 

643 prefix = '\n' 

644 else: 

645 prefix = '' 

646 

647 # determine suffix in parentheses 

648 msgs = [] 1LaMKb

649 if num_chains is not None: 1LaMKbc

650 msgs.append(f'avg. {num_chains} chains') 1aKb

651 if burnin: 1LaMKbc

652 msgs.append('burnin') 1LaMKb

653 suffix = f' ({", ".join(msgs)})' if msgs else '' 1LaMKbc

654 

655 print( # noqa: T201, see print_callback for why not logging 1LaMKbc

656 f'{prefix}Iteration {it}/{n_iters}, ' 

657 f'grow prob: {grow_prop:.0%}, ' 

658 f'move acc: {move_acc:.0%}, ' 

659 f'fill: {fill:.0%}{suffix}' 

660 ) 

661 

662 

663class Trace(TreeHeaps, Protocol): 

664 """Protocol for a MCMC trace.""" 

665 

666 offset: Float32[Array, '*trace_shape'] 

667 

668 

669class TreesTrace(Module): 

670 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace.""" 

671 

672 leaf_tree: ( 

673 Float32[Array, '*trace_shape num_trees 2**d'] 

674 | Float32[Array, '*trace_shape num_trees k 2**d'] 

675 ) 

676 var_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)'] 

677 split_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)'] 

678 

679 @classmethod 

680 def from_dataclass(cls, obj: TreeHeaps): 

681 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`.""" 

682 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) 2a ; = ? @ ] [ vbb wbf xbd o g h s l p i j q / : r m n c k e .

683 

684 

685@jit 

686def evaluate_trace( 

687 X: UInt[Array, 'p n'], trace: Trace 

688) -> Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']: 

689 """ 

690 Compute predictions for all iterations of the BART MCMC. 

691 

692 Parameters 

693 ---------- 

694 X 

695 The predictors matrix, with `p` predictors and `n` observations. 

696 trace 

697 A main trace of the BART MCMC, as returned by `run_mcmc`. 

698 

699 Returns 

700 ------- 

701 The predictions for each chain and iteration of the MCMC. 

702 """ 

703 # per-device memory limit 

704 max_io_nbytes = 2**27 # 128 MiB 1a;=?@][bfdoghslpijq/:rmncke.

705 

706 # adjust memory limit for number of devices 

707 mesh = jax.typeof(trace.leaf_tree).sharding.mesh 1a;=?@][bfdoghslpijq/:rmncke.

708 num_devices = get_axis_size(mesh, 'chains') * get_axis_size(mesh, 'data') 1a;=?@][bfdoghslpijq/:rmncke.

709 max_io_nbytes *= num_devices 1a;=?@][bfdoghslpijq/:rmncke.

710 

711 # determine batching axes 

712 has_chains = trace.split_tree.ndim > 3 # chains, samples, trees, nodes 1a;=?@][bfdoghslpijq/:rmncke.

713 if has_chains: 1a;=?@][bfdoghslpijq/:rmncke.

714 sample_axis = 1 1a@][bfdghlij:ke.

715 tree_axis = 2 1a@][bfdghlij:ke.

716 else: 

717 sample_axis = 0 1;=?ospq/rmnc

718 tree_axis = 1 1;=?ospq/rmnc

719 

720 # batch and sum over trees 

721 batched_eval = autobatch( 1a;=?@][bfdoghslpijq/:rmncke.

722 evaluate_forest, 

723 max_io_nbytes, 

724 (None, tree_axis), 

725 tree_axis, 

726 reduce_ufunc=jnp.add, 

727 ) 

728 

729 # determine output shape (to avoid autobatch tracing everything 4 times) 

730 is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim 1a;=?@][bfdoghslpijq/:rmncke.

731 k = trace.leaf_tree.shape[-2] if is_mv else 1 1a;=?@][bfdoghslpijq/:rmncke.

732 mv_shape = (k,) if is_mv else () 1a;=?@][bfdoghslpijq/:rmncke.

733 _, n = X.shape 1a;=?@][bfdoghslpijq/:rmncke.

734 out_shape = (*trace.split_tree.shape[:-2], *mv_shape, n) 1a;=?@][bfdoghslpijq/:rmncke.

735 

736 # adjust memory limit keeping into account that trees are summed over 

737 num_trees, hts = trace.split_tree.shape[-2:] 1a;=?@][bfdoghslpijq/:rmncke.

738 out_size = k * n * jnp.float32.dtype.itemsize # the value of the forest 1a;=?@][bfdoghslpijq/:rmncke.

739 core_io_size = ( 1a;=?@][bfdoghslpijq/:rmncke.

740 num_trees 

741 * hts 

742 * ( 

743 2 * k * trace.leaf_tree.itemsize 

744 + trace.var_tree.itemsize 

745 + trace.split_tree.itemsize 

746 ) 

747 + out_size 

748 ) 

749 core_int_size = (num_trees - 1) * out_size 1a;=?@][bfdoghslpijq/:rmncke.

750 max_io_nbytes = max(1, floor(max_io_nbytes / (1 + core_int_size / core_io_size))) 1a;=?@][bfdoghslpijq/:rmncke.

751 

752 # batch over mcmc samples 

753 batched_eval = autobatch( 1a;=?@][bfdoghslpijq/:rmncke.

754 batched_eval, 

755 max_io_nbytes, 

756 (None, sample_axis), 

757 sample_axis, 

758 warn_on_overflow=False, # the inner autobatch will handle it 

759 result_shape_dtype=ShapeDtypeStruct(out_shape, jnp.float32), 

760 ) 

761 

762 # extract only the trees from the trace 

763 trees = TreesTrace.from_dataclass(trace) 1a;=?@][bfdoghslpijq/:rmncke.

764 

765 # evaluate trees 

766 y_centered: Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n'] 

767 y_centered = batched_eval(X, trees) 1a;=?@][bfdoghslpijq/:rmncke.

768 return y_centered + trace.offset[..., None] 1a;=?@][bfdoghslpijq/:rmncke.

769 

770 

771@partial(jit, static_argnums=(0,)) 

772def compute_varcount(p: int, trace: TreeHeaps) -> Int32[Array, '*trace_shape {p}']: 

773 """ 

774 Count how many times each predictor is used in each MCMC state. 

775 

776 Parameters 

777 ---------- 

778 p 

779 The number of predictors. 

780 trace 

781 A main trace of the BART MCMC, as returned by `run_mcmc`. 

782 

783 Returns 

784 ------- 

785 Histogram of predictor usage in each MCMC state. 

786 """ 

787 # var_tree has shape (chains? samples trees nodes) 

788 return var_histogram(p, trace.var_tree, trace.split_tree, sum_batch_axis=-1) 2; = ? @ [ ~ _ abm n .