Coverage for src / bartz / mcmcloop.py: 95%

222 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 18:11 +0000

1# bartz/src/bartz/mcmcloop.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions that implement the full BART posterior MCMC loop. 

26 

27The entry points are `run_mcmc` and `make_default_callback`. 

28""" 

29 

30from collections.abc import Callable 

31from functools import partial, update_wrapper, wraps 

32from math import floor 

33from typing import Any, NamedTuple, Protocol, TypeVar 

34 

35import jax 

36import numpy 

37from equinox import Module 

38from jax import ( 

39 NamedSharding, 

40 ShapeDtypeStruct, 

41 debug, 

42 device_put, 

43 eval_shape, 

44 jit, 

45 lax, 

46 named_call, 

47 tree, 

48) 

49from jax import numpy as jnp 

50from jax.nn import softmax 

51from jax.sharding import Mesh, PartitionSpec 

52from jaxtyping import ( 

53 Array, 

54 ArrayLike, 

55 Bool, 

56 Float32, 

57 Int32, 

58 Integer, 

59 Key, 

60 PyTree, 

61 Shaped, 

62 UInt, 

63) 

64 

65from bartz import jaxext, mcmcstep 

66from bartz.grove import ( 

67 TreeHeaps, 

68 TreesTrace, 

69 evaluate_forest, 

70 forest_fill, 

71 var_histogram, 

72) 

73from bartz.jaxext import autobatch, jit_active 

74from bartz.mcmcstep import State 

75from bartz.mcmcstep._state import chain_vmap_axes, field, get_axis_size, get_num_chains 

76 

77 

78class BurninTrace(Module): 

79 """MCMC trace with only diagnostic values.""" 

80 

81 error_cov_inv: ( 

82 Float32[Array, '*chains_and_samples'] 

83 | Float32[Array, '*chains_and_samples k k'] 

84 ) = field(chains=True) 

85 theta: Float32[Array, '*chains_and_samples'] | None = field(chains=True) 

86 grow_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True) 

87 grow_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True) 

88 prune_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True) 

89 prune_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True) 

90 log_likelihood: Float32[Array, '*chains_and_samples'] | None = field(chains=True) 

91 log_trans_prior: Float32[Array, '*chains_and_samples'] | None = field(chains=True) 

92 

93 @classmethod 

94 def from_state(cls, state: State) -> 'BurninTrace': 

95 """Create a single-item burn-in trace from a MCMC state.""" 

96 return cls( 1ad

97 error_cov_inv=state.error_cov_inv, 

98 theta=state.forest.theta, 

99 grow_prop_count=state.forest.grow_prop_count, 

100 grow_acc_count=state.forest.grow_acc_count, 

101 prune_prop_count=state.forest.prune_prop_count, 

102 prune_acc_count=state.forest.prune_acc_count, 

103 log_likelihood=state.forest.log_likelihood, 

104 log_trans_prior=state.forest.log_trans_prior, 

105 ) 

106 

107 

108class MainTrace(BurninTrace): 

109 """MCMC trace with trees and diagnostic values.""" 

110 

111 leaf_tree: ( 

112 Float32[Array, '*chains_and_samples 2**d'] 

113 | Float32[Array, '*chains_and_samples k 2**d'] 

114 ) = field(chains=True) 

115 var_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True) 

116 split_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True) 

117 offset: Float32[Array, '*samples'] | Float32[Array, '*samples k'] 

118 varprob: Float32[Array, '*chains_and_samples p'] | None = field(chains=True) 

119 

120 @classmethod 

121 def from_state(cls, state: State) -> 'MainTrace': 

122 """Create a single-item main trace from a MCMC state.""" 

123 # compute varprob 

124 log_s = state.forest.log_s 1ad

125 if log_s is None: 1ajecd

126 varprob = None 1jd

127 else: 

128 varprob = softmax(log_s, where=state.forest.max_split.astype(bool)) 1aec

129 

130 return cls( 1ad

131 leaf_tree=state.forest.leaf_tree, 

132 var_tree=state.forest.var_tree, 

133 split_tree=state.forest.split_tree, 

134 offset=state.offset, 

135 varprob=varprob, 

136 **vars(BurninTrace.from_state(state)), 

137 ) 

138 

139 

140CallbackState = PyTree[Any, 'T'] 

141 

142 

143class RunMCMCResult(NamedTuple): 

144 """Return value of `run_mcmc`.""" 

145 

146 final_state: State 

147 """The final MCMC state.""" 

148 

149 burnin_trace: PyTree[ 

150 Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...'] 

151 ] 

152 """The trace of the burn-in phase. For the default layout, see `BurninTrace`.""" 

153 

154 main_trace: PyTree[ 

155 Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...'] 

156 ] 

157 """The trace of the main phase. For the default layout, see `MainTrace`.""" 

158 

159 

160class Callback(Protocol): 

161 """Callback type for `run_mcmc`.""" 

162 

163 def __call__( 

164 self, 

165 *, 

166 key: Key[Array, ''], 

167 bart: State, 

168 burnin: Bool[Array, ''], 

169 i_total: Int32[Array, ''], 

170 callback_state: CallbackState, 

171 n_burn: Int32[Array, ''], 

172 n_save: Int32[Array, ''], 

173 n_skip: Int32[Array, ''], 

174 i_outer: Int32[Array, ''], 

175 inner_loop_length: Int32[Array, ''], 

176 ) -> tuple[State, CallbackState] | None: 

177 """Do an arbitrary action after an iteration of the MCMC. 

