Coverage for src / bartz / jaxext / __init__.py: 94%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 18:11 +0000

1# bartz/src/bartz/jaxext/__init__.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Additions to jax.""" 

26 

27import math 

28from collections.abc import Callable, Sequence 

29from functools import partial 

30from typing import Any 

31 

32# WORKAROUND(jax<0.6.1): shard_map was promoted from jax.experimental to top-level in 0.6.1 

33try: 

34 from jax import shard_map 

35except ImportError: 

36 from jax.experimental.shard_map import shard_map 

37 

38import jax 

39from jax import ( 

40 Device, 

41 device_count, 

42 ensure_compile_time_eval, 

43 jit, 

44 lax, 

45 random, 

46 tree, 

47 typeof, 

48) 

49from jax import numpy as jnp 

50from jax.dtypes import prng_key 

51from jax.scipy.special import ndtr 

52from jax.sharding import PartitionSpec 

53from jaxtyping import Array, Bool, Float32, Key, PyTree, Scalar, Shaped 

54 

55from bartz.jaxext._autobatch import autobatch # noqa: F401 

56from bartz.jaxext.scipy.special import ndtri 

57 

58 

59def vmap_nodoc(fun: Callable, *args: Any, **kw: Any) -> Callable: 

60 """ 

61 Acts like `jax.vmap` but preserves the docstring of the function unchanged. 

62 

63 This is useful if the docstring already takes into account that the 

64 arguments have additional axes due to vmap. 

65 """ 

66 doc = fun.__doc__ 

67 fun = jax.vmap(fun, *args, **kw) 

68 fun.__doc__ = doc 

69 return fun 

70 

71 

72def minimal_unsigned_dtype(value: int) -> jnp.dtype: 

73 """Return the smallest unsigned integer dtype that can represent `value`.""" 

74 if value < 2**8: 1cjl

75 return jnp.uint8 1cj

76 if value < 2**16: 76 ↛ 78line 76 didn't jump to line 78 because the condition on line 76 was always true1cl

77 return jnp.uint16 1cl

78 if value < 2**32: 

79 return jnp.uint32 

80 return jnp.uint64 

81 

82 

83@partial(jax.jit, static_argnums=(1,)) 

84def unique( 

85 x: Shaped[Array, ' _'], size: int, fill_value: Scalar 

86) -> tuple[Shaped[Array, ' {size}'], int]: 

87 """ 

88 Restricted version of `jax.numpy.unique` that uses less memory. 

89 

90 Parameters 

91 ---------- 

92 x 

93 The input array. 

94 size 

95 The length of the output. 

96 fill_value 

97 The value to fill the output with if `size` is greater than the number 

98 of unique values in `x`. 

99 

100 Returns 

101 ------- 

102 out : Shaped[Array, '{size}'] 

103 The unique values in `x`, sorted, and right-padded with `fill_value`. 

104 actual_length : int 

105 The number of used values in `out`. 

106 """ 

107 if x.size == 0: 1amg

108 return jnp.full(size, fill_value, x.dtype), 0 1m

109 if size == 0: 1ang

110 return jnp.empty(0, x.dtype), 0 1n

111 x = jnp.sort(x) 1ag

112 

113 def loop( 1ag

114 carry: tuple[Scalar, Scalar, Shaped[Array, ' {size}']], x: Scalar 

115 ) -> tuple[tuple[Scalar, Scalar, Shaped[Array, ' {size}']], None]: 

116 i_out, last, out = carry 1ag

117 i_out = jnp.where(x == last, i_out, i_out + 1) 1ag

118 out = out.at[i_out].set(x) 1ag

119 return (i_out, x, out), None 1ag

120 

121 carry = 0, x[0], jnp.full(size, fill_value, x.dtype) 1ag

122 (actual_length, _, out), _ = lax.scan(loop, carry, x[:size]) 1ag

123 return out, actual_length + 1 1ag

