Coverage for src/bartz/mcmcloop/_callback.py: 93%
233 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/mcmcloop/_callback.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Progress-reporting callbacks for `run_mcmc`."""
27import itertools
28from collections.abc import Callable
29from dataclasses import dataclass, replace
30from functools import partial, wraps
31from typing import Any, TypeVar
33import numpy
34from equinox import Module, field
35from jax import debug, eval_shape, lax, tree
36from jax import numpy as jnp
37from jax.scipy.special import logsumexp
38from jaxtyping import Array, ArrayLike, Bool, Float32, Int32, Integer, PyTree, Shaped
39from tqdm.auto import tqdm
41from bartz._typing import kwdict
42from bartz.grove import forest_mean_leaves
43from bartz.mcmcloop._loop import _replicate
44from bartz.mcmcstep import State
45from bartz.mcmcstep._axes import chain_to_axis, chain_vmap_axes, chainful_axis
48class StatsReport(Module):
49 """Forest diagnostics produced by `StatsAccumulator.report` for one report."""
51 grow_prop: Float32[Array, '']
52 """Fraction of trees proposed for a grow move."""
54 move_acc: Float32[Array, '']
55 """Fraction of trees on which a grow or prune move was accepted."""
57 mean_leaves: Float32[Array, '']
58 """Mean number of leaves per tree."""
60 peff: Float32[Array, ''] | None
61 """Effective number of predictors, or `None` when variable selection is off."""
63 n_samples: Int32[Array, ''] | None
64 """Number of iterations averaged over, or `None` when not averaging."""
66 num_chains: int | None = field(static=True)
67 """Number of chains averaged over, or `None` when single-chain."""
69 max_leaves: int = field(static=True)
70 """Maximum possible number of leaves per tree."""
72 p: int | None = field(static=True)
73 """Number of predictors, or `None` when variable selection is off."""
76class StatsAccumulator(Module):
77 """Running average of the forest diagnostics shown during the MCMC.
79 When enabled, it sums the per-iteration statistics so a report shows their
80 average over the iterations since the previous report. When disabled it
81 carries no running state and a report shows the latest iteration only.
82 """
84 sums: dict[str, Float32[Array, '']] | None
85 """Running sums of the averaged statistics, or `None` when disabled."""
87 count: Int32[Array, '']
88 """Number of iterations accumulated since the last reset."""
90 @classmethod
91 def initial(cls, state: State, *, enabled: bool) -> 'StatsAccumulator':
92 """Create a zeroed accumulator, inert unless `enabled`."""
93 if enabled:
94 # only the structure is needed, so avoid computing the statistics
95 shapes = eval_shape(cls._avg_stats, state)
96 sums = tree.map(lambda s: jnp.zeros(s.shape, s.dtype), shapes)
97 else:
98 sums = None
99 return cls(sums=sums, count=jnp.int32(0))
101 def update(self, state: State) -> 'StatsAccumulator':
102 """Add the latest iteration's statistics; no-op when disabled."""
103 if self.sums is None:
104 return self
105 sums = tree.map(jnp.add, self.sums, self._avg_stats(state))
106 return replace(self, sums=sums, count=self.count + 1)
108 def reset_if(self, cond: bool | Bool[Array, '']) -> 'StatsAccumulator':
109 """Zero the running sums where `cond` holds; no-op when disabled."""
110 if self.sums is None:
111 return self
112 sums = tree.map(lambda s: jnp.where(cond, 0, s), self.sums)
113 return replace(self, sums=sums, count=jnp.where(cond, 0, self.count))
115 def report(self, state: State) -> StatsReport:
116 """Statistics to display: the windowed average if enabled, else the latest."""
117 if self.sums is None:
118 averaged: kwdict = self._avg_stats(state)
119 n_samples = None
120 else:
121 averaged = tree.map(lambda s: s / self.count, self.sums)
122 n_samples = self.count
123 return StatsReport(**averaged, **self._static_stats(state), n_samples=n_samples)
125 @staticmethod
126 def _avg_stats(state: State) -> dict[str, Float32[Array, ''] | None]:
127 """Per-iteration diagnostics that are averaged over the report window."""
128 forest = state.forest
129 chain_axis = chain_vmap_axes(forest).split_tree
130 num_trees_axis = chainful_axis(0, chain_axis) # (num_trees, hts)
131 split_tree = chain_to_axis(forest.split_tree, chain_axis)
132 prop_total = forest.split_tree.shape[num_trees_axis]
134 log_s = forest.log_s
135 if log_s is None:
136 peff = None
137 else:
138 log_s = chain_to_axis(log_s, chain_vmap_axes(forest).log_s)
139 peff = StatsAccumulator._effective_predictors(log_s)
141 return dict(
142 grow_prop=forest.grow_prop_count.mean() / prop_total,
143 move_acc=(forest.grow_acc_count.mean() + forest.prune_acc_count.mean())
144 / prop_total,
145 mean_leaves=forest_mean_leaves(split_tree),
146 peff=peff,
147 )
149 @staticmethod
150 def _static_stats(state: State) -> dict[str, int | None]:
151 """Per-iteration diagnostics shown as-is, constant over the run."""
152 forest = state.forest
153 split_tree = chain_to_axis(
154 forest.split_tree, chain_vmap_axes(forest).split_tree
155 )
156 log_s = forest.log_s
157 if log_s is None:
158 p = None
159 else:
160 *_, p = chain_to_axis(log_s, chain_vmap_axes(forest).log_s).shape
161 return dict(num_chains=state.num_chains(), max_leaves=split_tree.shape[-1], p=p)
163 @staticmethod
164 def _effective_predictors(log_s: Float32[Array, '*chains p']) -> Float32[Array, '']:
165 """Effective number of predictors used for splitting across all chains.
167 Perplexity (exponential of the Shannon entropy) of the split-variable
168 distribution ``s = softmax(log_s)`` pooled (averaged) over chains. It is
169 1 when all chains concentrate on a single shared predictor and ``p`` when
170 the pooled distribution is uniform; in general a pooled distribution
171 spread evenly over ``k`` predictors gives ``k``. Chains are pooled before
172 taking the entropy because predictions average over all chains, so a
173 predictor used by any chain counts as used.
174 """
175 *_, p = log_s.shape
176 # normalize each chain
177 log_prob = log_s - logsumexp(log_s, axis=-1, keepdims=True)
178 per_chain = log_prob.reshape(-1, p)
179 num_chains, _ = per_chain.shape
180 # mix over chains. WORKAROUND(jax<0.7.1): once we bump jax to v0.7.1
181 # this is `jax.nn.logmeanexp(per_chain, axis=0)`
182 log_pool = logsumexp(per_chain, axis=0) - jnp.log(num_chains)
183 prob = jnp.exp(log_pool)
184 # the where avoids the 0 * -inf = nan term where a probability is 0, the
185 # same guard `jax.scipy.special.entr` uses, but reusing the log we have
186 entropy = -jnp.sum(prob * jnp.where(prob, log_pool, 1.0))
187 return jnp.exp(entropy)
190def make_print_callback(
191 state: State,
192 *,
193 dot_every: int | Integer[Array, ''] | None = 1,
194 report_every: int | Integer[Array, ''] | None = 100,
195 average: bool = True,
196) -> dict[str, Any]:
197 """
198 Prepare a progress-printing callback for `run_mcmc`.
200 The callback prints a dot on every iteration, and a longer report
201 periodically.
203 Parameters
204 ----------
205 state
206 The MCMC state to use the callback with, used to determine device
207 sharding.
208 dot_every
209 A dot is printed every `dot_every` MCMC iterations, `None` to disable.
210 report_every
211 A one line report is printed every `report_every` MCMC iterations,
212 `None` to disable.
213 average
214 If `True`, the reported statistics are averaged over the iterations
215 since the previous report; if `False`, they reflect the current
216 iteration only. Ignored when `report_every` is `None`.
218 Returns
219 -------
220 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
222 Examples
223 --------
224 >>> run_mcmc(key, state, ..., **make_print_callback(state, ...))
225 """
227 def as_replicated_array_or_none(
228 val: Shaped[ArrayLike, '*shape'] | None,
229 ) -> None | Shaped[Array, '*shape']:
230 return None if val is None else _replicate(jnp.asarray(val), state.config.mesh)
232 accumulator = tree.map(
233 partial(_replicate, mesh=state.config.mesh),
234 StatsAccumulator.initial(state, enabled=average and report_every is not None),
235 )
237 return dict(
238 callback=print_callback,
239 callback_state=PrintCallbackState(
240 as_replicated_array_or_none(dot_every),
241 as_replicated_array_or_none(report_every),
242 accumulator,
243 ),
244 )
247class PrintCallbackState(Module):
248 """State for `print_callback`."""
250 dot_every: Int32[Array, ''] | None
251 """A dot is printed every `dot_every` MCMC iterations, `None` to disable."""
253 report_every: Int32[Array, ''] | None
254 """A one line report is printed every `report_every` MCMC iterations,
255 `None` to disable."""
257 accumulator: StatsAccumulator
258 """Running average of the reported statistics, inert unless averaging."""
261def print_callback(
262 *,
263 state: State,
264 burnin: Bool[Array, ''],
265 i_total: Int32[Array, ''],
266 n_burn: Int32[Array, ''],
267 n_save: Int32[Array, ''],
268 n_skip: Int32[Array, ''],
269 callback_state: PrintCallbackState,
270 **_: Any,
271) -> tuple[State, PrintCallbackState]:
272 """Print a dot and/or a report periodically during the MCMC."""
273 report_every = callback_state.report_every
274 dot_every = callback_state.dot_every
275 it = i_total + 1
277 accumulator = callback_state.accumulator.update(state)
279 def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']:
280 return False if every is None else it % every == 0
282 report_cond = get_cond(report_every)
283 dot_cond = get_cond(dot_every)
285 def line_report_branch() -> None:
286 if report_every is None: 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true
287 return
288 if dot_every is None: 288 ↛ 289line 288 didn't jump to line 289 because the condition on line 288 was never true
289 print_newline = False
290 else:
291 print_newline = it % report_every > it % dot_every
292 debug.callback(
293 _print_report,
294 accumulator.report(state),
295 print_dot=dot_cond,
296 print_newline=print_newline,
297 burnin=burnin,
298 it=it,
299 n_iters=n_burn + n_save * n_skip,
300 )
302 def just_dot_branch() -> None:
303 if dot_every is None: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true
304 return
305 # terminate the dot line on the final iteration so subsequent output
306 # doesn't continue on the same line as the dots
307 last_iter = it == n_burn + n_save * n_skip
308 lax.cond(
309 last_iter,
310 lambda: debug.callback(lambda: print('.', flush=True)), # noqa: T201
311 lambda: debug.callback(lambda: print('.', end='', flush=True)), # noqa: T201
312 )
313 # logging can't do in-line printing so we use print
315 lax.cond(
316 report_cond,
317 line_report_branch,
318 lambda: lax.cond(dot_cond, just_dot_branch, lambda: None),
319 )
321 accumulator = accumulator.reset_if(report_cond)
322 return state, replace(callback_state, accumulator=accumulator)
325def make_tqdm_callback(
326 state: State,
327 *,
328 update_every: int = 1,
329 report_every: int | None = 100,
330 average: bool = True,
331 **tqdm_kwargs: Any,
332) -> dict[str, Any]:
333 """
334 Prepare a `tqdm` progress-bar callback for `run_mcmc`.
336 The callback shows a progress bar that advances with the MCMC iterations,
337 optionally annotated with the proposal acceptance statistics.
339 Parameters
340 ----------
341 state
342 The MCMC state to use the callback with, used to determine device
343 sharding.
344 update_every
345 The bar position is refreshed every `update_every` MCMC iterations
346 (`tqdm` further throttles the actual redraw rate on its own).
347 report_every
348 The acceptance statistics shown next to the bar are refreshed every
349 `report_every` MCMC iterations, `None` to omit them.
350 average
351 If `True`, the statistics shown are averaged over the iterations since
352 the previous refresh; if `False`, they reflect the current iteration
353 only. Ignored when `report_every` is `None`.
354 **tqdm_kwargs
355 Additional keyword arguments forwarded to the `tqdm.tqdm` constructor,
356 e.g., ``desc``, ``file``, or ``disable``.
358 Returns
359 -------
360 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
362 Notes
363 -----
364 Works with chains sharded across multiple devices. If the run is interrupted
365 (e.g. with ^C), the bar is left as-is; the next `make_tqdm_callback` call
366 closes it, so a subsequent run starts from a clean line.
368 Examples
369 --------
370 >>> run_mcmc(key, state, ..., **make_tqdm_callback(state, ...))
371 """
372 _close_stale_bars() # clean up after any previous run that was interrupted
373 bar_id = next(_TQDM_BAR_COUNTER)
374 _TQDM_REGISTRY[bar_id] = _TqdmEntry(tqdm_kwargs)
376 def as_replicated_array(
377 val: Shaped[ArrayLike, '*shape'],
378 ) -> Shaped[Array, '*shape']:
379 return _replicate(jnp.asarray(val), state.config.mesh)
381 return dict(
382 callback=tqdm_callback,
383 callback_state=TqdmCallbackState(
384 bar_id=as_replicated_array(jnp.int32(bar_id)),
385 update_every=as_replicated_array(jnp.int32(update_every)),
386 report_every=None
387 if report_every is None
388 else as_replicated_array(jnp.int32(report_every)),
389 accumulator=tree.map(
390 partial(_replicate, mesh=state.config.mesh),
391 StatsAccumulator.initial(
392 state, enabled=average and report_every is not None
393 ),
394 ),
395 ),
396 )
399class TqdmCallbackState(Module):
400 """State for `tqdm_callback`."""
402 bar_id: Int32[Array, '']
403 """Handle identifying the bar in the module-level `tqdm` bar registry."""
405 update_every: Int32[Array, '']
406 """The bar position is refreshed every `update_every` MCMC iterations."""
408 report_every: Int32[Array, ''] | None
409 """The acceptance statistics are refreshed every `report_every` MCMC
410 iterations, `None` to omit them."""
412 accumulator: StatsAccumulator
413 """Running average of the reported statistics, inert unless averaging."""
416def tqdm_callback(
417 *,
418 state: State,
419 i_total: Int32[Array, ''],
420 n_burn: Int32[Array, ''],
421 n_save: Int32[Array, ''],
422 n_skip: Int32[Array, ''],
423 callback_state: TqdmCallbackState,
424 **_: Any,
425) -> tuple[State, TqdmCallbackState]:
426 """Advance a `tqdm` progress bar during the MCMC."""
427 it = i_total + 1
428 n_iters = n_burn + n_save * n_skip
429 bar_id = callback_state.bar_id
430 last = it == n_iters
432 accumulator = callback_state.accumulator.update(state)
434 # The callbacks are unordered: `ordered=True` is unsupported with more than
435 # one device, and we need this to work with chains sharded across devices.
436 # `_tqdm_advance` is therefore robust to out-of-order invocations.
438 # refresh the statistics first so they tend to be visible by the time the
439 # bar is advanced
440 report_every = callback_state.report_every
441 if report_every is not None: 441 ↛ 450line 441 didn't jump to line 450 because the condition on line 441 was always true
442 report_cond = (it % report_every == 0) | last
444 def report_branch() -> None:
445 debug.callback(_tqdm_report, accumulator.report(state), bar_id, n_iters)
447 lax.cond(report_cond, report_branch, lambda: None)
448 accumulator = accumulator.reset_if(report_cond)
450 lax.cond(
451 (it % callback_state.update_every == 0) | last,
452 lambda: debug.callback(_tqdm_advance, bar_id, it, n_iters),
453 lambda: None,
454 )
456 return state, replace(callback_state, accumulator=accumulator)
459T = TypeVar('T')
462def _convert_jax_arrays_in_args(func: Callable[..., T]) -> Callable[..., T]:
463 """Remove jax arrays from a function arguments.
465 Converts all `jax.Array` instances in the arguments to either Python scalars
466 or numpy arrays.
467 """
469 def convert_jax_arrays(pytree: PyTree) -> PyTree:
470 def convert_jax_array(val: object) -> object:
471 if not isinstance(val, Array):
472 return val
473 elif val.shape: 473 ↛ 474line 473 didn't jump to line 474 because the condition on line 473 was never true
474 return numpy.array(val)
475 else:
476 return val.item()
478 return tree.map(convert_jax_array, pytree)
480 @wraps(func)
481 def new_func(*args: Any, **kw: Any) -> T:
482 args = convert_jax_arrays(args)
483 kw = convert_jax_arrays(kw)
484 return func(*args, **kw)
486 return new_func
489@_convert_jax_arrays_in_args
490# convert all jax arrays in arguments because operations on them could lead to
491# deadlock with the main thread
492def _print_report(
493 report: StatsReport,
494 *,
495 print_dot: bool,
496 print_newline: bool,
497 burnin: bool,
498 it: int,
499 n_iters: int,
500) -> None:
501 """Print the report for `print_callback`."""
502 # determine prefix
503 if print_dot: 503 ↛ 505line 503 didn't jump to line 505 because the condition on line 503 was always true
504 prefix = '.\n'
505 elif print_newline:
506 prefix = '\n'
507 else:
508 prefix = ''
510 # determine suffix in parentheses: what the statistics are averaged over
511 avg_over = []
512 if report.num_chains is not None:
513 avg_over.append(f'{report.num_chains} chains')
514 if report.n_samples is not None: 514 ↛ 516line 514 didn't jump to line 516 because the condition on line 514 was always true
515 avg_over.append(f'{report.n_samples} samples')
516 msgs = []
517 if avg_over: 517 ↛ 519line 517 didn't jump to line 519 because the condition on line 517 was always true
518 msgs.append('avg. ' + ' x '.join(avg_over))
519 if burnin: 519 ↛ 520line 519 didn't jump to line 520 because the condition on line 519 was never true
520 msgs.append('burnin')
521 suffix = f' ({", ".join(msgs)})' if msgs else ''
523 # variable-selection concentration, only shown when it is enabled
524 if report.peff is None:
525 var_msg = ''
526 else:
527 var_msg = f'var: {report.peff:.1f}/{report.p}, '
529 print( # noqa: T201, see print_callback for why not logging
530 f'{prefix}Iteration {it}/{n_iters}, '
531 f'grow prob: {report.grow_prop:.0%}, '
532 f'move acc: {report.move_acc:.0%}, '
533 f'{var_msg}'
534 f'leaves: {report.mean_leaves:.1f}/{report.max_leaves}{suffix}'
535 )
538@dataclass(frozen=True)
539class _TqdmEntry:
540 """An entry in the `tqdm` bar registry."""
542 kwargs: dict[str, Any]
543 """Keyword arguments to construct the bar with, from `make_tqdm_callback`."""
545 bar: tqdm | None = None
546 """The bar, created lazily on the first callback invocation, `None` until then."""
549# tqdm carries Python state that cannot live in a jax pytree, so the bars are
550# kept here and referenced from the jax loop through the integer handle stored
551# in `TqdmCallbackState.bar_id` (a traceable scalar, so the loop pytree stays
552# stable across runs and is not recompiled).
553_TQDM_REGISTRY: dict[int, _TqdmEntry] = {}
554_TQDM_BAR_COUNTER = itertools.count()
556# tqdm's default layout, but without the ': ' that `format_meter` forces after a
557# non-empty description; the label is set as a `{desc}` ending in a space instead
558_TQDM_BAR_FORMAT = (
559 '{desc}{percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} '
560 '[{elapsed}<{remaining}, {rate_fmt}{postfix}]'
561)
564def _close_stale_bars() -> None:
565 """Close and drop any bars left over from a previous (e.g. interrupted) run."""
566 for entry in _TQDM_REGISTRY.values():
567 if entry.bar is not None:
568 entry.bar.close()
569 _TQDM_REGISTRY.clear()
572def _get_or_create_bar(bar_id: int, n_iters: int) -> tqdm | None:
573 """Return the bar for `bar_id`, creating it on first use, `None` if finished."""
574 entry = _TQDM_REGISTRY.get(bar_id)
575 if entry is None:
576 # the bar was already closed (the loop finished, possibly out of order)
577 return None
578 if entry.bar is None:
579 bar = tqdm(**{'total': n_iters, 'bar_format': _TQDM_BAR_FORMAT, **entry.kwargs})
580 _TQDM_REGISTRY[bar_id] = replace(entry, bar=bar)
581 return bar
582 return entry.bar
585@_convert_jax_arrays_in_args
586# convert all jax arrays in arguments, see _print_report for why
587def _tqdm_advance(bar_id: int, it: int, n_iters: int) -> None:
588 """Advance the bar towards absolute position `it`, closing it at the end."""
589 bar = _get_or_create_bar(bar_id, n_iters)
590 if bar is None:
591 return
592 bar.update(max(0, it - bar.n)) # forward-only: callbacks may arrive out of order
593 if it >= n_iters:
594 bar.close()
595 del _TQDM_REGISTRY[bar_id]
598@_convert_jax_arrays_in_args
599# convert all jax arrays in arguments, see _print_report for why
600def _tqdm_report(report: StatsReport, bar_id: int, n_iters: int) -> None:
601 """Set the bar description and acceptance-statistics postfix."""
602 bar = _get_or_create_bar(bar_id, n_iters)
603 if bar is None:
604 return
605 # set_description_str (not set_description) to avoid tqdm's ': ' suffix; the
606 # trailing space separates the label from the bar
607 bar.set_description_str('train ', refresh=False)
608 # keep this terse so the bar stays narrow, e.g. '4ch 100sa acc 25% leaves 3.4/32'
609 msgs = []
610 if report.num_chains is not None:
611 msgs.append(f'{report.num_chains}ch')
612 if report.n_samples is not None:
613 msgs.append(f'{report.n_samples}sa')
614 msgs.append(f'acc {report.move_acc:.0%}')
615 if report.peff is not None:
616 msgs.append(f'var {report.peff:.1f}/{report.p}')
617 msgs.append(f'leaves {report.mean_leaves:.1f}/{report.max_leaves}')
618 bar.set_postfix_str(' '.join(msgs))