178 

179 Parameters 

180 ---------- 

181 key 

182 A key for random number generation. 

183 bart 

184 The MCMC state just after updating it. 

185 burnin 

186 Whether the last iteration was in the burn-in phase. 

187 i_total 

188 The index of the last MCMC iteration (0-based). 

189 callback_state 

190 The callback state, initially set to the argument passed to 

191 `run_mcmc`, afterwards to the value returned by the last invocation 

192 of the callback. 

193 n_burn 

194 n_save 

195 n_skip 

196 The corresponding `run_mcmc` arguments as-is. 

197 i_outer 

198 The index of the last outer loop iteration (0-based). 

199 inner_loop_length 

200 The number of MCMC iterations in the inner loop. 

201 

202 Returns 

203 ------- 

204 bart : State 

205 A possibly modified MCMC state. To avoid modifying the state, 

206 return the `bart` argument passed to the callback as-is. 

207 callback_state : CallbackState 

208 The new state to be passed on the next callback invocation. 

209 

210 Notes 

211 ----- 

212 For convenience, the callback may return `None`, and the states won't 

213 be updated. 

214 """ 

215 ... 

216 

217 

218class _Carry(Module): 

219 """Carry used in the loop in `run_mcmc`.""" 

220 

221 bart: State 

222 i_total: Int32[Array, ''] 

223 key: Key[Array, ''] 

224 burnin_trace: PyTree[ 

225 Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...'] 

226 ] 

227 main_trace: PyTree[ 

228 Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...'] 

229 ] 

230 callback_state: CallbackState 

231 

232 

233def run_mcmc( 

234 key: Key[Array, ''], 

235 bart: State, 

236 n_save: int, 

237 *, 

238 n_burn: int = 0, 

239 n_skip: int = 1, 

240 inner_loop_length: int | None = None, 

241 callback: Callback | None = None, 

242 callback_state: CallbackState = None, 

243 burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state, 

244 main_extractor: Callable[[State], PyTree] = MainTrace.from_state, 

245) -> RunMCMCResult: 

246 """ 

247 Run the MCMC for the BART posterior. 

248 

249 Parameters 

250 ---------- 

251 key 

252 A key for random number generation. 

253 bart 

254 The initial MCMC state, as created and updated by the functions in 

255 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies, 

256 so this variable is invalidated after running `run_mcmc`. Make a copy 

257 beforehand to use it again. 

258 n_save 

259 The number of iterations to save. 

260 n_burn 

261 The number of initial iterations which are not saved. 

262 n_skip 

263 The number of iterations to skip between each saved iteration, plus 1. 

264 The effective burn-in is ``n_burn + n_skip - 1``. 

265 inner_loop_length 

266 The MCMC loop is split into an outer and an inner loop. The outer loop 

267 is in Python, while the inner loop is in JAX. `inner_loop_length` is the 

268 number of iterations of the inner loop to run for each iteration of the 

269 outer loop. If not specified, the outer loop will iterate just once, 

270 with all iterations done in a single inner loop run. The inner stride is 

271 unrelated to the stride used for saving the trace. 

272 callback 

273 An arbitrary function run during the loop after updating the state. For 

