Coverage for src/bartz/mcmcloop/_callback.py: 93%

233 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/mcmcloop/_callback.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Progress-reporting callbacks for `run_mcmc`.""" 

26 

27import itertools 

28from collections.abc import Callable 

29from dataclasses import dataclass, replace 

30from functools import partial, wraps 

31from typing import Any, TypeVar 

32 

33import numpy 

34from equinox import Module, field 

35from jax import debug, eval_shape, lax, tree 

36from jax import numpy as jnp 

37from jax.scipy.special import logsumexp 

38from jaxtyping import Array, ArrayLike, Bool, Float32, Int32, Integer, PyTree, Shaped 

39from tqdm.auto import tqdm 

40 

41from bartz._typing import kwdict 

42from bartz.grove import forest_mean_leaves 

43from bartz.mcmcloop._loop import _replicate 

44from bartz.mcmcstep import State 

45from bartz.mcmcstep._axes import chain_to_axis, chain_vmap_axes, chainful_axis 

46 

47 

48class StatsReport(Module): 

49 """Forest diagnostics produced by `StatsAccumulator.report` for one report.""" 

50 

51 grow_prop: Float32[Array, ''] 

52 """Fraction of trees proposed for a grow move.""" 

53 

54 move_acc: Float32[Array, ''] 

55 """Fraction of trees on which a grow or prune move was accepted.""" 

56 

57 mean_leaves: Float32[Array, ''] 

58 """Mean number of leaves per tree.""" 

59 

60 peff: Float32[Array, ''] | None 

61 """Effective number of predictors, or `None` when variable selection is off.""" 

62 

63 n_samples: Int32[Array, ''] | None 

64 """Number of iterations averaged over, or `None` when not averaging.""" 

65 

66 num_chains: int | None = field(static=True) 

67 """Number of chains averaged over, or `None` when single-chain.""" 

68 

69 max_leaves: int = field(static=True) 

70 """Maximum possible number of leaves per tree.""" 

71 

72 p: int | None = field(static=True) 

73 """Number of predictors, or `None` when variable selection is off.""" 

74 

75 

76class StatsAccumulator(Module): 

77 """Running average of the forest diagnostics shown during the MCMC. 

78 

79 When enabled, it sums the per-iteration statistics so a report shows their 

80 average over the iterations since the previous report. When disabled it 

81 carries no running state and a report shows the latest iteration only. 

82 """ 

83 

84 sums: dict[str, Float32[Array, '']] | None 

85 """Running sums of the averaged statistics, or `None` when disabled.""" 

86 

87 count: Int32[Array, ''] 

88 """Number of iterations accumulated since the last reset.""" 

89 

90 @classmethod 

91 def initial(cls, state: State, *, enabled: bool) -> 'StatsAccumulator': 

92 """Create a zeroed accumulator, inert unless `enabled`.""" 

93 if enabled: 

94 # only the structure is needed, so avoid computing the statistics 

95 shapes = eval_shape(cls._avg_stats, state) 

96 sums = tree.map(lambda s: jnp.zeros(s.shape, s.dtype), shapes) 

97 else: 

98 sums = None 

99 return cls(sums=sums, count=jnp.int32(0)) 

100 

101 def update(self, state: State) -> 'StatsAccumulator': 

102 """Add the latest iteration's statistics; no-op when disabled.""" 

103 if self.sums is None: 

104 return self 

105 sums = tree.map(jnp.add, self.sums, self._avg_stats(state)) 

106 return replace(self, sums=sums, count=self.count + 1) 

107 

108 def reset_if(self, cond: bool | Bool[Array, '']) -> 'StatsAccumulator': 

109 """Zero the running sums where `cond` holds; no-op when disabled.""" 

110 if self.sums is None: 

111 return self 

112 sums = tree.map(lambda s: jnp.where(cond, 0, s), self.sums) 

113 return replace(self, sums=sums, count=jnp.where(cond, 0, self.count)) 

114 

115 def report(self, state: State) -> StatsReport: 

116 """Statistics to display: the windowed average if enabled, else the latest.""" 

117 if self.sums is None: 

118 averaged: kwdict = self._avg_stats(state) 

