Coverage for src/bartz/_jaxext/_jaxext.py: 96%

122 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/_jaxext/_jaxext.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implementation of miscellaneous jax extension utilities.""" 

26 

27import math 

28import sys 

29from collections.abc import Callable, Generator, Sequence 

30from contextlib import contextmanager 

31from functools import partial 

32from typing import Any 

33 

34import jax 

35from jax import Device, ensure_compile_time_eval, lax, random, shard_map, tree, vmap 

36from jax import numpy as jnp 

37from jax.dtypes import prng_key 

38from jax.scipy.special import ndtr, ndtri 

39from jax.sharding import PartitionSpec 

40from jax.typing import DTypeLike 

41from jaxtyping import ( 

42 Array, 

43 Bool, 

44 Float32, 

45 Integer, 

46 Key, 

47 PyTree, 

48 Scalar, 

49 ScalarLike, 

50 Shaped, 

51) 

52from jaxtyping import config as jaxtyping_config 

53 

54from bartz._jaxext._jit import jit 

55 

56if sys.version_info >= (3, 13): 

57 from typing import TypeIs 

58else: # WORKAROUND(python<3.13): typing.TypeIs was added in 3.13 

59 from typing_extensions import TypeIs 

60 

61 

62@contextmanager 

63def jaxtyping_disabled() -> Generator[None, None, None]: 

64 """Temporarily disable jaxtyping runtime type-checking. 

65 

66 This also disables `beartype`, because the jaxtyping import hook applies it 

67 as ``jaxtyped(typechecker=beartype)`` and `jaxtyped` short-circuits to the 

68 undecorated function when type-checking is disabled. Used to park 

69 deliberately wrong-typed intermediates (e.g. `_LazyArray` leaves) in an 

70 `equinox.Module` during construction. 

71 """ 

72 old = jaxtyping_config.jaxtyping_disable 

73 jaxtyping_config.update('jaxtyping_disable', True) 

74 try: 

75 yield 

76 finally: 

77 jaxtyping_config.update('jaxtyping_disable', old) 

78 

79 

80def vmap_nodoc(fun: Callable, *args: Any, **kw: Any) -> Callable: 

81 """ 

82 Acts like `jax.vmap` but preserves the docstring of the function unchanged. 

83 

84 This is useful if the docstring already takes into account that the 

85 arguments have additional axes due to vmap. 

86 """ 

87 doc = fun.__doc__ 

88 fun = vmap(fun, *args, **kw) 

89 fun.__doc__ = doc 

90 return fun 

91 

92 

93def minimal_unsigned_dtype(value: int) -> DTypeLike: 

94 """Return the smallest unsigned integer dtype that can represent `value`.""" 

95 if value < 2**8: 

96 return jnp.uint8 

97 if value < 2**16: 97 ↛ 99line 97 didn't jump to line 99 because the condition on line 97 was always true

98 return jnp.uint16 

99 if value < 2**32: 

100 return jnp.uint32 

101 return jnp.uint64 

102 

103 

104@jit(static_argnums=(1,)) 

105def unique( 

106 x: Shaped[Array, ' _'], size: int, fill_value: ScalarLike 

107) -> tuple[Shaped[Array, ' {size}'], int | Integer[Array, '']]: 

108 """ 

109 Restricted version of `jax.numpy.unique` that uses less memory. 

110 

111 Parameters 

112 ---------- 

113 x 

114 The input array. 

115 size 

116 The length of the output. 

117 fill_value 

118 The value to fill the output with if `size` is greater than the number 

119 of unique values in `x`. 

120 

121 Returns 

122 ------- 

123 out : Shaped[Array, '{size}'] 

124 The unique values in `x`, sorted, and right-padded with `fill_value`. 

125 actual_length : int 

126 The number of used values in `out`. 