274 the signature, see `Callback`. The callback is called under the jax jit, 

275 so the argument values are not available at the time the Python code is 

276 executed. Use the utilities in `jax.debug` to access the values at 

277 actual runtime. The callback may return new values for the MCMC state 

278 and the callback state. 

279 callback_state 

280 The initial custom state for the callback. 

281 burnin_extractor 

282 main_extractor 

283 Functions that extract the variables to be saved respectively in the 

284 burnin trace and main traces, given the MCMC state as argument. Must 

285 return a pytree, and must be vmappable. 

286 

287 Returns 

288 ------- 

289 A namedtuple with the final state, the burn-in trace, and the main trace. 

290 

291 Raises 

292 ------ 

293 RuntimeError 

294 If `run_mcmc` detects it's being invoked in a `jit`-wrapped context and 

295 with settings that would create unrolled loops in the trace. 

296 

297 Notes 

298 ----- 

299 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do 

300 not include the initial state, and include the final state. 

301 """ 

302 # create empty traces 

303 burnin_trace = _empty_trace(n_burn, bart, burnin_extractor) 1ad

304 main_trace = _empty_trace(n_save, bart, main_extractor) 1ad

305 

306 # determine number of iterations for inner and outer loops 

307 n_iters = n_burn + n_skip * n_save 1ad

308 if inner_loop_length is None: 1agdk

309 inner_loop_length = n_iters 1gk

310 if inner_loop_length: 1atudk

311 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length) 1ad

312 else: 

313 n_outer = 1 1tuk

314 # setting to 0 would make for a clean noop, but it's useful to keep the 

315 # same code path for benchmarking and testing 

316 

317 # error if under jit and there are unrolled loops 

318 if jit_active() and n_outer > 1: 1agvdn

319 msg = ( 1n

320 '`run_mcmc` was called within a jit-compiled function and ' 

321 'there are more than 1 outer loops, ' 

322 'please either do not jit or set `inner_loop_length=None`' 

323 ) 

324 raise RuntimeError(msg) 1n

325 

326 replicate = partial(_replicate, mesh=bart.config.mesh) 1agvd

327 carry = _Carry( 1ad

328 bart, 

329 replicate(jnp.int32(0)), 

330 replicate(key), 

331 burnin_trace, 

332 main_trace, 

333 callback_state, 

334 ) 

335 _run_mcmc_inner_loop._fun.reset_call_counter() # noqa: SLF001 1ad

336 for i_outer in range(n_outer): 1ad

337 carry = _run_mcmc_inner_loop( 1ad

338 carry, 

339 inner_loop_length, 

340 callback, 

341 burnin_extractor, 

342 main_extractor, 

343 n_burn, 

344 n_save, 

345 n_skip, 

346 i_outer, 

347 n_iters, 

348 ) 

349 

350 return RunMCMCResult(carry.bart, carry.burnin_trace, carry.main_trace) 1ad

351 

352 

353def _replicate(x: Array, mesh: Mesh | None) -> Array: 

354 if mesh is None: 1wajhd

355 return x 1ad

356 else: 

357 return device_put(x, NamedSharding(mesh, PartitionSpec())) 1wjh

358 

359 

360@partial(jit, static_argnums=(0, 2)) 

361def _empty_trace( 

362 length: int, bart: State, extractor: Callable[[State], PyTree] 

363) -> PyTree: 

364 num_chains = get_num_chains(bart) 1ad

365 if num_chains is None: 1aild

366 out_axes = 0 1id

367 else: 

368 example_output = eval_shape(extractor, bart) 1al

369 chain_axes = chain_vmap_axes(example_output) 1al

370 out_axes = tree.map( 1al

371 lambda a: 0 if a is None else 1, chain_axes, is_leaf=lambda a: a is None 

372 ) 

373 return jax.vmap(extractor, in_axes=None, out_axes=out_axes, axis_size=length)(bart) 1ad

374 

375 

376T = TypeVar('T') 

377 

378 

379class _CallCounter: 

380 """Wrap a callable to check it's not called more than once.""" 

381 

382 def __init__(self, func: Callable[..., T]) -> None: 

383 self.func = func 

384 self.n_calls = 0 

385 update_wrapper(self, func) 

386 

387 def reset_call_counter(self) -> None: 

388 """Reset the call counter.""" 

389 self.n_calls = 0 1ad