124 

125 

126class split: 

127 """ 

128 Split a key into `num` keys. 

129 

130 Parameters 

131 ---------- 

132 key 

133 The key to split. 

134 num 

135 The number of keys to split into. 

136 """ 

137 

138 _keys: tuple[Key[Array, ''], ...] 

139 _num_used: int 

140 

141 def __init__(self, key: Key[Array, ''], num: int = 2) -> None: 

142 self._keys = _split_unpack(key, num) 1ck

143 self._num_used = 0 1ck

144 

145 def __len__(self) -> int: 

146 return len(self._keys) - self._num_used 1ci

147 

148 def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, ' {shape}']: 

149 """ 

150 Pop one or more keys from the list. 

151 

152 Parameters 

153 ---------- 

154 shape 

155 The shape of the keys to pop. If empty (default), a single key is 

156 popped and returned. If not empty, the popped key is split and 

157 reshaped to the target shape. 

158 

159 Returns 

160 ------- 

161 The popped keys as a jax array with the requested shape. 

162 

163 Raises 

164 ------ 

165 IndexError 

166 If the list is empty. 

167 """ 

168 if len(self) == 0: 1cih

169 msg = 'No keys left to pop' 1h

170 raise IndexError(msg) 1h

171 if not isinstance(shape, tuple): 1coih

172 shape = (shape,) 1oh

173 key = self._keys[self._num_used] 1ci

174 self._num_used += 1 1ci

175 if shape: 1cih

176 key = _split_shaped(key, shape) 1ch

177 return key 1ci

178 

179 

180@partial(jit, static_argnums=(1,)) 

181def _split_unpack(key: Key[Array, ''], num: int) -> tuple[Key[Array, ''], ...]: 

182 keys = random.split(key, num) 1ck

183 return tuple(keys) 1ck

184 

185 

186@partial(jit, static_argnums=(1,)) 

187def _split_shaped( 

188 key: Key[Array, ''], shape: tuple[int, ...] 

189) -> Key[Array, ' {shape}']: 

190 num = math.prod(shape) 1ch

191 keys = random.split(key, num) 1ch

192 return keys.reshape(shape) 1ch

193 

194 

195def truncated_normal_onesided( 

196 key: Key[Array, ''], 

197 shape: Sequence[int], 

198 upper: Bool[Array, '*'], 

199 bound: Float32[Array, '*'], 

200 *, 

201 clip: bool = True, 

202) -> Float32[Array, '*']: 

203 """ 

204 Sample from a one-sided truncated standard normal distribution. 

205 

206 Parameters 

207 ---------- 

208 key 

209 JAX random key. 

210 shape 

211 Shape of output array, broadcasted with other inputs. 

212 upper 

213 True for (-∞, bound], False for [bound, ∞). 

214 bound 

215 The truncation boundary. 

216 clip 

217 Whether to clip the truncated uniform samples to (0, 1) before 

218 transforming them to truncated normal. Intended for debugging purposes. 

219 

220 Returns 

221 ------- 

222 Array of samples from the truncated normal distribution. 

223 """ 

224 # Pseudocode: 

225 # | if upper: 

226 # | if bound < 0: 

227 # | ndtri(uniform(0, ndtr(bound))) = 

228 # | ndtri(ndtr(bound) * u) 

229 # | if bound > 0: 

230 # | -ndtri(uniform(ndtr(-bound), 1)) = 

231 # | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u)) 

232 # | if not upper: 

233 # | if bound < 0: 

234 # | ndtri(uniform(ndtr(bound), 1)) = 

235 # | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u)) 

236 # | if bound > 0: 

237 # | -ndtri(uniform(0, ndtr(-bound))) = 

238 # | -ndtri(ndtr(-bound) * u) 

239 shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape) 1ad

240 bound_pos = bound > 0 1ad

241 ndtr_bound = ndtr(bound) 1ad