119 n_samples = None 

120 else: 

121 averaged = tree.map(lambda s: s / self.count, self.sums) 

122 n_samples = self.count 

123 return StatsReport(**averaged, **self._static_stats(state), n_samples=n_samples) 

124 

125 @staticmethod 

126 def _avg_stats(state: State) -> dict[str, Float32[Array, ''] | None]: 

127 """Per-iteration diagnostics that are averaged over the report window.""" 

128 forest = state.forest 

129 chain_axis = chain_vmap_axes(forest).split_tree 

130 num_trees_axis = chainful_axis(0, chain_axis) # (num_trees, hts) 

131 split_tree = chain_to_axis(forest.split_tree, chain_axis) 

132 prop_total = forest.split_tree.shape[num_trees_axis] 

133 

134 log_s = forest.log_s 

135 if log_s is None: 

136 peff = None 

137 else: 

138 log_s = chain_to_axis(log_s, chain_vmap_axes(forest).log_s) 

139 peff = StatsAccumulator._effective_predictors(log_s) 

140 

141 return dict( 

142 grow_prop=forest.grow_prop_count.mean() / prop_total, 

143 move_acc=(forest.grow_acc_count.mean() + forest.prune_acc_count.mean()) 

144 / prop_total, 

145 mean_leaves=forest_mean_leaves(split_tree), 

146 peff=peff, 

147 ) 

148 

149 @staticmethod 

150 def _static_stats(state: State) -> dict[str, int | None]: 

151 """Per-iteration diagnostics shown as-is, constant over the run.""" 

152 forest = state.forest 

153 split_tree = chain_to_axis( 

154 forest.split_tree, chain_vmap_axes(forest).split_tree 

155 ) 

156 log_s = forest.log_s 

157 if log_s is None: 

158 p = None 

159 else: 

160 *_, p = chain_to_axis(log_s, chain_vmap_axes(forest).log_s).shape 

161 return dict(num_chains=state.num_chains(), max_leaves=split_tree.shape[-1], p=p) 

162 

163 @staticmethod 

164 def _effective_predictors(log_s: Float32[Array, '*chains p']) -> Float32[Array, '']: 

165 """Effective number of predictors used for splitting across all chains. 

166 

167 Perplexity (exponential of the Shannon entropy) of the split-variable 

168 distribution ``s = softmax(log_s)`` pooled (averaged) over chains. It is 

169 1 when all chains concentrate on a single shared predictor and ``p`` when 

170 the pooled distribution is uniform; in general a pooled distribution 

171 spread evenly over ``k`` predictors gives ``k``. Chains are pooled before 

172 taking the entropy because predictions average over all chains, so a 

173 predictor used by any chain counts as used. 

174 """ 

175 *_, p = log_s.shape 

176 # normalize each chain 

177 log_prob = log_s - logsumexp(log_s, axis=-1, keepdims=True) 

178 per_chain = log_prob.reshape(-1, p) 

179 num_chains, _ = per_chain.shape 

180 # mix over chains. WORKAROUND(jax<0.7.1): once we bump jax to v0.7.1 

181 # this is `jax.nn.logmeanexp(per_chain, axis=0)` 

182 log_pool = logsumexp(per_chain, axis=0) - jnp.log(num_chains) 

183 prob = jnp.exp(log_pool) 

184 # the where avoids the 0 * -inf = nan term where a probability is 0, the 

185 # same guard `jax.scipy.special.entr` uses, but reusing the log we have 

186 entropy = -jnp.sum(prob * jnp.where(prob, log_pool, 1.0)) 

187 return jnp.exp(entropy) 

188 

189 

190def make_print_callback( 

191 state: State, 

192 *, 

193 dot_every: int | Integer[Array, ''] | None = 1, 

194 report_every: int | Integer[Array, ''] | None = 100, 

195 average: bool = True, 

196) -> dict[str, Any]: 