390 

391 def __call__(self, *args: Any, **kwargs: Any) -> T: 

392 if self.n_calls: 1aod

393 msg = ( 1o

394 'The inner loop of `run_mcmc` was traced more than once, ' 

395 'which indicates a double compilation of the MCMC code. This ' 

396 'probably depends on the input state having different type from the ' 

397 'output state. Check the input is in a format that is the ' 

398 'same jax would output, e.g., all arrays and scalars are jax ' 

399 'arrays, with the right shardings.' 

400 ) 

401 raise RuntimeError(msg) 1o

402 self.n_calls += 1 1ad

403 return self.func(*args, **kwargs) 1ad

404 

405 

406@partial(jit, donate_argnums=(0,), static_argnums=(2, 3, 4)) 

407@_CallCounter 

408def _run_mcmc_inner_loop( 

409 carry: _Carry, 

410 inner_loop_length: Int32[Array, ''], 

411 callback: Callback | None, 

412 burnin_extractor: Callable[[State], PyTree], 

413 main_extractor: Callable[[State], PyTree], 

414 n_burn: Int32[Array, ''], 

415 n_save: Int32[Array, ''], 

416 n_skip: Int32[Array, ''], 

417 i_outer: Int32[Array, ''], 

418 n_iters: Int32[Array, ''], 

419) -> _Carry: 

420 # determine number of iterations for this loop batch 

421 i_upper = jnp.minimum(carry.i_total + inner_loop_length, n_iters) 1ad

422 

423 def cond(carry: _Carry) -> Bool[Array, '']: 1ad

424 """Whether to continue the MCMC loop.""" 

425 return carry.i_total < i_upper 1ad

426 

427 def body(carry: _Carry) -> _Carry: 1ad

428 """Update the MCMC state.""" 

429 # split random key 

430 keys = jaxext.split(carry.key, 3) 1ad

431 key = keys.pop() 1ad

432 

433 # update state 

434 bart = mcmcstep.step(keys.pop(), carry.bart) 1ad

435 

436 # invoke callback 

437 callback_state = carry.callback_state 1ad

438 if callback is not None: 1aecd

439 rt = callback( 1aec

440 key=keys.pop(), 

441 bart=bart, 

442 burnin=carry.i_total < n_burn, 

443 i_total=carry.i_total, 

444 callback_state=callback_state, 

445 n_burn=n_burn, 

446 n_save=n_save, 

447 n_skip=n_skip, 

448 i_outer=i_outer, 

449 inner_loop_length=inner_loop_length, 

450 ) 

451 if rt is not None: 451 ↛ 452line 451 didn't jump to line 452 because the condition on line 451 was never true1aec

452 bart, callback_state = rt 

453 

454 # save to trace 

455 burnin_trace, main_trace = _save_state_to_trace( 1aecd

456 carry.burnin_trace, 

457 carry.main_trace, 

458 burnin_extractor, 

459 main_extractor, 

460 bart, 

461 carry.i_total, 

462 n_burn, 

463 n_skip, 

464 ) 

465 

466 return _Carry( 1ad

467 bart=bart, 

468 i_total=carry.i_total + 1, 

469 key=key, 

470 burnin_trace=burnin_trace, 

471 main_trace=main_trace, 

472 callback_state=callback_state, 

473 ) 

474 

475 return lax.while_loop(cond, body, carry) 1ad

476 

477 

478@named_call 

479def _save_state_to_trace( 

480 burnin_trace: PyTree, 

481 main_trace: PyTree, 

482 burnin_extractor: Callable[[State], PyTree], 

483 main_extractor: Callable[[State], PyTree], 

484 bart: State, 

485 i_total: Int32[Array, ''], 

486 n_burn: Int32[Array, ''], 

487 n_skip: Int32[Array, ''], 

488) -> tuple[PyTree, PyTree]: 

489 # trace index where to save during burnin; out-of-bounds => noop after 

490 # burnin 

491 burnin_idx = i_total 1ad

492 

493 # trace index where to save during main phase; force it out-of-bounds 

494 # during burnin 

495 main_idx = (i_total - n_burn) // n_skip 1ad

496 noop_idx = jnp.iinfo(jnp.int32).max 1ad

497 noop_cond = i_total < n_burn 1ad

498 main_idx = jnp.where(noop_cond, noop_idx, main_idx) 1ad

