Coverage for src / bartz / jaxext / __init__.py: 94%
112 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
1# bartz/src/bartz/jaxext/__init__.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Additions to jax."""
27import math
28from collections.abc import Callable, Sequence
29from functools import partial
30from typing import Any
32# WORKAROUND(jax<0.6.1): shard_map was promoted from jax.experimental to top-level in 0.6.1
33try:
34 from jax import shard_map
35except ImportError:
36 from jax.experimental.shard_map import shard_map
38import jax
39from jax import (
40 Device,
41 device_count,
42 ensure_compile_time_eval,
43 jit,
44 lax,
45 random,
46 tree,
47 typeof,
48)
49from jax import numpy as jnp
50from jax.dtypes import prng_key
51from jax.scipy.special import ndtr
52from jax.sharding import PartitionSpec
53from jaxtyping import Array, Bool, Float32, Key, PyTree, Scalar, Shaped
55from bartz.jaxext._autobatch import autobatch # noqa: F401
56from bartz.jaxext.scipy.special import ndtri
59def vmap_nodoc(fun: Callable, *args: Any, **kw: Any) -> Callable:
60 """
61 Acts like `jax.vmap` but preserves the docstring of the function unchanged.
63 This is useful if the docstring already takes into account that the
64 arguments have additional axes due to vmap.
65 """
66 doc = fun.__doc__
67 fun = jax.vmap(fun, *args, **kw)
68 fun.__doc__ = doc
69 return fun
72def minimal_unsigned_dtype(value: int) -> jnp.dtype:
73 """Return the smallest unsigned integer dtype that can represent `value`."""
74 if value < 2**8: 1cjl
75 return jnp.uint8 1cj
76 if value < 2**16: 76 ↛ 78line 76 didn't jump to line 78 because the condition on line 76 was always true1cl
77 return jnp.uint16 1cl
78 if value < 2**32:
79 return jnp.uint32
80 return jnp.uint64
83@partial(jax.jit, static_argnums=(1,))
84def unique(
85 x: Shaped[Array, ' _'], size: int, fill_value: Scalar
86) -> tuple[Shaped[Array, ' {size}'], int]:
87 """
88 Restricted version of `jax.numpy.unique` that uses less memory.
90 Parameters
91 ----------
92 x
93 The input array.
94 size
95 The length of the output.
96 fill_value
97 The value to fill the output with if `size` is greater than the number
98 of unique values in `x`.
100 Returns
101 -------
102 out : Shaped[Array, '{size}']
103 The unique values in `x`, sorted, and right-padded with `fill_value`.
104 actual_length : int
105 The number of used values in `out`.
106 """
107 if x.size == 0: 1amg
108 return jnp.full(size, fill_value, x.dtype), 0 1m
109 if size == 0: 1ang
110 return jnp.empty(0, x.dtype), 0 1n
111 x = jnp.sort(x) 1ag
113 def loop( 1ag
114 carry: tuple[Scalar, Scalar, Shaped[Array, ' {size}']], x: Scalar
115 ) -> tuple[tuple[Scalar, Scalar, Shaped[Array, ' {size}']], None]:
116 i_out, last, out = carry 1ag
117 i_out = jnp.where(x == last, i_out, i_out + 1) 1ag
118 out = out.at[i_out].set(x) 1ag
119 return (i_out, x, out), None 1ag
121 carry = 0, x[0], jnp.full(size, fill_value, x.dtype) 1ag
122 (actual_length, _, out), _ = lax.scan(loop, carry, x[:size]) 1ag
123 return out, actual_length + 1 1ag
126class split:
127 """
128 Split a key into `num` keys.
130 Parameters
131 ----------
132 key
133 The key to split.
134 num
135 The number of keys to split into.
136 """
138 _keys: tuple[Key[Array, ''], ...]
139 _num_used: int
141 def __init__(self, key: Key[Array, ''], num: int = 2) -> None:
142 self._keys = _split_unpack(key, num) 1ck
143 self._num_used = 0 1ck
145 def __len__(self) -> int:
146 return len(self._keys) - self._num_used 1ci
148 def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, ' {shape}']:
149 """
150 Pop one or more keys from the list.
152 Parameters
153 ----------
154 shape
155 The shape of the keys to pop. If empty (default), a single key is
156 popped and returned. If not empty, the popped key is split and
157 reshaped to the target shape.
159 Returns
160 -------
161 The popped keys as a jax array with the requested shape.
163 Raises
164 ------
165 IndexError
166 If the list is empty.
167 """
168 if len(self) == 0: 1cih
169 msg = 'No keys left to pop' 1h
170 raise IndexError(msg) 1h
171 if not isinstance(shape, tuple): 1coih
172 shape = (shape,) 1oh
173 key = self._keys[self._num_used] 1ci
174 self._num_used += 1 1ci
175 if shape: 1cih
176 key = _split_shaped(key, shape) 1ch
177 return key 1ci
180@partial(jit, static_argnums=(1,))
181def _split_unpack(key: Key[Array, ''], num: int) -> tuple[Key[Array, ''], ...]:
182 keys = random.split(key, num) 1ck
183 return tuple(keys) 1ck
186@partial(jit, static_argnums=(1,))
187def _split_shaped(
188 key: Key[Array, ''], shape: tuple[int, ...]
189) -> Key[Array, ' {shape}']:
190 num = math.prod(shape) 1ch
191 keys = random.split(key, num) 1ch
192 return keys.reshape(shape) 1ch
195def truncated_normal_onesided(
196 key: Key[Array, ''],
197 shape: Sequence[int],
198 upper: Bool[Array, '*'],
199 bound: Float32[Array, '*'],
200 *,
201 clip: bool = True,
202) -> Float32[Array, '*']:
203 """
204 Sample from a one-sided truncated standard normal distribution.
206 Parameters
207 ----------
208 key
209 JAX random key.
210 shape
211 Shape of output array, broadcasted with other inputs.
212 upper
213 True for (-∞, bound], False for [bound, ∞).
214 bound
215 The truncation boundary.
216 clip
217 Whether to clip the truncated uniform samples to (0, 1) before
218 transforming them to truncated normal. Intended for debugging purposes.
220 Returns
221 -------
222 Array of samples from the truncated normal distribution.
223 """
224 # Pseudocode:
225 # | if upper:
226 # | if bound < 0:
227 # | ndtri(uniform(0, ndtr(bound))) =
228 # | ndtri(ndtr(bound) * u)
229 # | if bound > 0:
230 # | -ndtri(uniform(ndtr(-bound), 1)) =
231 # | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u))
232 # | if not upper:
233 # | if bound < 0:
234 # | ndtri(uniform(ndtr(bound), 1)) =
235 # | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u))
236 # | if bound > 0:
237 # | -ndtri(uniform(0, ndtr(-bound))) =
238 # | -ndtri(ndtr(-bound) * u)
239 shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape) 1ad
240 bound_pos = bound > 0 1ad
241 ndtr_bound = ndtr(bound) 1ad
242 ndtr_neg_bound = ndtr(-bound) 1ad
243 scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound) 1ad
244 shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound) 1ad
245 u = random.uniform(key, shape) 1ad
246 left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)] 1ad
247 right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1) 1ad
248 truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u) 1ad
249 if clip: 1adp
250 # on gpu the accuracy is lower and sometimes u can reach the boundaries
251 zero = jnp.zeros((), truncated_u.dtype) 1ad
252 one = jnp.ones((), truncated_u.dtype) 1ad
253 truncated_u = jnp.clip( 1ad
254 truncated_u, jnp.nextafter(zero, one), jnp.nextafter(one, zero)
255 )
256 truncated_norm = ndtri(truncated_u) 1adp
257 return jnp.where(bound_pos, -truncated_norm, truncated_norm) 1ad
260def get_default_device() -> Device:
261 """Get the current default JAX device."""
262 with ensure_compile_time_eval(): 1bcaj
263 return jnp.empty(0).device 1bcaj
266def get_device_count() -> int:
267 """Get the number of available devices on the default platform."""
268 device = get_default_device() 1aq
269 return device_count(device.platform) 1aq
272def is_key(x: object) -> bool:
273 """Determine if `x` is a jax random key."""
274 return isinstance(x, Array) and jnp.issubdtype(x.dtype, prng_key) 1crs
277def jit_active() -> bool:
278 """Check if we are under jit."""
279 return not hasattr(jnp.empty(0), 'platform') 1ct
282def _equal_shards(x: Array, axis_name: str) -> Bool[Array, '']:
283 """Check if all shards of `x` are equal, to be used in a `shard_map` context."""
284 # WORKAROUND(jax<0.6.1): could be `size = lax.axis_size(axis_name)`
285 mesh = typeof(x).sharding.mesh 1aef
286 i = mesh.axis_names.index(axis_name) 1aef
287 size = mesh.axis_sizes[i] 1aef
289 perm = [(i, (i + 1) % size) for i in range(size)] 1aef
290 perm_x = lax.ppermute(x, axis_name, perm) 1aef
291 diff = jnp.any(x != perm_x) 1aef
292 return jnp.logical_not(lax.psum(diff, axis_name)) 1aef
295def equal_shards(
296 x: PyTree[Array, ' S'], axis_name: str, **shard_map_kwargs: Any
297) -> PyTree[Bool[Array, ''], ' S']:
298 """Check that all shards of `x` are equal across axis `axis_name`.
300 Parameters
301 ----------
302 x
303 A pytree of arrays to check. Each array is checked separately.
304 axis_name
305 The mesh axis name across which equality is checked. It's not checked
306 across other axes.
307 **shard_map_kwargs
308 Additional arguments passed to `jax.shard_map` to set up the function
309 that checks equality. You may need to specify `in_specs` passing
310 the (pytree of) `jax.sharding.PartitionSpec` that specifies how `x`
311 is sharded, if the axes are not explicit, and `mesh` if there is not
312 a default mesh set by `jax.set_mesh`.
314 Returns
315 -------
316 A pytree of booleans indicating whether each leaf is equal across devices along the mesh axis.
317 """
318 equal_shards_leaf = partial(_equal_shards, axis_name=axis_name) 1aef
320 def check_equal(x: PyTree[Array, ' S']) -> PyTree[Bool[Array, ''], ' S']: 1aef
321 return tree.map(equal_shards_leaf, x) 1aef
323 sharded_check_equal = shard_map( 1aef
324 check_equal, out_specs=PartitionSpec(), **shard_map_kwargs
325 )
327 return sharded_check_equal(x) 1aef