197 """ 

198 Prepare a progress-printing callback for `run_mcmc`. 

199 

200 The callback prints a dot on every iteration, and a longer report 

201 periodically. 

202 

203 Parameters 

204 ---------- 

205 state 

206 The MCMC state to use the callback with, used to determine device 

207 sharding. 

208 dot_every 

209 A dot is printed every `dot_every` MCMC iterations, `None` to disable. 

210 report_every 

211 A one line report is printed every `report_every` MCMC iterations, 

212 `None` to disable. 

213 average 

214 If `True`, the reported statistics are averaged over the iterations 

215 since the previous report; if `False`, they reflect the current 

216 iteration only. Ignored when `report_every` is `None`. 

217 

218 Returns 

219 ------- 

220 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback. 

221 

222 Examples 

223 -------- 

224 >>> run_mcmc(key, state, ..., **make_print_callback(state, ...)) 

225 """ 

226 

227 def as_replicated_array_or_none( 

228 val: Shaped[ArrayLike, '*shape'] | None, 

229 ) -> None | Shaped[Array, '*shape']: 

230 return None if val is None else _replicate(jnp.asarray(val), state.config.mesh) 

231 

232 accumulator = tree.map( 

233 partial(_replicate, mesh=state.config.mesh), 

234 StatsAccumulator.initial(state, enabled=average and report_every is not None), 

235 ) 

236 

237 return dict( 

238 callback=print_callback, 

239 callback_state=PrintCallbackState( 

240 as_replicated_array_or_none(dot_every), 

241 as_replicated_array_or_none(report_every), 

242 accumulator, 

243 ), 

244 ) 

245 

246 

247class PrintCallbackState(Module): 

248 """State for `print_callback`.""" 

249 

250 dot_every: Int32[Array, ''] | None 

251 """A dot is printed every `dot_every` MCMC iterations, `None` to disable.""" 

252 

253 report_every: Int32[Array, ''] | None 

254 """A one line report is printed every `report_every` MCMC iterations, 

255 `None` to disable.""" 

256 

257 accumulator: StatsAccumulator 

258 """Running average of the reported statistics, inert unless averaging.""" 

259 

260 

261def print_callback( 

262 *, 

263 state: State, 

264 burnin: Bool[Array, ''], 

265 i_total: Int32[Array, ''], 

266 n_burn: Int32[Array, ''], 

267 n_save: Int32[Array, ''], 

268 n_skip: Int32[Array, ''], 

269 callback_state: PrintCallbackState, 

270 **_: Any, 

271) -> tuple[State, PrintCallbackState]: 

272 """Print a dot and/or a report periodically during the MCMC.""" 

273 report_every = callback_state.report_every 

274 dot_every = callback_state.dot_every 

275 it = i_total + 1 

276 

277 accumulator = callback_state.accumulator.update(state) 

278 

279 def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']: 

280 return False if every is None else it % every == 0 

281 

282 report_cond = get_cond(report_every) 

283 dot_cond = get_cond(dot_every) 

284 

285 def line_report_branch() -> None: 

286 if report_every is None: 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true

287 return 

288 if dot_every is None: 288 ↛ 289line 288 didn't jump to line 289 because the condition on line 288 was never true

289 print_newline = False 

290 else: 

291 print_newline = it % report_every > it % dot_every 

292 debug.callback( 

293 _print_report, 

294 accumulator.report(state), 

295 print_dot=dot_cond, 

296 print_newline=print_newline, 

297 burnin=burnin, 

298 it=it, 

299 n_iters=n_burn + n_save * n_skip, 

300 ) 

301 

302 def just_dot_branch() -> None: 

303 if dot_every is None: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true

304 return 

305 # terminate the dot line on the final iteration so subsequent output 

306 # doesn't continue on the same line as the dots 

307 last_iter = it == n_burn + n_save * n_skip 

308 lax.cond( 

309 last_iter, 

310 lambda: debug.callback(lambda: print('.', flush=True)), # noqa: T201 

311 lambda: debug.callback(lambda: print('.', end='', flush=True)), # noqa: T201 

312 ) 

313 # logging can't do in-line printing so we use print 

314 

315 lax.cond( 

316 report_cond, 

317 line_report_branch, 

318 lambda: lax.cond(dot_cond, just_dot_branch, lambda: None), 

319 ) 

320 

321 accumulator = accumulator.reset_if(report_cond) 

322 return state, replace(callback_state, accumulator=accumulator) 

323 

324 