499 

500 # prepare array index 

501 num_chains = get_num_chains(bart) 1ad

502 burnin_trace = _set(burnin_trace, burnin_idx, burnin_extractor(bart), num_chains) 1ad

503 main_trace = _set(main_trace, main_idx, main_extractor(bart), num_chains) 1ad

504 

505 return burnin_trace, main_trace 1ad

506 

507 

508def _set( 

509 trace: PyTree[Array, ' T'], 

510 index: Int32[Array, ''], 

511 val: PyTree[Array, ' T'], 

512 num_chains: int | None, 

513) -> PyTree[Array, ' T']: 

514 """Do ``trace[index] = val`` but fancier.""" 

515 chain_axis = chain_vmap_axes(val) 1ad

516 

517 def at_set( 1ad

518 trace: Shaped[Array, 'chains samples *shape'] 

519 | None 

520 | Shaped[Array, ' samples *shape'] 

521 | None, 

522 val: Shaped[Array, ' chains *shape'] | Shaped[Array, '*shape'] | None, 

523 chain_axis: int | None, 

524 ) -> Shaped[Array, 'chains samples *shape'] | None: 

525 if trace is None or trace.size == 0: 1xamd

526 # this handles the case where an array is empty because jax refuses 

527 # to index into an axis of length 0, even if just in the abstract, 

528 # and optional elements that are considered leaves due to `is_leaf` 

529 # below needed to traverse `chain_axis`. 

530 return trace 1xmd

531 

532 if num_chains is None or chain_axis is None: 1aipdq

533 ndindex = (index, ...) 1aipdq

534 else: 

535 ndindex = (slice(None), index, ...) 1apq

536 

537 return trace.at[ndindex].set(val, mode='drop') 1ad

538 

539 return tree.map(at_set, trace, val, chain_axis, is_leaf=lambda x: x is None) 1ad

540 

541 

542def make_default_callback( 

543 state: State, 

544 *, 

545 dot_every: int | Integer[Array, ''] | None = 1, 

546 report_every: int | Integer[Array, ''] | None = 100, 

547) -> dict[str, Any]: 

548 """ 

549 Prepare a default callback for `run_mcmc`. 

550 

551 The callback prints a dot on every iteration, and a longer 

552 report outer loop iteration, and can do variable selection. 

553 

554 Parameters 

555 ---------- 

556 state 

557 The bart state to use the callback with, used to determine device 

558 sharding. 

559 dot_every 

560 A dot is printed every `dot_every` MCMC iterations, `None` to disable. 

561 report_every 

562 A one line report is printed every `report_every` MCMC iterations, 

563 `None` to disable. 

564 

565 Returns 

566 ------- 

567 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback. 

568 

569 Examples 

570 -------- 

571 >>> run_mcmc(key, state, ..., **make_default_callback(state, ...)) 

572 """ 

573 

574 def as_replicated_array_or_none(val: ArrayLike | None) -> None | Array: 1aec

575 return None if val is None else _replicate(jnp.asarray(val), state.config.mesh) 1aghec

576 

577 return dict( 1aec

578 callback=print_callback, 

579 callback_state=PrintCallbackState( 

580 as_replicated_array_or_none(dot_every), 

581 as_replicated_array_or_none(report_every), 

582 ), 

583 ) 

584 

585 

586class PrintCallbackState(Module): 

587 """State for `print_callback`.""" 

588 

589 dot_every: Int32[Array, ''] | None 

590 """A dot is printed every `dot_every` MCMC iterations, `None` to disable.""" 

591 

592 report_every: Int32[Array, ''] | None 

593 """A one line report is printed every `report_every` MCMC iterations, 

594 `None` to disable.""" 

595 

596 

597def print_callback( 

598 *, 

599 bart: State, 

600 burnin: Bool[Array, ''], 

601 i_total: Int32[Array, ''], 

602 n_burn: Int32[Array, ''], 

603 n_save: Int32[Array, ''], 

604 n_skip: Int32[Array, ''], 

605 callback_state: PrintCallbackState, 

606 **_: Any, 

607) -> None: 

608 """Print a dot and/or a report periodically during the MCMC.""" 

609 report_every = callback_state.report_every 1aec

610 dot_every = callback_state.dot_every 1aec

611 it = i_total + 1 1aec

612 

613 def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']: 1aec