242 ndtr_neg_bound = ndtr(-bound) 1ad

243 scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound) 1ad

244 shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound) 1ad

245 u = random.uniform(key, shape) 1ad

246 left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)] 1ad

247 right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1) 1ad

248 truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u) 1ad

249 if clip: 1adp

250 # on gpu the accuracy is lower and sometimes u can reach the boundaries 

251 zero = jnp.zeros((), truncated_u.dtype) 1ad

252 one = jnp.ones((), truncated_u.dtype) 1ad

253 truncated_u = jnp.clip( 1ad

254 truncated_u, jnp.nextafter(zero, one), jnp.nextafter(one, zero) 

255 ) 

256 truncated_norm = ndtri(truncated_u) 1adp

257 return jnp.where(bound_pos, -truncated_norm, truncated_norm) 1ad

258 

259 

260def get_default_device() -> Device: 

261 """Get the current default JAX device.""" 

262 with ensure_compile_time_eval(): 1bcaj

263 return jnp.empty(0).device 1bcaj

264 

265 

266def get_device_count() -> int: 

267 """Get the number of available devices on the default platform.""" 

268 device = get_default_device() 1aq

269 return device_count(device.platform) 1aq

270 

271 

272def is_key(x: object) -> bool: 

273 """Determine if `x` is a jax random key.""" 

274 return isinstance(x, Array) and jnp.issubdtype(x.dtype, prng_key) 1crs

275 

276 

277def jit_active() -> bool: 

278 """Check if we are under jit.""" 

279 return not hasattr(jnp.empty(0), 'platform') 1ct

280 

281 

282def _equal_shards(x: Array, axis_name: str) -> Bool[Array, '']: 

283 """Check if all shards of `x` are equal, to be used in a `shard_map` context.""" 

284 # WORKAROUND(jax<0.6.1): could be `size = lax.axis_size(axis_name)` 

285 mesh = typeof(x).sharding.mesh 1aef

286 i = mesh.axis_names.index(axis_name) 1aef

287 size = mesh.axis_sizes[i] 1aef

288 

289 perm = [(i, (i + 1) % size) for i in range(size)] 1aef

290 perm_x = lax.ppermute(x, axis_name, perm) 1aef

291 diff = jnp.any(x != perm_x) 1aef

292 return jnp.logical_not(lax.psum(diff, axis_name)) 1aef

293 

294 

295def equal_shards( 

296 x: PyTree[Array, ' S'], axis_name: str, **shard_map_kwargs: Any 

297) -> PyTree[Bool[Array, ''], ' S']: 

298 """Check that all shards of `x` are equal across axis `axis_name`. 

299 

300 Parameters 

301 ---------- 

302 x 

303 A pytree of arrays to check. Each array is checked separately. 

304 axis_name 

305 The mesh axis name across which equality is checked. It's not checked 

306 across other axes. 

307 **shard_map_kwargs 

308 Additional arguments passed to `jax.shard_map` to set up the function 

309 that checks equality. You may need to specify `in_specs` passing 

310 the (pytree of) `jax.sharding.PartitionSpec` that specifies how `x` 

311 is sharded, if the axes are not explicit, and `mesh` if there is not 

312 a default mesh set by `jax.set_mesh`. 

313 

314 Returns 

315 ------- 

316 A pytree of booleans indicating whether each leaf is equal across devices along the mesh axis. 

317 """ 

318 equal_shards_leaf = partial(_equal_shards, axis_name=axis_name) 1aef

319 

320 def check_equal(x: PyTree[Array, ' S']) -> PyTree[Bool[Array, ''], ' S']: 1aef

321 return tree.map(equal_shards_leaf, x) 1aef

322 

323 sharded_check_equal = shard_map( 1aef

324 check_equal, out_specs=PartitionSpec(), **shard_map_kwargs 

325 ) 

326 

327 return sharded_check_equal(x) 1aef