Coverage for src/bartz/mcmcstep/_axes.py: 97%

87 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/mcmcstep/_axes.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Batch-axis bookkeeping: `field` markers and the chain/data/sample resolvers. 

26 

27Module dataclasses tag their array fields with chain/data/sample axis positions 

28via `field`; the resolvers here read those markers off a pytree to build the 

29`in_axes`/`out_axes` for `jax.vmap` and the axis positions used when reshaping. 

30""" 

31 

32import os 

33from collections.abc import Callable, Hashable 

34from dataclasses import fields 

35from typing import Any, TypeVar 

36 

37from equinox import Module as EquinoxModule 

38from jax import numpy as jnp 

39from jax import tree 

40from jaxtyping import Array, PyTree, Shaped 

41from numpy.lib.array_utils import normalize_axis_index 

42 

43from bartz.mcmcstep._lazy import DummyArray, _LazyArray 

44 

45# Structure variable for the `PyTree[..., 'T']` annotations below. 

46T = TypeVar('T') 

47 

48# Default position of the chain axis in chain-bearing leaves; see `field`. 

49CHAIN_AXIS = int(os.environ.get('CHAIN_AXIS', '0')) 

50 

51 

52def chain_vmap_axes(x: PyTree[EquinoxModule | Any, 'T']) -> PyTree[int | None, 'T ...']: 

53 """Determine vmapping axes for chains. 

54 

55 This function determines the argument to the `in_axes` or `out_axes` 

56 parameter of `jax.vmap` to vmap over all and only the chain axes found in the 

57 pytree `x`. 

58 

59 Parameters 

60 ---------- 

61 x 

62 A pytree. Subpytrees that are Module attributes marked with 

63 ``field(chains=<int>)`` are considered to have a chain axis at that 

64 index. `x` (or one of its subtrees) must define a `has_chains` property 

65 (see `get_has_chains`). 

66 

67 Returns 

68 ------- 

69 A pytree with the same structure as `x`, with each leaf set to the chain axis index declared by its owning ``field(chains=...)`` marker, normalized against the leaf's ``ndim`` (so the returned indices are non-negative), or `None` for unmarked leaves. If `has_chains` is `False`, every leaf is `None`. 

70 """ 

71 if not get_has_chains(x): 

72 return _find_metadata(x, 'chains', marker_value=_none_marker) 

73 

74 return _find_metadata(x, 'chains') 

75 

76 

77def _none_marker(leaf: object, raw: int) -> None: # noqa: ARG001 

78 """Marker mapper that always returns `None`.""" 

79 return None # noqa: RET501 

80 

81 

82def data_vmap_axes(x: PyTree[EquinoxModule | Any, 'T']) -> PyTree[int | None, 'T ...']: 

83 """Determine vmapping axes for data. 

84 

85 Parameters 

86 ---------- 

87 x 

88 A pytree. Subpytrees that are Module attributes marked with 

89 ``field(data=<int>)`` are considered to have a data axis at that 

90 position in the chain-less layout. `x` (or one of its subtrees) must 

91 define a `has_chains` property (see `get_has_chains`). 

92 

93 Returns 

94 ------- 

95 A pytree with the same structure as `x`, with each leaf set to the data axis index (normalized and chain-shifted), or `None` for unmarked leaves. 

96 """ 

97 chain_axes = chain_vmap_axes(x) 

98 data_raw = _find_metadata(x, 'data', marker_value=_raw_marker) 

99 return tree.map( 

100 _compute_core_axis, x, data_raw, chain_axes, is_leaf=_is_core_axis_leaf 

101 ) 

102 

103 

104def trace_sample_axes( 

105 trace: PyTree[EquinoxModule | Any, 'T'], 

106) -> PyTree[int | None, 'T ...']: 

107 """Determine the position of the sample axis for each leaf of a trace. 

108 

109 Parameters 

110 ---------- 

111 trace 

112 A trace pytree (typically a `~bartz.mcmcloop.BurninTrace` or 

113 `~bartz.mcmcloop.MainTrace`). `trace` (or one of its subtrees) must 

114 define a `has_chains` property. 

115 

116 Returns 

117 ------- 

118 A pytree with the same structure as `trace` but with sample axes in the leaves, see `field`. 

119 """ 

120 chain_axes = chain_vmap_axes(trace) 

121 sample_raw = _find_metadata(trace, 'samples', marker_value=_raw_marker) 

122 return tree.map( 

123 _compute_core_axis, trace, sample_raw, chain_axes, is_leaf=_is_core_axis_leaf 

124 ) 

125 

126 

127def _raw_marker(leaf: object, raw: int) -> int: # noqa: ARG001 

128 """Marker mapper that returns the raw marker value.""" 

129 return raw 

130 

131 

132def _is_core_axis_leaf(x: object) -> bool: 

133 """Treat `None` and `_LazyArray` as leaves when resolving core-axis markers.""" 

134 return x is None or _is_lazy_array(x) 

135 

136 

137def chainful_axis(core_axis: int, chain_axis: int | None) -> int: 

138 """Position of a chainless-layout axis in the corresponding chainful array. 

139 

140 Parameters 

141 ---------- 

142 core_axis 

143 Non-negative axis position in the chainless ("core") layout. 

144 chain_axis 

145 Non-negative position of the chain axis in the chainful layout, or 

146 `None` if there is no chain axis. 

147 

148 Returns 

149 ------- 

150 The non-negative position of `core_axis` after inserting the chain axis at `chain_axis`. 