614 return False if every is None else it % every == 0 1aghec

615 

616 report_cond = get_cond(report_every) 1aec

617 dot_cond = get_cond(dot_every) 1aec

618 

619 def line_report_branch() -> None: 1aec

620 if report_every is None: 1aghec

621 return 1gh

622 if dot_every is None: 1ameyc

623 print_newline = False 1my

624 else: 

625 print_newline = it % report_every > it % dot_every 1aec

626 debug.callback( 1aec

627 _print_report, 

628 print_dot=dot_cond, 

629 print_newline=print_newline, 

630 burnin=burnin, 

631 it=it, 

632 n_iters=n_burn + n_save * n_skip, 

633 num_chains=bart.forest.num_chains(), 

634 grow_prop_count=bart.forest.grow_prop_count.mean(), 

635 grow_acc_count=bart.forest.grow_acc_count.mean(), 

636 prune_acc_count=bart.forest.prune_acc_count.mean(), 

637 prop_total=bart.forest.split_tree.shape[-2], 

638 fill=forest_fill(bart.forest.split_tree), 

639 ) 

640 

641 def just_dot_branch() -> None: 1aec

642 if dot_every is None: 1aghec

643 return 1gh

644 debug.callback( 1aec

645 lambda: print('.', end='', flush=True) # noqa: T201 

646 ) 

647 # logging can't do in-line printing so we use print 

648 

649 lax.cond( 1aec

650 report_cond, 

651 line_report_branch, 

652 lambda: lax.cond(dot_cond, just_dot_branch, lambda: None), 

653 ) 

654 

655 

656def _convert_jax_arrays_in_args(func: Callable[..., T]) -> Callable[..., T]: 

657 """Remove jax arrays from a function arguments. 

658 

659 Converts all `jax.Array` instances in the arguments to either Python scalars 

660 or numpy arrays. 

661 """ 

662 

663 def convert_jax_arrays(pytree: PyTree) -> PyTree: 

664 def convert_jax_array(val: object) -> object: 1aec

665 if not isinstance(val, Array): 665 ↛ 666line 665 didn't jump to line 666 because the condition on line 665 was never true1aec

666 return val 

667 elif val.shape: 667 ↛ 668line 667 didn't jump to line 668 because the condition on line 667 was never true1aec

668 return numpy.array(val) 

669 else: 

670 return val.item() 1aec

671 

672 return tree.map(convert_jax_array, pytree) 1aec

673 

674 @wraps(func) 

675 def new_func(*args: Any, **kw: Any) -> T: 

676 args = convert_jax_arrays(args) 1aec

677 kw = convert_jax_arrays(kw) 1aec

678 return func(*args, **kw) 1aec

679 

680 return new_func 

681 

682 

683@_convert_jax_arrays_in_args 

684# convert all jax arrays in arguments because operations on them could lead to 

685# deadlock with the main thread 

686def _print_report( 

687 *, 

688 print_dot: bool, 

689 print_newline: bool, 

690 burnin: bool, 

691 it: int, 

692 n_iters: int, 

693 num_chains: int | None, 

694 grow_prop_count: float, 

695 grow_acc_count: float, 

696 prune_acc_count: float, 

697 prop_total: int, 

698 fill: float, 

699) -> None: 

700 """Print the report for `print_callback`.""" 

701 # compute fractions 

702 grow_prop = grow_prop_count / prop_total 1aec

703 move_acc = (grow_acc_count + prune_acc_count) / prop_total 1aec

704 

705 # determine prefix 

706 if print_dot: 706 ↛ 708line 706 didn't jump to line 708 because the condition on line 706 was always true1aec

707 prefix = '.\n' 1aec

708 elif print_newline: 

709 prefix = '\n' 

710 else: 

711 prefix = '' 

712 

713 # determine suffix in parentheses 

714 msgs = [] 1aec

715 if num_chains is not None: 1aizec

716 msgs.append(f'avg. {num_chains} chains') 1az

717 if burnin: 1aiec

718 msgs.append('burnin') 1aec

719 suffix = f' ({", ".join(msgs)})' if msgs else '' 1aiec

720 

721 print( # noqa: T201, see print_callback for why not logging 1aiec

722 f'{prefix}Iteration {it}/{n_iters}, ' 

723 f'grow prob: {grow_prop:.0%}, ' 

724 f'move acc: {move_acc:.0%}, ' 

725 f'fill: {fill:.0%}{suffix}' 

726 ) 

