Coverage for src/bartz/_jaxext/_jaxext.py: 96%
122 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/_jaxext/_jaxext.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implementation of miscellaneous jax extension utilities."""
27import math
28import sys
29from collections.abc import Callable, Generator, Sequence
30from contextlib import contextmanager
31from functools import partial
32from typing import Any
34import jax
35from jax import Device, ensure_compile_time_eval, lax, random, shard_map, tree, vmap
36from jax import numpy as jnp
37from jax.dtypes import prng_key
38from jax.scipy.special import ndtr, ndtri
39from jax.sharding import PartitionSpec
40from jax.typing import DTypeLike
41from jaxtyping import (
42 Array,
43 Bool,
44 Float32,
45 Integer,
46 Key,
47 PyTree,
48 Scalar,
49 ScalarLike,
50 Shaped,
51)
52from jaxtyping import config as jaxtyping_config
54from bartz._jaxext._jit import jit
56if sys.version_info >= (3, 13):
57 from typing import TypeIs
58else: # WORKAROUND(python<3.13): typing.TypeIs was added in 3.13
59 from typing_extensions import TypeIs
62@contextmanager
63def jaxtyping_disabled() -> Generator[None, None, None]:
64 """Temporarily disable jaxtyping runtime type-checking.
66 This also disables `beartype`, because the jaxtyping import hook applies it
67 as ``jaxtyped(typechecker=beartype)`` and `jaxtyped` short-circuits to the
68 undecorated function when type-checking is disabled. Used to park
69 deliberately wrong-typed intermediates (e.g. `_LazyArray` leaves) in an
70 `equinox.Module` during construction.
71 """
72 old = jaxtyping_config.jaxtyping_disable
73 jaxtyping_config.update('jaxtyping_disable', True)
74 try:
75 yield
76 finally:
77 jaxtyping_config.update('jaxtyping_disable', old)
80def vmap_nodoc(fun: Callable, *args: Any, **kw: Any) -> Callable:
81 """
82 Acts like `jax.vmap` but preserves the docstring of the function unchanged.
84 This is useful if the docstring already takes into account that the
85 arguments have additional axes due to vmap.
86 """
87 doc = fun.__doc__
88 fun = vmap(fun, *args, **kw)
89 fun.__doc__ = doc
90 return fun
93def minimal_unsigned_dtype(value: int) -> DTypeLike:
94 """Return the smallest unsigned integer dtype that can represent `value`."""
95 if value < 2**8:
96 return jnp.uint8
97 if value < 2**16: 97 ↛ 99line 97 didn't jump to line 99 because the condition on line 97 was always true
98 return jnp.uint16
99 if value < 2**32:
100 return jnp.uint32
101 return jnp.uint64
104@jit(static_argnums=(1,))
105def unique(
106 x: Shaped[Array, ' _'], size: int, fill_value: ScalarLike
107) -> tuple[Shaped[Array, ' {size}'], int | Integer[Array, '']]:
108 """
109 Restricted version of `jax.numpy.unique` that uses less memory.
111 Parameters
112 ----------
113 x
114 The input array.
115 size
116 The length of the output.
117 fill_value
118 The value to fill the output with if `size` is greater than the number
119 of unique values in `x`.
121 Returns
122 -------
123 out : Shaped[Array, '{size}']
124 The unique values in `x`, sorted, and right-padded with `fill_value`.
125 actual_length : int
126 The number of used values in `out`.
127 """
128 if x.size == 0:
129 return jnp.full(size, fill_value, x.dtype), 0
130 if size == 0:
131 return jnp.empty(0, x.dtype), 0
132 x = jnp.sort(x)
134 def loop(
135 carry: tuple[Scalar, Scalar, Shaped[Array, ' size']], x: Scalar
136 ) -> tuple[tuple[Scalar, Scalar, Shaped[Array, ' size']], None]:
137 i_out, last, out = carry
138 i_out = jnp.where(x == last, i_out, i_out + 1)
139 out = out.at[i_out].set(x)
140 return (i_out, x, out), None
142 carry = 0, x[0], jnp.full(size, fill_value, x.dtype)
144 def run(unroll: int) -> tuple[Shaped[Array, ' size'], Scalar]:
145 (actual_length, _, out), _ = lax.scan(loop, carry, x[:size], unroll=unroll)
146 return out, actual_length + 1
148 # The optimal scan unroll is opposite on cpu and gpu (benchmarked):
149 # - gpu: the loop is dominated by per-step overhead, so a large unroll is up
150 # to ~6x faster; the run time plateaus by ~32 while compile time then grows
151 # steeply, so 32 is the sweet spot.
152 # - cpu: past ~6 the backend stops aliasing `out` in place and copies the
153 # size-`size` buffer each step (O(size**2), ~100x slower), so 2 is safest.
154 # `default` (cpu, tpu, untested backends) takes the conservative value.
155 return lax.platform_dependent(
156 cuda=partial(run, 32), rocm=partial(run, 32), default=partial(run, 2)
157 )
160class split:
161 """
162 Split a key into `num` keys.
164 Parameters
165 ----------
166 key
167 The key to split.
168 num
169 The number of keys to split into.
170 """
172 _keys: tuple[Key[Array, ''], ...]
173 _num_used: int
175 def __init__(self, key: Key[Array, ''], num: int = 2) -> None:
176 self._keys = _split_unpack(key, num)
177 self._num_used = 0
179 def __len__(self) -> int:
180 return len(self._keys) - self._num_used
182 def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, ' *shape']:
183 """
184 Pop one or more keys from the list.
186 Parameters
187 ----------
188 shape
189 The shape of the keys to pop. If empty (default), a single key is
190 popped and returned. If not empty, the popped key is split and
191 reshaped to the target shape.
193 Returns
194 -------
195 The popped keys as a jax array with the requested shape.
197 Raises
198 ------
199 IndexError
200 If the list is empty.
201 """
202 if len(self) == 0:
203 msg = 'No keys left to pop'
204 raise IndexError(msg)
205 if not isinstance(shape, tuple):
206 shape = (shape,)
207 key = self._keys[self._num_used]
208 self._num_used += 1
209 if shape:
210 key = _split_shaped(key, shape)
211 return key
214@jit(static_argnums=(1,))
215def _split_unpack(key: Key[Array, ''], num: int) -> tuple[Key[Array, ''], ...]:
216 keys = random.split(key, num)
217 return tuple(keys)
220@jit(static_argnums=(1,))
221def _split_shaped(key: Key[Array, ''], shape: tuple[int, ...]) -> Key[Array, ' *shape']:
222 num = math.prod(shape)
223 keys = random.split(key, num)
224 return keys.reshape(shape)
227def truncated_normal_onesided(
228 key: Key[Array, ''],
229 shape: Sequence[int],
230 upper: Bool[Array, '...'],
231 bound: Float32[Array, '...'],
232 *,
233 clip: bool = True,
234) -> Float32[Array, '...']:
235 """
236 Sample from a one-sided truncated standard normal distribution.
238 Parameters
239 ----------
240 key
241 JAX random key.
242 shape
243 Shape of output array, broadcasted with other inputs.
244 upper
245 True for (-∞, bound], False for [bound, ∞).
246 bound
247 The truncation boundary.
248 clip
249 Whether to clip the truncated uniform samples to (0, 1) before
250 transforming them to truncated normal. Intended for debugging purposes.
252 Returns
253 -------
254 Array of samples from the truncated normal distribution.
255 """
256 # Pseudocode:
257 # | if upper:
258 # | if bound < 0:
259 # | ndtri(uniform(0, ndtr(bound))) =
260 # | ndtri(ndtr(bound) * u)
261 # | if bound > 0:
262 # | -ndtri(uniform(ndtr(-bound), 1)) =
263 # | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u))
264 # | if not upper:
265 # | if bound < 0:
266 # | ndtri(uniform(ndtr(bound), 1)) =
267 # | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u))
268 # | if bound > 0:
269 # | -ndtri(uniform(0, ndtr(-bound))) =
270 # | -ndtri(ndtr(-bound) * u)
271 shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape)
272 bound_pos = bound > 0
273 ndtr_bound = ndtr(bound)
274 ndtr_neg_bound = ndtr(-bound)
275 scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound)
276 shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound)
277 u = random.uniform(key, shape)
278 left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)]
279 right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1)
280 truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u)
281 if clip:
282 # on gpu the accuracy is lower and sometimes u can reach the boundaries
283 zero = jnp.zeros((), truncated_u.dtype)
284 one = jnp.ones((), truncated_u.dtype)
285 truncated_u = jnp.clip(
286 truncated_u, jnp.nextafter(zero, one), jnp.nextafter(one, zero)
287 )
288 truncated_norm = ndtri(truncated_u)
289 return jnp.where(bound_pos, -truncated_norm, truncated_norm)
292def get_default_device() -> Device:
293 """Get the current default JAX device."""
294 with ensure_compile_time_eval():
295 return jnp.empty(0).device
298def get_default_devices() -> list[Device]:
299 """Get all JAX devices on the default platform."""
300 return jax.devices(get_default_device().platform)
303def get_device_count() -> int:
304 """Get the number of available devices on the default platform."""
305 return len(get_default_devices())
308def is_key(x: object) -> TypeIs[Key[Array, ' *shape']]:
309 """Determine if `x` is a jax random key."""
310 return isinstance(x, Array) and jnp.issubdtype(x.dtype, prng_key)
313def jit_active() -> bool:
314 """Check if we are under jit."""
315 return not hasattr(jnp.empty(0), 'platform')
318def _equal_shards(x: Shaped[Array, '...'], axis_name: str) -> Bool[Array, '']:
319 """Check if all shards of `x` are equal, to be used in a `shard_map` context."""
320 size = lax.axis_size(axis_name)
321 perm = [(i, (i + 1) % size) for i in range(size)]
322 perm_x = lax.ppermute(x, axis_name, perm)
323 diff = jnp.any(x != perm_x)
324 return jnp.logical_not(lax.psum(diff, axis_name))
327def equal_shards(
328 x: PyTree[Array, ' S'], axis_name: str, **shard_map_kwargs: Any
329) -> PyTree[Bool[Array, ''], ' S']:
330 """Check that all shards of `x` are equal across axis `axis_name`.
332 Parameters
333 ----------
334 x
335 A pytree of arrays to check. Each array is checked separately.
336 axis_name
337 The mesh axis name across which equality is checked. It's not checked
338 across other axes.
339 **shard_map_kwargs
340 Additional arguments passed to `jax.shard_map` to set up the function
341 that checks equality. You may need to specify `in_specs` passing
342 the (pytree of) `jax.sharding.PartitionSpec` that specifies how `x`
343 is sharded, if the axes are not explicit, and `mesh` if there is not
344 a default mesh set by `jax.set_mesh`.
346 Returns
347 -------
348 A pytree of booleans indicating whether each leaf is equal across devices along the mesh axis.
349 """
350 equal_shards_leaf = partial(_equal_shards, axis_name=axis_name)
352 def check_equal(x: PyTree[Array, ' S']) -> PyTree[Bool[Array, ''], ' S']:
353 return tree.map(equal_shards_leaf, x)
355 sharded_check_equal = shard_map(
356 check_equal, out_specs=PartitionSpec(), **shard_map_kwargs
357 )
359 return sharded_check_equal(x)