127 """ 

128 if x.size == 0: 

129 return jnp.full(size, fill_value, x.dtype), 0 

130 if size == 0: 

131 return jnp.empty(0, x.dtype), 0 

132 x = jnp.sort(x) 

133 

134 def loop( 

135 carry: tuple[Scalar, Scalar, Shaped[Array, ' size']], x: Scalar 

136 ) -> tuple[tuple[Scalar, Scalar, Shaped[Array, ' size']], None]: 

137 i_out, last, out = carry 

138 i_out = jnp.where(x == last, i_out, i_out + 1) 

139 out = out.at[i_out].set(x) 

140 return (i_out, x, out), None 

141 

142 carry = 0, x[0], jnp.full(size, fill_value, x.dtype) 

143 

144 def run(unroll: int) -> tuple[Shaped[Array, ' size'], Scalar]: 

145 (actual_length, _, out), _ = lax.scan(loop, carry, x[:size], unroll=unroll) 

146 return out, actual_length + 1 

147 

148 # The optimal scan unroll is opposite on cpu and gpu (benchmarked): 

149 # - gpu: the loop is dominated by per-step overhead, so a large unroll is up 

150 # to ~6x faster; the run time plateaus by ~32 while compile time then grows 

151 # steeply, so 32 is the sweet spot. 

152 # - cpu: past ~6 the backend stops aliasing `out` in place and copies the 

153 # size-`size` buffer each step (O(size**2), ~100x slower), so 2 is safest. 

154 # `default` (cpu, tpu, untested backends) takes the conservative value. 

155 return lax.platform_dependent( 

156 cuda=partial(run, 32), rocm=partial(run, 32), default=partial(run, 2) 

157 ) 

158 

159 

160class split: 

161 """ 

162 Split a key into `num` keys. 

163 

164 Parameters 

165 ---------- 

166 key 

167 The key to split. 

168 num 

169 The number of keys to split into. 

170 """ 

171 

172 _keys: tuple[Key[Array, ''], ...] 

173 _num_used: int 

174 

175 def __init__(self, key: Key[Array, ''], num: int = 2) -> None: 

176 self._keys = _split_unpack(key, num) 

177 self._num_used = 0 

178 

179 def __len__(self) -> int: 

180 return len(self._keys) - self._num_used 

181 

182 def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, ' *shape']: 

183 """ 

184 Pop one or more keys from the list. 

185 

186 Parameters 

187 ---------- 

188 shape 

189 The shape of the keys to pop. If empty (default), a single key is 

190 popped and returned. If not empty, the popped key is split and 

191 reshaped to the target shape. 

192 

193 Returns 

194 ------- 

195 The popped keys as a jax array with the requested shape. 

196 

197 Raises 

198 ------ 

199 IndexError 

200 If the list is empty. 

201 """ 

202 if len(self) == 0: 

203 msg = 'No keys left to pop' 

204 raise IndexError(msg) 

205 if not isinstance(shape, tuple): 

206 shape = (shape,) 

207 key = self._keys[self._num_used] 

208 self._num_used += 1 

209 if shape: 

210 key = _split_shaped(key, shape) 

211 return key 

212 

213 

214@jit(static_argnums=(1,)) 

215def _split_unpack(key: Key[Array, ''], num: int) -> tuple[Key[Array, ''], ...]: 

216 keys = random.split(key, num) 

217 return tuple(keys) 

218 

219 

220@jit(static_argnums=(1,)) 

221def _split_shaped(key: Key[Array, ''], shape: tuple[int, ...]) -> Key[Array, ' *shape']: 

222 num = math.prod(shape) 

223 keys = random.split(key, num) 

224 return keys.reshape(shape) 

225 

226 

227def truncated_normal_onesided( 

228 key: Key[Array, ''], 

229 shape: Sequence[int], 

230 upper: Bool[Array, '...'], 

231 bound: Float32[Array, '...'], 

232 *, 

233 clip: bool = True, 

234) -> Float32[Array, '...']: 

235 """ 

236 Sample from a one-sided truncated standard normal distribution. 

237 

238 Parameters 

239 ---------- 

240 key 

241 JAX random key. 

242 shape 

243 Shape of output array, broadcasted with other inputs. 

244 upper 

245 True for (-∞, bound], False for [bound, ∞). 

246 bound 

247 The truncation boundary. 

248 clip 

249 Whether to clip the truncated uniform samples to (0, 1) before 

250 transforming them to truncated normal. Intended for debugging purposes. 

251 

252 Returns 

253 ------- 

254 Array of samples from the truncated normal distribution. 