727 

728 

729class Trace(TreeHeaps, Protocol): 

730 """Protocol for a MCMC trace.""" 

731 

732 offset: Float32[Array, '*trace_shape'] 

733 

734 

735@jit 

736def evaluate_trace( 

737 X: UInt[Array, 'p n'], trace: Trace 

738) -> Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']: 

739 """ 

740 Compute predictions for all iterations of the BART MCMC. 

741 

742 Parameters 

743 ---------- 

744 X 

745 The predictors matrix, with `p` predictors and `n` observations. 

746 trace 

747 A main trace of the BART MCMC, as returned by `run_mcmc`. 

748 

749 Returns 

750 ------- 

751 The predictions for each chain and iteration of the MCMC. 

752 """ 

753 # per-device memory limit 

754 max_io_nbytes = 2**27 # 128 MiB 1afc

755 

756 # adjust memory limit for number of devices 

757 mesh = jax.typeof(trace.leaf_tree).sharding.mesh 1afc

758 num_devices = get_axis_size(mesh, 'chains') * get_axis_size(mesh, 'data') 1afc

759 max_io_nbytes *= num_devices 1afc

760 

761 # determine batching axes 

762 has_chains = trace.split_tree.ndim > 3 # chains, samples, trees, nodes 1afc

763 if has_chains: 1rafc

764 sample_axis = 1 1af

765 tree_axis = 2 1af

766 else: 

767 sample_axis = 0 1rc

768 tree_axis = 1 1rc

769 

770 # batch and sum over trees 

771 batched_eval = autobatch( 1afc

772 evaluate_forest, 

773 max_io_nbytes, 

774 (None, tree_axis), 

775 tree_axis, 

776 reduce_ufunc=jnp.add, 

777 ) 

778 

779 # determine output shape (to avoid autobatch tracing everything 4 times) 

780 is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim 1afc

781 k = trace.leaf_tree.shape[-2] if is_mv else 1 1asfc

782 mv_shape = (k,) if is_mv else () 1asfc

783 _, n = X.shape 1asfc

784 out_shape = (*trace.split_tree.shape[:-2], *mv_shape, n) 1afc

785 

786 # adjust memory limit keeping into account that trees are summed over 

787 num_trees, hts = trace.split_tree.shape[-2:] 1afc

788 out_size = k * n * jnp.float32.dtype.itemsize # the value of the forest 1afc

789 core_io_size = ( 1afc

790 num_trees 

791 * hts 

792 * ( 

793 2 * k * trace.leaf_tree.itemsize 

794 + trace.var_tree.itemsize 

795 + trace.split_tree.itemsize 

796 ) 

797 + out_size 

798 ) 

799 core_int_size = (num_trees - 1) * out_size 1afc

800 max_io_nbytes = max(1, floor(max_io_nbytes / (1 + core_int_size / core_io_size))) 1afc

801 

802 # batch over mcmc samples 

803 batched_eval = autobatch( 1afc

804 batched_eval, 

805 max_io_nbytes, 

806 (None, sample_axis), 

807 sample_axis, 

808 warn_on_overflow=False, # the inner autobatch will handle it 

809 result_shape_dtype=ShapeDtypeStruct(out_shape, jnp.float32), 

810 ) 

811 

812 # extract only the trees from the trace 

813 trees = TreesTrace.from_dataclass(trace) 1afc

814 

815 # evaluate trees 

816 y_centered: Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n'] 

817 y_centered = batched_eval(X, trees) 1afc

818 return y_centered + trace.offset[..., None] 1afc

819 

820 

821@partial(jit, static_argnums=(0,)) 

822def compute_varcount(p: int, trace: TreeHeaps) -> Int32[Array, '*trace_shape {p}']: 

823 """ 

824 Count how many times each predictor is used in each MCMC state. 

825 

826 Parameters 

827 ---------- 

828 p 

829 The number of predictors. 

830 trace 

831 A main trace of the BART MCMC, as returned by `run_mcmc`. 

832 

833 Returns 

834 ------- 

835 Histogram of predictor usage in each MCMC state. 

836 """ 

837 # var_tree has shape (chains? samples trees nodes) 

838 return var_histogram(p, trace.var_tree, trace.split_tree, sum_batch_axis=-1) 1ABc