325def make_tqdm_callback( 

326 state: State, 

327 *, 

328 update_every: int = 1, 

329 report_every: int | None = 100, 

330 average: bool = True, 

331 **tqdm_kwargs: Any, 

332) -> dict[str, Any]: 

333 """ 

334 Prepare a `tqdm` progress-bar callback for `run_mcmc`. 

335 

336 The callback shows a progress bar that advances with the MCMC iterations, 

337 optionally annotated with the proposal acceptance statistics. 

338 

339 Parameters 

340 ---------- 

341 state 

342 The MCMC state to use the callback with, used to determine device 

343 sharding. 

344 update_every 

345 The bar position is refreshed every `update_every` MCMC iterations 

346 (`tqdm` further throttles the actual redraw rate on its own). 

347 report_every 

348 The acceptance statistics shown next to the bar are refreshed every 

349 `report_every` MCMC iterations, `None` to omit them. 

350 average 

351 If `True`, the statistics shown are averaged over the iterations since 

352 the previous refresh; if `False`, they reflect the current iteration 

353 only. Ignored when `report_every` is `None`. 

354 **tqdm_kwargs 

355 Additional keyword arguments forwarded to the `tqdm.tqdm` constructor, 

356 e.g., ``desc``, ``file``, or ``disable``. 

357 

358 Returns 

359 ------- 

360 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback. 

361 

362 Notes 

363 ----- 

364 Works with chains sharded across multiple devices. If the run is interrupted 

365 (e.g. with ^C), the bar is left as-is; the next `make_tqdm_callback` call 

366 closes it, so a subsequent run starts from a clean line. 

367 

368 Examples 

369 -------- 

370 >>> run_mcmc(key, state, ..., **make_tqdm_callback(state, ...)) 

371 """ 

372 _close_stale_bars() # clean up after any previous run that was interrupted 

373 bar_id = next(_TQDM_BAR_COUNTER) 

374 _TQDM_REGISTRY[bar_id] = _TqdmEntry(tqdm_kwargs) 

375 

376 def as_replicated_array( 

377 val: Shaped[ArrayLike, '*shape'], 

378 ) -> Shaped[Array, '*shape']: 

379 return _replicate(jnp.asarray(val), state.config.mesh) 

380 

381 return dict( 

382 callback=tqdm_callback, 

383 callback_state=TqdmCallbackState( 

384 bar_id=as_replicated_array(jnp.int32(bar_id)), 

385 update_every=as_replicated_array(jnp.int32(update_every)), 

386 report_every=None 

387 if report_every is None 

388 else as_replicated_array(jnp.int32(report_every)), 

389 accumulator=tree.map( 

390 partial(_replicate, mesh=state.config.mesh), 

391 StatsAccumulator.initial( 

392 state, enabled=average and report_every is not None 

393 ), 

394 ), 

395 ), 

396 ) 

397 

398 

399class TqdmCallbackState(Module): 

400 """State for `tqdm_callback`.""" 

401 

402 bar_id: Int32[Array, ''] 

403 """Handle identifying the bar in the module-level `tqdm` bar registry.""" 

404 

405 update_every: Int32[Array, ''] 

406 """The bar position is refreshed every `update_every` MCMC iterations.""" 

407 

408 report_every: Int32[Array, ''] | None 

409 """The acceptance statistics are refreshed every `report_every` MCMC 

410 iterations, `None` to omit them.""" 

411 

412 accumulator: StatsAccumulator 

413 """Running average of the reported statistics, inert unless averaging.""" 

414 

415 

416def tqdm_callback( 

417 *, 

418 state: State, 

419 i_total: Int32[Array, ''], 

420 n_burn: Int32[Array, ''], 

421 n_save: Int32[Array, ''], 

422 n_skip: Int32[Array, ''], 

423 callback_state: TqdmCallbackState, 

424 **_: Any, 

425) -> tuple[State, TqdmCallbackState]: 

426 """Advance a `tqdm` progress bar during the MCMC.""" 

427 it = i_total + 1 

428 n_iters = n_burn + n_save * n_skip 

429 bar_id = callback_state.bar_id 

430 last = it == n_iters 

431 