255 """ 

256 # Pseudocode: 

257 # | if upper: 

258 # | if bound < 0: 

259 # | ndtri(uniform(0, ndtr(bound))) = 

260 # | ndtri(ndtr(bound) * u) 

261 # | if bound > 0: 

262 # | -ndtri(uniform(ndtr(-bound), 1)) = 

263 # | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u)) 

264 # | if not upper: 

265 # | if bound < 0: 

266 # | ndtri(uniform(ndtr(bound), 1)) = 

267 # | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u)) 

268 # | if bound > 0: 

269 # | -ndtri(uniform(0, ndtr(-bound))) = 

270 # | -ndtri(ndtr(-bound) * u) 

271 shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape) 

272 bound_pos = bound > 0 

273 ndtr_bound = ndtr(bound) 

274 ndtr_neg_bound = ndtr(-bound) 

275 scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound) 

276 shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound) 

277 u = random.uniform(key, shape) 

278 left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)] 

279 right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1) 

280 truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u) 

281 if clip: 

282 # on gpu the accuracy is lower and sometimes u can reach the boundaries 

283 zero = jnp.zeros((), truncated_u.dtype) 

284 one = jnp.ones((), truncated_u.dtype) 

285 truncated_u = jnp.clip( 

286 truncated_u, jnp.nextafter(zero, one), jnp.nextafter(one, zero) 

287 ) 

288 truncated_norm = ndtri(truncated_u) 

289 return jnp.where(bound_pos, -truncated_norm, truncated_norm) 

290 

291 

292def get_default_device() -> Device: 

293 """Get the current default JAX device.""" 

294 with ensure_compile_time_eval(): 

295 return jnp.empty(0).device 

296 

297 

298def get_default_devices() -> list[Device]: 

299 """Get all JAX devices on the default platform.""" 

300 return jax.devices(get_default_device().platform) 

301 

302 

303def get_device_count() -> int: 

304 """Get the number of available devices on the default platform.""" 

305 return len(get_default_devices()) 

306 

307 

308def is_key(x: object) -> TypeIs[Key[Array, ' *shape']]: 

309 """Determine if `x` is a jax random key.""" 

310 return isinstance(x, Array) and jnp.issubdtype(x.dtype, prng_key) 

311 

312 

313def jit_active() -> bool: 

314 """Check if we are under jit.""" 

315 return not hasattr(jnp.empty(0), 'platform') 

316 

317 

318def _equal_shards(x: Shaped[Array, '...'], axis_name: str) -> Bool[Array, '']: 

319 """Check if all shards of `x` are equal, to be used in a `shard_map` context.""" 

320 size = lax.axis_size(axis_name) 

321 perm = [(i, (i + 1) % size) for i in range(size)] 

322 perm_x = lax.ppermute(x, axis_name, perm) 

323 diff = jnp.any(x != perm_x) 

324 return jnp.logical_not(lax.psum(diff, axis_name)) 

325 

326 

327def equal_shards( 

328 x: PyTree[Array, ' S'], axis_name: str, **shard_map_kwargs: Any 

329) -> PyTree[Bool[Array, ''], ' S']: 

330 """Check that all shards of `x` are equal across axis `axis_name`. 

331 

332 Parameters 

333 ---------- 

334 x 

335 A pytree of arrays to check. Each array is checked separately. 

336 axis_name 

337 The mesh axis name across which equality is checked. It's not checked 

338 across other axes. 

339 **shard_map_kwargs 

340 Additional arguments passed to `jax.shard_map` to set up the function 

341 that checks equality. You may need to specify `in_specs` passing 

342 the (pytree of) `jax.sharding.PartitionSpec` that specifies how `x` 

343 is sharded, if the axes are not explicit, and `mesh` if there is not 

344 a default mesh set by `jax.set_mesh`. 

345 

346 Returns 

347 ------- 

348 A pytree of booleans indicating whether each leaf is equal across devices along the mesh axis. 

349 """ 

350 equal_shards_leaf = partial(_equal_shards, axis_name=axis_name) 

351 

352 def check_equal(x: PyTree[Array, ' S']) -> PyTree[Bool[Array, ''], ' S']: 

353 return tree.map(equal_shards_leaf, x) 

354 

355 sharded_check_equal = shard_map( 

356 check_equal, out_specs=PartitionSpec(), **shard_map_kwargs 

357 ) 

358 

359 return sharded_check_equal(x)