Coverage for src/bartz/mcmcstep/_axes.py: 97%
87 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/mcmcstep/_axes.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Batch-axis bookkeeping: `field` markers and the chain/data/sample resolvers.
27Module dataclasses tag their array fields with chain/data/sample axis positions
28via `field`; the resolvers here read those markers off a pytree to build the
29`in_axes`/`out_axes` for `jax.vmap` and the axis positions used when reshaping.
30"""
32import os
33from collections.abc import Callable, Hashable
34from dataclasses import fields
35from typing import Any, TypeVar
37from equinox import Module as EquinoxModule
38from jax import numpy as jnp
39from jax import tree
40from jaxtyping import Array, PyTree, Shaped
41from numpy.lib.array_utils import normalize_axis_index
43from bartz.mcmcstep._lazy import DummyArray, _LazyArray
45# Structure variable for the `PyTree[..., 'T']` annotations below.
46T = TypeVar('T')
48# Default position of the chain axis in chain-bearing leaves; see `field`.
49CHAIN_AXIS = int(os.environ.get('CHAIN_AXIS', '0'))
52def chain_vmap_axes(x: PyTree[EquinoxModule | Any, 'T']) -> PyTree[int | None, 'T ...']:
53 """Determine vmapping axes for chains.
55 This function determines the argument to the `in_axes` or `out_axes`
56 parameter of `jax.vmap` to vmap over all and only the chain axes found in the
57 pytree `x`.
59 Parameters
60 ----------
61 x
62 A pytree. Subpytrees that are Module attributes marked with
63 ``field(chains=<int>)`` are considered to have a chain axis at that
64 index. `x` (or one of its subtrees) must define a `has_chains` property
65 (see `get_has_chains`).
67 Returns
68 -------
69 A pytree with the same structure as `x`, with each leaf set to the chain axis index declared by its owning ``field(chains=...)`` marker, normalized against the leaf's ``ndim`` (so the returned indices are non-negative), or `None` for unmarked leaves. If `has_chains` is `False`, every leaf is `None`.
70 """
71 if not get_has_chains(x):
72 return _find_metadata(x, 'chains', marker_value=_none_marker)
74 return _find_metadata(x, 'chains')
77def _none_marker(leaf: object, raw: int) -> None: # noqa: ARG001
78 """Marker mapper that always returns `None`."""
79 return None # noqa: RET501
82def data_vmap_axes(x: PyTree[EquinoxModule | Any, 'T']) -> PyTree[int | None, 'T ...']:
83 """Determine vmapping axes for data.
85 Parameters
86 ----------
87 x
88 A pytree. Subpytrees that are Module attributes marked with
89 ``field(data=<int>)`` are considered to have a data axis at that
90 position in the chain-less layout. `x` (or one of its subtrees) must
91 define a `has_chains` property (see `get_has_chains`).
93 Returns
94 -------
95 A pytree with the same structure as `x`, with each leaf set to the data axis index (normalized and chain-shifted), or `None` for unmarked leaves.
96 """
97 chain_axes = chain_vmap_axes(x)
98 data_raw = _find_metadata(x, 'data', marker_value=_raw_marker)
99 return tree.map(
100 _compute_core_axis, x, data_raw, chain_axes, is_leaf=_is_core_axis_leaf
101 )
104def trace_sample_axes(
105 trace: PyTree[EquinoxModule | Any, 'T'],
106) -> PyTree[int | None, 'T ...']:
107 """Determine the position of the sample axis for each leaf of a trace.
109 Parameters
110 ----------
111 trace
112 A trace pytree (typically a `~bartz.mcmcloop.BurninTrace` or
113 `~bartz.mcmcloop.MainTrace`). `trace` (or one of its subtrees) must
114 define a `has_chains` property.
116 Returns
117 -------
118 A pytree with the same structure as `trace` but with sample axes in the leaves, see `field`.
119 """
120 chain_axes = chain_vmap_axes(trace)
121 sample_raw = _find_metadata(trace, 'samples', marker_value=_raw_marker)
122 return tree.map(
123 _compute_core_axis, trace, sample_raw, chain_axes, is_leaf=_is_core_axis_leaf
124 )
127def _raw_marker(leaf: object, raw: int) -> int: # noqa: ARG001
128 """Marker mapper that returns the raw marker value."""
129 return raw
132def _is_core_axis_leaf(x: object) -> bool:
133 """Treat `None` and `_LazyArray` as leaves when resolving core-axis markers."""
134 return x is None or _is_lazy_array(x)
137def chainful_axis(core_axis: int, chain_axis: int | None) -> int:
138 """Position of a chainless-layout axis in the corresponding chainful array.
140 Parameters
141 ----------
142 core_axis
143 Non-negative axis position in the chainless ("core") layout.
144 chain_axis
145 Non-negative position of the chain axis in the chainful layout, or
146 `None` if there is no chain axis.
148 Returns
149 -------
150 The non-negative position of `core_axis` after inserting the chain axis at `chain_axis`.
151 """
152 if chain_axis is None or core_axis < chain_axis:
153 return core_axis
154 return core_axis + 1
157def chain_to_axis(
158 arr: Shaped[Array, '...'], chain_axis: int | None, target: int = 0
159) -> Shaped[Array, '...']:
160 """Move `chain_axis` of `arr` to position `target`.
162 Helper for the common pattern of normalizing the chain axis position in
163 arrays derived from chain-marked Module fields. Pair it with
164 `chain_vmap_axes` to fetch the source axis from a dataclass.
166 Parameters
167 ----------
168 arr
169 Array to be reordered.
170 chain_axis
171 Source position of the chain axis, or `None` for arrays with no chain
172 axis (in which case `arr` is returned unchanged).
173 target
174 Destination position of the chain axis.
176 Returns
177 -------
178 The reordered array.
179 """
180 if chain_axis is None:
181 return arr
182 return jnp.moveaxis(arr, chain_axis, target)
185def _compute_core_axis(
186 leaf: Shaped[DummyArray, '...'] | None, raw_axis: int | None, chain_axis: int | None
187) -> int | None:
188 """Combine a raw core-layout marker and a (normalized) chain position."""
189 if raw_axis is None:
190 return None
191 assert leaf is not None
192 has_chain = chain_axis is not None
193 core_ndim = leaf.ndim - (1 if has_chain else 0)
194 axis = normalize_axis_index(raw_axis, core_ndim)
195 return chainful_axis(axis, chain_axis)
198class _HasChainsFound(Exception):
199 """Internal control-flow signal carrying a found `has_chains` value."""
201 def __init__(self, value: bool) -> None:
202 self.value = value
205def get_has_chains(x: PyTree) -> bool:
206 """Return the `has_chains` flag from the first node in `x` that defines it.
208 Walks `x` and stops at the first node exposing a `has_chains` attribute,
209 returning its value. The walk uses `jax.tree.map` with an `is_leaf` callback
210 that raises a custom exception to short-circuit traversal.
212 Parameters
213 ----------
214 x
215 A pytree, possibly containing nodes that define a `has_chains`
216 attribute.
218 Returns
219 -------
220 The value of `has_chains` on the first matching node.
222 Raises
223 ------
224 ValueError
225 If no node in `x` defines a `has_chains` property.
226 """
228 def is_leaf(node: object) -> bool:
229 value = getattr(node, 'has_chains', None)
230 if value is None:
231 return False
232 raise _HasChainsFound(value)
234 try:
235 tree.map(lambda _: None, x, is_leaf=is_leaf)
236 except _HasChainsFound as exc:
237 return exc.value
238 msg = 'no `has_chains` property found in the pytree'
239 raise ValueError(msg)
242def _normalize_axis_for_leaf(leaf: Shaped[DummyArray, '...'], raw: int) -> int:
243 """Normalize a marker axis index against `leaf.ndim`.
245 Raises `numpy.exceptions.AxisError` if `raw` is out of bounds for
246 `leaf.ndim`.
247 """
248 return normalize_axis_index(raw, leaf.ndim)
251def _is_lazy_array(x: object) -> bool:
252 return isinstance(x, _LazyArray)
255def _is_module(x: object) -> bool:
256 return isinstance(x, EquinoxModule) and not _is_lazy_array(x)
259def _find_metadata(
260 x: PyTree[Any, ' S'],
261 key: Hashable,
262 *,
263 marker_value: Callable[[DummyArray, int], object] = _normalize_axis_for_leaf,
264 default_value: object = None,
265) -> PyTree[Any, ' S ...']:
266 """Walk `x` replacing marked subtrees with derived values.
268 For each Module field whose metadata contains `key`, the field's subtree
269 is replaced by mapping ``marker_value(leaf, raw)`` over its leaves, where
270 `raw` is the unnormalized metadata value; leaves outside any marked field
271 become `default_value`.
272 """
273 if _is_module(x): 273 ↛ 302line 273 didn't jump to line 302 because the condition on line 273 was always true
274 args = []
275 for f in fields(x):
276 v = getattr(x, f.name)
277 if f.metadata.get('static', False):
278 args.append(v)
279 elif key in f.metadata:
280 raw = f.metadata[key]
281 args.append(
282 tree.map(
283 lambda leaf, raw=raw: marker_value(leaf, raw),
284 v,
285 is_leaf=_is_lazy_array,
286 )
287 )
288 else:
289 args.append(
290 _find_metadata(
291 v, key, marker_value=marker_value, default_value=default_value
292 )
293 )
294 # rebuild bypassing the (type-checked) __init__: the result is a
295 # same-structure pytree whose leaves are axis markers (int/None), not
296 # the arrays the field annotations require.
297 out = object.__new__(type(x))
298 for f, value in zip(fields(x), args, strict=True):
299 object.__setattr__(out, f.name, value)
300 return out
302 def get_axes(x: object) -> PyTree:
303 if _is_module(x):
304 return _find_metadata(
305 x, key, marker_value=marker_value, default_value=default_value
306 )
307 return tree.map(lambda _: default_value, x, is_leaf=_is_lazy_array)
309 return tree.map(get_axes, x, is_leaf=lambda x: isinstance(x, EquinoxModule))