432 accumulator = callback_state.accumulator.update(state) 

433 

434 # The callbacks are unordered: `ordered=True` is unsupported with more than 

435 # one device, and we need this to work with chains sharded across devices. 

436 # `_tqdm_advance` is therefore robust to out-of-order invocations. 

437 

438 # refresh the statistics first so they tend to be visible by the time the 

439 # bar is advanced 

440 report_every = callback_state.report_every 

441 if report_every is not None: 441 ↛ 450line 441 didn't jump to line 450 because the condition on line 441 was always true

442 report_cond = (it % report_every == 0) | last 

443 

444 def report_branch() -> None: 

445 debug.callback(_tqdm_report, accumulator.report(state), bar_id, n_iters) 

446 

447 lax.cond(report_cond, report_branch, lambda: None) 

448 accumulator = accumulator.reset_if(report_cond) 

449 

450 lax.cond( 

451 (it % callback_state.update_every == 0) | last, 

452 lambda: debug.callback(_tqdm_advance, bar_id, it, n_iters), 

453 lambda: None, 

454 ) 

455 

456 return state, replace(callback_state, accumulator=accumulator) 

457 

458 

459T = TypeVar('T') 

460 

461 

462def _convert_jax_arrays_in_args(func: Callable[..., T]) -> Callable[..., T]: 

463 """Remove jax arrays from a function arguments. 

464 

465 Converts all `jax.Array` instances in the arguments to either Python scalars 

466 or numpy arrays. 

467 """ 

468 

469 def convert_jax_arrays(pytree: PyTree) -> PyTree: 

470 def convert_jax_array(val: object) -> object: 

471 if not isinstance(val, Array): 

472 return val 

473 elif val.shape: 473 ↛ 474line 473 didn't jump to line 474 because the condition on line 473 was never true

474 return numpy.array(val) 

475 else: 

476 return val.item() 

477 

478 return tree.map(convert_jax_array, pytree) 

479 

480 @wraps(func) 

481 def new_func(*args: Any, **kw: Any) -> T: 

482 args = convert_jax_arrays(args) 

483 kw = convert_jax_arrays(kw) 

484 return func(*args, **kw) 

485 

486 return new_func 

487 

488 

489@_convert_jax_arrays_in_args 

490# convert all jax arrays in arguments because operations on them could lead to 

491# deadlock with the main thread 

492def _print_report( 

493 report: StatsReport, 

494 *, 

495 print_dot: bool, 

496 print_newline: bool, 

497 burnin: bool, 

498 it: int, 

499 n_iters: int, 

500) -> None: 

501 """Print the report for `print_callback`.""" 

502 # determine prefix 

503 if print_dot: 503 ↛ 505line 503 didn't jump to line 505 because the condition on line 503 was always true

504 prefix = '.\n' 

505 elif print_newline: 

506 prefix = '\n' 

507 else: 

508 prefix = '' 

509 

510 # determine suffix in parentheses: what the statistics are averaged over 

511 avg_over = [] 

512 if report.num_chains is not None: 

513 avg_over.append(f'{report.num_chains} chains') 

514 if report.n_samples is not None: 514 ↛ 516line 514 didn't jump to line 516 because the condition on line 514 was always true

515 avg_over.append(f'{report.n_samples} samples') 

516 msgs = [] 

517 if avg_over: 517 ↛ 519line 517 didn't jump to line 519 because the condition on line 517 was always true

518 msgs.append('avg. ' + ' x '.join(avg_over)) 

519 if burnin: 519 ↛ 520line 519 didn't jump to line 520 because the condition on line 519 was never true

520 msgs.append('burnin') 

521 suffix = f' ({", ".join(msgs)})' if msgs else '' 

522 

523 # variable-selection concentration, only shown when it is enabled 

524 if report.peff is None: 

525 var_msg = '' 

526 else: 

527 var_msg = f'var: {report.peff:.1f}/{report.p}, ' 

528 

529 print( # noqa: T201, see print_callback for why not logging 

530 f'{prefix}Iteration {it}/{n_iters}, ' 

531 f'grow prob: {report.grow_prop:.0%}, ' 

532 f'move acc: {report.move_acc:.0%}, ' 

533 f'{var_msg}' 

534 f'leaves: {report.mean_leaves:.1f}/{report.max_leaves}{suffix}' 

535 ) 