151 """ 

152 if chain_axis is None or core_axis < chain_axis: 

153 return core_axis 

154 return core_axis + 1 

155 

156 

157def chain_to_axis( 

158 arr: Shaped[Array, '...'], chain_axis: int | None, target: int = 0 

159) -> Shaped[Array, '...']: 

160 """Move `chain_axis` of `arr` to position `target`. 

161 

162 Helper for the common pattern of normalizing the chain axis position in 

163 arrays derived from chain-marked Module fields. Pair it with 

164 `chain_vmap_axes` to fetch the source axis from a dataclass. 

165 

166 Parameters 

167 ---------- 

168 arr 

169 Array to be reordered. 

170 chain_axis 

171 Source position of the chain axis, or `None` for arrays with no chain 

172 axis (in which case `arr` is returned unchanged). 

173 target 

174 Destination position of the chain axis. 

175 

176 Returns 

177 ------- 

178 The reordered array. 

179 """ 

180 if chain_axis is None: 

181 return arr 

182 return jnp.moveaxis(arr, chain_axis, target) 

183 

184 

185def _compute_core_axis( 

186 leaf: Shaped[DummyArray, '...'] | None, raw_axis: int | None, chain_axis: int | None 

187) -> int | None: 

188 """Combine a raw core-layout marker and a (normalized) chain position.""" 

189 if raw_axis is None: 

190 return None 

191 assert leaf is not None 

192 has_chain = chain_axis is not None 

193 core_ndim = leaf.ndim - (1 if has_chain else 0) 

194 axis = normalize_axis_index(raw_axis, core_ndim) 

195 return chainful_axis(axis, chain_axis) 

196 

197 

198class _HasChainsFound(Exception): 

199 """Internal control-flow signal carrying a found `has_chains` value.""" 

200 

201 def __init__(self, value: bool) -> None: 

202 self.value = value 

203 

204 

205def get_has_chains(x: PyTree) -> bool: 

206 """Return the `has_chains` flag from the first node in `x` that defines it. 

207 

208 Walks `x` and stops at the first node exposing a `has_chains` attribute, 

209 returning its value. The walk uses `jax.tree.map` with an `is_leaf` callback 

210 that raises a custom exception to short-circuit traversal. 

211 

212 Parameters 

213 ---------- 

214 x 

215 A pytree, possibly containing nodes that define a `has_chains` 

216 attribute. 

217 

218 Returns 

219 ------- 

220 The value of `has_chains` on the first matching node. 

221 

222 Raises 

223 ------ 

224 ValueError 

225 If no node in `x` defines a `has_chains` property. 

226 """ 

227 

228 def is_leaf(node: object) -> bool: 

229 value = getattr(node, 'has_chains', None) 

230 if value is None: 

231 return False 

232 raise _HasChainsFound(value) 

233 

234 try: 

235 tree.map(lambda _: None, x, is_leaf=is_leaf) 

236 except _HasChainsFound as exc: 

237 return exc.value 

238 msg = 'no `has_chains` property found in the pytree' 

239 raise ValueError(msg) 

240 

241 

242def _normalize_axis_for_leaf(leaf: Shaped[DummyArray, '...'], raw: int) -> int: 

243 """Normalize a marker axis index against `leaf.ndim`. 

244 

245 Raises `numpy.exceptions.AxisError` if `raw` is out of bounds for 

246 `leaf.ndim`. 

247 """ 

248 return normalize_axis_index(raw, leaf.ndim) 

249 

250 

251def _is_lazy_array(x: object) -> bool: 

252 return isinstance(x, _LazyArray) 

253 

254 

255def _is_module(x: object) -> bool: 

256 return isinstance(x, EquinoxModule) and not _is_lazy_array(x) 

257 

258 

259def _find_metadata( 

260 x: PyTree[Any, ' S'], 

261 key: Hashable, 

262 *, 

263 marker_value: Callable[[DummyArray, int], object] = _normalize_axis_for_leaf, 

264 default_value: object = None, 

265) -> PyTree[Any, ' S ...']: 

266 """Walk `x` replacing marked subtrees with derived values. 

267 

268 For each Module field whose metadata contains `key`, the field's subtree 

269 is replaced by mapping ``marker_value(leaf, raw)`` over its leaves, where 

270 `raw` is the unnormalized metadata value; leaves outside any marked field 

271 become `default_value`. 

272 """ 

273 if _is_module(x): 273 ↛ 302line 273 didn't jump to line 302 because the condition on line 273 was always true

274 args = [] 

275 for f in fields(x): 

276 v = getattr(x, f.name) 

277 if f.metadata.get('static', False): 

278 args.append(v) 

279 elif key in f.metadata: 

280 raw = f.metadata[key] 

281 args.append( 

282 tree.map( 

283 lambda leaf, raw=raw: marker_value(leaf, raw), 

284 v, 

285 is_leaf=_is_lazy_array, 

286 ) 

287 ) 

288 else: 

289 args.append( 

290 _find_metadata( 

291 v, key, marker_value=marker_value, default_value=default_value 

292 ) 

293 ) 

294 # rebuild bypassing the (type-checked) __init__: the result is a 

295 # same-structure pytree whose leaves are axis markers (int/None), not 

296 # the arrays the field annotations require. 

297 out = object.__new__(type(x)) 

298 for f, value in zip(fields(x), args, strict=True): 

299 object.__setattr__(out, f.name, value) 

300 return out 

301 

302 def get_axes(x: object) -> PyTree: 

303 if _is_module(x): 

304 return _find_metadata( 

305 x, key, marker_value=marker_value, default_value=default_value 

306 ) 

307 return tree.map(lambda _: default_value, x, is_leaf=_is_lazy_array) 

308 

309 return tree.map(get_axes, x, is_leaf=lambda x: isinstance(x, EquinoxModule))