536 

537 

538@dataclass(frozen=True) 

539class _TqdmEntry: 

540 """An entry in the `tqdm` bar registry.""" 

541 

542 kwargs: dict[str, Any] 

543 """Keyword arguments to construct the bar with, from `make_tqdm_callback`.""" 

544 

545 bar: tqdm | None = None 

546 """The bar, created lazily on the first callback invocation, `None` until then.""" 

547 

548 

549# tqdm carries Python state that cannot live in a jax pytree, so the bars are 

550# kept here and referenced from the jax loop through the integer handle stored 

551# in `TqdmCallbackState.bar_id` (a traceable scalar, so the loop pytree stays 

552# stable across runs and is not recompiled). 

553_TQDM_REGISTRY: dict[int, _TqdmEntry] = {} 

554_TQDM_BAR_COUNTER = itertools.count() 

555 

556# tqdm's default layout, but without the ': ' that `format_meter` forces after a 

557# non-empty description; the label is set as a `{desc}` ending in a space instead 

558_TQDM_BAR_FORMAT = ( 

559 '{desc}{percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} ' 

560 '[{elapsed}<{remaining}, {rate_fmt}{postfix}]' 

561) 

562 

563 

564def _close_stale_bars() -> None: 

565 """Close and drop any bars left over from a previous (e.g. interrupted) run.""" 

566 for entry in _TQDM_REGISTRY.values(): 

567 if entry.bar is not None: 

568 entry.bar.close() 

569 _TQDM_REGISTRY.clear() 

570 

571 

572def _get_or_create_bar(bar_id: int, n_iters: int) -> tqdm | None: 

573 """Return the bar for `bar_id`, creating it on first use, `None` if finished.""" 

574 entry = _TQDM_REGISTRY.get(bar_id) 

575 if entry is None: 

576 # the bar was already closed (the loop finished, possibly out of order) 

577 return None 

578 if entry.bar is None: 

579 bar = tqdm(**{'total': n_iters, 'bar_format': _TQDM_BAR_FORMAT, **entry.kwargs}) 

580 _TQDM_REGISTRY[bar_id] = replace(entry, bar=bar) 

581 return bar 

582 return entry.bar 

583 

584 

585@_convert_jax_arrays_in_args 

586# convert all jax arrays in arguments, see _print_report for why 

587def _tqdm_advance(bar_id: int, it: int, n_iters: int) -> None: 

588 """Advance the bar towards absolute position `it`, closing it at the end.""" 

589 bar = _get_or_create_bar(bar_id, n_iters) 

590 if bar is None: 

591 return 

592 bar.update(max(0, it - bar.n)) # forward-only: callbacks may arrive out of order 

593 if it >= n_iters: 

594 bar.close() 

595 del _TQDM_REGISTRY[bar_id] 

596 

597 

598@_convert_jax_arrays_in_args 

599# convert all jax arrays in arguments, see _print_report for why 

600def _tqdm_report(report: StatsReport, bar_id: int, n_iters: int) -> None: 

601 """Set the bar description and acceptance-statistics postfix.""" 

602 bar = _get_or_create_bar(bar_id, n_iters) 

603 if bar is None: 

604 return 

605 # set_description_str (not set_description) to avoid tqdm's ': ' suffix; the 

606 # trailing space separates the label from the bar 

607 bar.set_description_str('train ', refresh=False) 

608 # keep this terse so the bar stays narrow, e.g. '4ch 100sa acc 25% leaves 3.4/32' 

609 msgs = [] 

610 if report.num_chains is not None: 

611 msgs.append(f'{report.num_chains}ch') 

612 if report.n_samples is not None: 

613 msgs.append(f'{report.n_samples}sa') 

614 msgs.append(f'acc {report.move_acc:.0%}') 

615 if report.peff is not None: 

616 msgs.append(f'var {report.peff:.1f}/{report.p}') 

617 msgs.append(f'leaves {report.mean_leaves:.1f}/{report.max_leaves}') 

618 bar.set_postfix_str(' '.join(msgs))