Coverage for src / bartz / jaxext / __init__.py: 95%
112 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/jaxext/__init__.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Additions to jax."""
27import math
28from collections.abc import Callable, Sequence
29from functools import partial
30from typing import Any
32try:
33 from jax import shard_map # available since jax v0.6.1
34except ImportError:
35 from jax.experimental.shard_map import shard_map
37import jax
38from jax import (
39 Device,
40 device_count,
41 ensure_compile_time_eval,
42 jit,
43 lax,
44 random,
45 tree,
46 typeof,
47)
48from jax import numpy as jnp
49from jax.dtypes import prng_key
50from jax.scipy.special import ndtr
51from jax.sharding import PartitionSpec
52from jaxtyping import Array, Bool, Float32, Key, PyTree, Scalar, Shaped
54from bartz.jaxext._autobatch import autobatch # noqa: F401
55from bartz.jaxext.scipy.special import ndtri
58def vmap_nodoc(fun: Callable, *args: Any, **kw: Any) -> Callable:
59 """
60 Acts like `jax.vmap` but preserves the docstring of the function unchanged.
62 This is useful if the docstring already takes into account that the
63 arguments have additional axes due to vmap.
64 """
65 doc = fun.__doc__
66 fun = jax.vmap(fun, *args, **kw)
67 fun.__doc__ = doc
68 return fun
71def minimal_unsigned_dtype(value: int) -> jnp.dtype:
72 """Return the smallest unsigned integer dtype that can represent `value`."""
73 if value < 2**8: 2F $dceC o j Z ^ VbWbXbrevewesef xeb yeA G k deH %deeI 'dfeh a p q c v ! # _ 0 ndJ r (d% s )d' B l ( K g *d1 m ) L +dgeM ,dheN od-dO .dieP 2 3 YbQ jei n 4 z icpdt d w 5 D /dZb0b1b` 2bke{ E :d* ;d+ =d, ?d- @d. [d/ ]d: ^d; _d= `d? {d@ |d| }d} ~d~ ae6 beIbzeAeBeleCeDeEemeneteFe7 abbbS T JbU V W X Y 8 zcGeoeHeIeJeKeLeMeNeOePeQeReSe9 cbdbeb$ KbLbMbfbTegbUehbibVejbkbWeXe
74 return jnp.uint8 2F $dceC o j Z ^ VbWbXbvewef xeb yeA G k deH %deeI 'dfeh a p q c v ! # _ 0 ndJ r (d% s )d' B l ( K g *d1 m ) L +dgeM ,dheN od-dO .dieP 2 3 YbQ jei n 4 z icpdt d w 5 D /dZb0b1b` 2bke{ E :d* ;d+ =d, ?d- @d. [d/ ]d: ^d; _d= `d? {d@ |d| }d} ~d~ ae6 beIbzeAeBeleCeDeEemeneFe7 abbbS T JbU V W X Y 8 zcGeoeHeIeJeKeLeMeNeOePeQeReSe9 cbdbeb$ KbLbMbfbTegbUehbibVejbkbWeXe
75 if value < 2**16: 75 ↛ 77line 75 didn't jump to line 77 because the condition on line 75 was always true2F $dC o j Z VbWbXbresef b A G k H %dI 'dh a p q c ! # 0 ndJ r (ds )dB l K g *d1 m L +dM ,dN odO .dP 2 3 Q i n 4 z ict d 5 D /d` kelemeneteoe
76 return jnp.uint16 2F $dC o j Z VbWbXbresef b A G k H %dI 'dh a p q c ! # 0 ndJ r (ds )dB l K g *d1 m L +dM ,dN odO .dP 2 3 Q i n 4 z ict d 5 D /d` kelemeneteoe
77 if value < 2**32:
78 return jnp.uint32
79 return jnp.uint64
82@partial(jax.jit, static_argnums=(1,))
83def unique(
84 x: Shaped[Array, ' _'], size: int, fill_value: Scalar
85) -> tuple[Shaped[Array, ' {size}'], int]:
86 """
87 Restricted version of `jax.numpy.unique` that uses less memory.
89 Parameters
90 ----------
91 x
92 The input array.
93 size
94 The length of the output.
95 fill_value
96 The value to fill the output with if `size` is greater than the number
97 of unique values in `x`.
99 Returns
100 -------
101 out : Shaped[Array, '{size}']
102 The unique values in `x`, sorted, and right-padded with `fill_value`.
103 actual_length : int
104 The number of used values in `out`.
105 """
106 if x.size == 0: 2^ b c v d w Yeue3b4bfbgbhbibjbkb
107 return jnp.full(size, fill_value, x.dtype), 0 2Ye
108 if size == 0: 2^ b c v d w ue3b4bfbgbhbibjbkb
109 return jnp.empty(0, x.dtype), 0 2ue
110 x = jnp.sort(x) 2^ b c v d w 3b4bfbgbhbibjbkb
112 def loop( 2^ b c v d w 3b4bfbgbhbibjbkb
113 carry: tuple[Scalar, Scalar, Shaped[Array, ' {size}']], x: Scalar
114 ) -> tuple[tuple[Scalar, Scalar, Shaped[Array, ' {size}']], None]:
115 i_out, last, out = carry 2^ b c v d w 3b4bfbgbhbibjbkb
116 i_out = jnp.where(x == last, i_out, i_out + 1) 2^ b c v d w 3b4bfbgbhbibjbkb
117 out = out.at[i_out].set(x) 2^ b c v d w 3b4bfbgbhbibjbkb
118 return (i_out, x, out), None 2^ b c v d w 3b4bfbgbhbibjbkb
120 carry = 0, x[0], jnp.full(size, fill_value, x.dtype) 2^ b c v d w 3b4bfbgbhbibjbkb
121 (actual_length, _, out), _ = lax.scan(loop, carry, x[:size]) 2^ b c v d w 3b4bfbgbhbibjbkb
122 return out, actual_length + 1 2^ b c v d w 3b4bfbgbhbibjbkb
125class split:
126 """
127 Split a key into `num` keys.
129 Parameters
130 ----------
131 key
132 The key to split.
133 num
134 The number of keys to split into.
135 """
137 _keys: tuple[Key[Array, ''], ...]
138 _num_used: int
140 def __init__(self, key: Key[Array, ''], num: int = 2) -> None:
141 self._keys = _split_unpack(key, num) 2lbAcBcC Zeo mbj CcZ Dcnb5bEcpe0e1ef b A obk FcGcpbHcIcqbJcKch rba Lcp Mcq sbc Ncv OcPc2e! tb# Qc_ Rc0 ubScJ Tcr vbUc% Vcs wbWc' XcB xbl Yc( Zcybg 0c1c1 zbm 2c) 3cAb4c5cBb6c7cCb8c9cDb!c#cP 3e2 4e3 5e6e7eEb$c%ci Fbn 'c4 (cGb)c*ct Hbd +cw ,c5 8eD 9eZb0b1b` 2b6b7b8b9b!b#b$b%b-c.c/c:c'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdc!e;c#eec$e%e'e(e)e*e+e,e=c-e.eqe/e:e;e=e?eu @e[e]eR ^e{ _eE :d* ;d+ =d, ?d- @d. [d/ ]d: ^d; _d= `d? {d@ |d| }d} ~d~ ae6 be`ejc{ekc|elc}emc~encafocbfpccfqcdfrcefscfftcgfuchfvcifwcjfxckfyclf7 ?cab@cbb[cS ]cT ^cJb_cU `cV {cW |cX }cY ~c8 adbdcdmfnfofpfqfrfsftfuf[ vfddwfxfyfzfedAffcBfCfDfEf9 FfcbGfdbHfebIf$ JfKbKfLbLfMbMfNfOfPfQfRfSfTfUfVfWfXfYfZf0f1f2ffd3fgd4f5f6fhd7f8f9f!f#f$f] %fgc'f(f)fhc*f
142 self._num_used = 0 2lbAcBcC Zeo mbj CcZ Dcnb5bEcpe0e1ef b A obk FcGcpbHcIcqbJcKch rba Lcp Mcq sbc Ncv OcPc2e! tb# Qc_ Rc0 ubScJ Tcr vbUc% Vcs wbWc' XcB xbl Yc( Zcybg 0c1c1 zbm 2c) 3cAb4c5cBb6c7cCb8c9cDb!c#cP 3e2 4e3 5e6e7eEb$c%ci Fbn 'c4 (cGb)c*ct Hbd +cw ,c5 8eD 9eZb0b1b` 2b6b7b8b9b!b#b$b%b-c.c/c:c'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdc!e;c#eec$e%e'e(e)e*e+e,e=c-e.eqe/e:e;e=e?eu @e[e]eR ^e{ _eE :d* ;d+ =d, ?d- @d. [d/ ]d: ^d; _d= `d? {d@ |d| }d} ~d~ ae6 be`ejc{ekc|elc}emc~encafocbfpccfqcdfrcefscfftcgfuchfvcifwcjfxckfyclf7 ?cab@cbb[cS ]cT ^cJb_cU `cV {cW |cX }cY ~c8 adbdcdmfnfofpfqfrfsftfuf[ vfddwfxfyfzfedAffcBfCfDfEf9 FfcbGfdbHfebIf$ JfKbKfLbLfMbMfNfOfPfQfRfSfTfUfVfWfXfYfZf0f1f2ffd3fgd4f5f6fhd7f8f9f!f#f$f] %fgc'f(f)fhc*f
144 def __len__(self) -> int:
145 return len(self._keys) - self._num_used 2lbAcBcC o mbj CcZ Dcnb5bEcVbWbXbf b A obk FcGcpbHcIcqbJcKch rba Lcp Mcq sbc Ncv OcPc! tb# Qc_ Rc0 ubScJ Tcr vbUc% Vcs wbWc' XcB xbl Yc( Zcybg 0c1c1 zbm 2c) 3cAb4c5cBb6c7cCb8c9cDb!c#cP 2 3 YbqdEb$c%ci Fbn 'c4 (cGb)c*ct Hbd +cw ,c5 D rdZbsd0b1b` 2b6b7b8b9b!b#b$b%b-c.c/c:c'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdc;cectdudvdwdxdydzdAd=cBdCdDdEdFdx y u GdNbR { E * + , - . / : ; = ? @ | } ~ 6 Ibjckclcmcncocpcqcrcsctcucvcwcxcyc7 ?cab@cbb[cS ]cT ^cJb_cU `cV {cW |cX }cY ~c8 adzcbdcdHdObIdJdKdPbQbLdMd[ ddNdOd+f,fedfcidjdkd9 cbdbeb$ KbLbMbPdQdRdSdTdUdVdWdXdYdZd0d1d2d3d4dfdgd5d6dhd7d8d9d!d#d] gcldmdhc
147 def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, ' {shape}']:
148 """
149 Pop one or more keys from the list.
151 Parameters
152 ----------
153 shape
154 The shape of the keys to pop. If empty (default), a single key is
155 popped and returned. If not empty, the popped key is split and
156 reshaped to the target shape.
158 Returns
159 -------
160 The popped keys as a jax array with the requested shape.
162 Raises
163 ------
164 IndexError
165 If the list is empty.
166 """
167 if len(self) == 0: 2lbAcBcC o mbj CcZ Dcnb5bEcVbWbXbf b A obk FcGcpbHcIcqbJcKch rba Lcp Mcq sbc Ncv OcPc! tb# Qc_ Rc0 ubScJ Tcr vbUc% Vcs wbWc' XcB xbl Yc( Zcybg 0c1c1 zbm 2c) 3cAb4c5cBb6c7cCb8c9cDb!c#cP 2 3 YbqdEb$c%ci Fbn 'c4 (cGb)c*ct Hbd +cw ,c5 D rdZbsd0b1b` 2b6b7b8b9b!b#b$b%b-c.c/c:c'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdc;cectdudvdwdxdydzdAd=cBdCdDdEdFdx y u GdNbR { E * + , - . / : ; = ? @ | } ~ 6 Ibjckclcmcncocpcqcrcsctcucvcwcxcyc7 ?cab@cbb[cS ]cT ^cJb_cU `cV {cW |cX }cY ~c8 adzcbdcdHdObIdJdKdPbQbLdMd[ ddNdOdedfcidjdkd9 cbdbeb$ KbLbMbPdQdRdSdTdUdVdWdXdYdZd0d1d2d3d4dfdgd5d6dhd7d8d9d!d#d] gcldmdhc
168 msg = 'No keys left to pop' 1R
169 raise IndexError(msg) 1R
170 if not isinstance(shape, tuple): 2lbAcBcC o mbj CcZ Dcnb5bEcVbWbXbf b A obk FcGcpbHcIcqbJcKch rba Lcp Mcq sbc Ncv OcPc! tb# Qc_ Rc0 ubScJ Tcr vbUc% Vcs wbWc' XcB xbl Yc( Zcybg 0c1c1 zbm 2c) 3cAb4c5cBb6c7cCb8c9cDb!c#cP 2 3 YbqdEb$c%ci Fbn 'c4 (cGb)c*ct Hbd +cw ,c5 D rdZbsd0b1b` 2b6b7b8b9b!b#b$b%b-c.c/c:c'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdc;cectdudvdwdxdydzdAd=cBdCdDdEdFdx y u GdNbR { E * + , - . / : ; = ? @ | } ~ 6 Ibjckclcmcncocpcqcrcsctcucvcwcxcyc7 ?cab@cbb[cS ]cT ^cJb_cU `cV {cW |cX }cY ~c8 adzcbdcdHdObIdJdKdPbQbLdMd[ ddNdOdedfcidjdkd9 cbdbeb$ KbLbMbPdQdRdSdTdUdVdWdXdYdZd0d1d2d3d4dfdgd5d6dhd7d8d9d!d#d] gcldmdhc
171 shape = (shape,) 26b7b8b9b!b#b$b%b'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccecu R ObPbQb[ fcidjdkd9 cbdbeb] gcldmdhc
172 key = self._keys[self._num_used] 2lbAcBcC o mbj CcZ Dcnb5bEcVbWbXbf b A obk FcGcpbHcIcqbJcKch rba Lcp Mcq sbc Ncv OcPc! tb# Qc_ Rc0 ubScJ Tcr vbUc% Vcs wbWc' XcB xbl Yc( Zcybg 0c1c1 zbm 2c) 3cAb4c5cBb6c7cCb8c9cDb!c#cP 2 3 YbqdEb$c%ci Fbn 'c4 (cGb)c*ct Hbd +cw ,c5 D rdZbsd0b1b` 2b6b7b8b9b!b#b$b%b-c.c/c:c'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdc;cectdudvdwdxdydzdAd=cBdCdDdEdFdx y u GdNbR { E * + , - . / : ; = ? @ | } ~ 6 Ibjckclcmcncocpcqcrcsctcucvcwcxcyc7 ?cab@cbb[cS ]cT ^cJb_cU `cV {cW |cX }cY ~c8 adzcbdcdHdObIdJdKdPbQbLdMd[ ddNdOdedfcidjdkd9 cbdbeb$ KbLbMbPdQdRdSdTdUdVdWdXdYdZd0d1d2d3d4dfdgd5d6dhd7d8d9d!d#d] gcldmdhc
173 self._num_used += 1 2lbAcBcC o mbj CcZ Dcnb5bEcVbWbXbf b A obk FcGcpbHcIcqbJcKch rba Lcp Mcq sbc Ncv OcPc! tb# Qc_ Rc0 ubScJ Tcr vbUc% Vcs wbWc' XcB xbl Yc( Zcybg 0c1c1 zbm 2c) 3cAb4c5cBb6c7cCb8c9cDb!c#cP 2 3 YbqdEb$c%ci Fbn 'c4 (cGb)c*ct Hbd +cw ,c5 D rdZbsd0b1b` 2b6b7b8b9b!b#b$b%b-c.c/c:c'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdc;cectdudvdwdxdydzdAd=cBdCdDdEdFdx y u GdNbR { E * + , - . / : ; = ? @ | } ~ 6 Ibjckclcmcncocpcqcrcsctcucvcwcxcyc7 ?cab@cbb[cS ]cT ^cJb_cU `cV {cW |cX }cY ~c8 adzcbdcdHdObIdJdKdPbQbLdMd[ ddNdOdedfcidjdkd9 cbdbeb$ KbLbMbPdQdRdSdTdUdVdWdXdYdZd0d1d2d3d4dfdgd5d6dhd7d8d9d!d#d] gcldmdhc
174 if shape: 2lbAcBcC o mbj CcZ Dcnb5bEcVbWbXbf b A obk FcGcpbHcIcqbJcKch rba Lcp Mcq sbc Ncv OcPc! tb# Qc_ Rc0 ubScJ Tcr vbUc% Vcs wbWc' XcB xbl Yc( Zcybg 0c1c1 zbm 2c) 3cAb4c5cBb6c7cCb8c9cDb!c#cP 2 3 YbqdEb$c%ci Fbn 'c4 (cGb)c*ct Hbd +cw ,c5 D rdZbsd0b1b` 2b6b7b8b9b!b#b$b%b-c.c/c:c'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdc;cectdudvdwdxdydzdAd=cBdCdDdEdFdx y u GdNbR { E * + , - . / : ; = ? @ | } ~ 6 Ibjckclcmcncocpcqcrcsctcucvcwcxcyc7 ?cab@cbb[cS ]cT ^cJb_cU `cV {cW |cX }cY ~c8 adzcbdcdHdObIdJdKdPbQbLdMd[ ddNdOdedfcidjdkd9 cbdbeb$ KbLbMbPdQdRdSdTdUdVdWdXdYdZd0d1d2d3d4dfdgd5d6dhd7d8d9d!d#d] gcldmdhc
175 key = _split_shaped(key, shape) 2C o j Z f b A k h a p q c v 0 J r % s ' B l ( g 1 m ) P 2 3 i n 4 t d w 5 D 6b7b8b9b!b#b$b%b'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccecu NbR { E * + , - . / : ; = ? @ | } ~ 6 7 abbbS T JbU V W X Y 8 ObPbQb[ fcidjdkd9 cbdbeb$ KbLbMb] gcldmdhc
176 return key 2lbAcBcC o mbj CcZ Dcnb5bEcVbWbXbf b A obk FcGcpbHcIcqbJcKch rba Lcp Mcq sbc Ncv OcPc! tb# Qc_ Rc0 ubScJ Tcr vbUc% Vcs wbWc' XcB xbl Yc( Zcybg 0c1c1 zbm 2c) 3cAb4c5cBb6c7cCb8c9cDb!c#cP 2 3 YbqdEb$c%ci Fbn 'c4 (cGb)c*ct Hbd +cw ,c5 D rdZbsd0b1b` 2b6b7b8b9b!b#b$b%b-c.c/c:c'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdc;cectdudvdwdxdydzdAd=cBdCdDdEdFdx y u GdNbR { E * + , - . / : ; = ? @ | } ~ 6 Ibjckclcmcncocpcqcrcsctcucvcwcxcyc7 ?cab@cbb[cS ]cT ^cJb_cU `cV {cW |cX }cY ~c8 adzcbdcdHdObIdJdKdPbQbLdMd[ ddNdOdedfcidjdkd9 cbdbeb$ KbLbMbPdQdRdSdTdUdVdWdXdYdZd0d1d2d3d4dfdgd5d6dhd7d8d9d!d#d] gcldmdhc
179@partial(jit, static_argnums=(1,))
180def _split_unpack(key: Key[Array, ''], num: int) -> tuple[Key[Array, ''], ...]:
181 keys = random.split(key, num) 25bpef A h a B P i dcqeR E
182 return tuple(keys) 25bpef A h a B P i dcqeR E
185@partial(jit, static_argnums=(1,))
186def _split_shaped(
187 key: Key[Array, ''], shape: tuple[int, ...]
188) -> Key[Array, ' {shape}']:
189 num = math.prod(shape) 2C f b A h a B g i D u NbR E 6 7 ObPbQb[ 9 $ ]
190 keys = random.split(key, num) 2C f b A h a B g i D u NbR E 6 7 ObPbQb[ 9 $ ]
191 return keys.reshape(shape) 2C f b A h a B g i D u NbR E 6 7 ObPbQb[ 9 $ ]
194def truncated_normal_onesided(
195 key: Key[Array, ''],
196 shape: Sequence[int],
197 upper: Bool[Array, '*'],
198 bound: Float32[Array, '*'],
199 *,
200 clip: bool = True,
201) -> Float32[Array, '*']:
202 """
203 Sample from a one-sided truncated standard normal distribution.
205 Parameters
206 ----------
207 key
208 JAX random key.
209 shape
210 Shape of output array, broadcasted with other inputs.
211 upper
212 True for (-∞, bound], False for [bound, ∞).
213 bound
214 The truncation boundary.
215 clip
216 Whether to clip the truncated uniform samples to (0, 1) before
217 transforming them to truncated normal. Intended for debugging purposes.
219 Returns
220 -------
221 Array of samples from the truncated normal distribution.
222 """
223 # Pseudocode:
224 # | if upper:
225 # | if bound < 0:
226 # | ndtri(uniform(0, ndtr(bound))) =
227 # | ndtri(ndtr(bound) * u)
228 # | if bound > 0:
229 # | -ndtri(uniform(ndtr(-bound), 1)) =
230 # | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u))
231 # | if not upper:
232 # | if bound < 0:
233 # | ndtri(uniform(ndtr(bound), 1)) =
234 # | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u))
235 # | if bound > 0:
236 # | -ndtri(uniform(0, ndtr(-bound))) =
237 # | -ndtri(ndtr(-bound) * u)
238 shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape) 1jbkaclgmndxyu
239 bound_pos = bound > 0 1jbkaclgmndxyu
240 ndtr_bound = ndtr(bound) 1jbkaclgmndxyu
241 ndtr_neg_bound = ndtr(-bound) 1jbkaclgmndxyu
242 scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound) 1jbkaclgmndxyu
243 shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound) 1jbkaclgmndxyu
244 u = random.uniform(key, shape) 1jbkaclgmndxyu
245 left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)] 1jbkaclgmndxyu
246 right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1) 1jbkaclgmndxyu
247 truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u) 1jbkaclgmndxyu
248 if clip: 1jbkaclgmndxyu
249 # on gpu the accuracy is lower and sometimes u can reach the boundaries
250 zero = jnp.zeros((), truncated_u.dtype) 1jbkaclgmndxy
251 one = jnp.ones((), truncated_u.dtype) 1jbkaclgmndxy
252 truncated_u = jnp.clip( 1jbkaclgmndxy
253 truncated_u, jnp.nextafter(zero, one), jnp.nextafter(one, zero)
254 )
255 truncated_norm = ndtri(truncated_u) 1jbkaclgmndxyu
256 return jnp.where(bound_pos, -truncated_norm, truncated_norm) 1jbkaclgmndxyu
259def get_default_device() -> Device:
260 """Get the current default JAX device."""
261 with ensure_compile_time_eval(): 2e lbmbnbf obpbqbrbsbtbubndJ vbwbxbybzbAbBbCbod-dDbEbFbz GbicpdHb:d;d=d?d@d[d]d^d_d`d{d|d}d~daebeIbS T U V W X Y 8
262 return jnp.empty(0).device 2e lbmbnbf obpbqbrbsbtbubndJ vbwbxbybzbAbBbCbod-dDbEbFbz GbicpdHb:d;d=d?d@d[d]d^d_d`d{d|d}d~daebeIbS T U V W X Y 8
265def get_device_count() -> int:
266 """Get the number of available devices on the default platform."""
267 device = get_default_device() 2lbmbnbf obpbqbrbsbtbubvbwbxbybzbAbBbCbDbEbFbz GbicpdHbS T U V W X Y 8
268 return device_count(device.platform) 2lbmbnbf obpbqbrbsbtbubvbwbxbybzbAbBbCbDbEbFbz GbicpdHbS T U V W X Y 8
271def is_key(x: object) -> bool:
272 """Determine if `x` is a jax random key."""
273 return isinstance(x, Array) and jnp.issubdtype(x.dtype, prng_key) 2F $dceC o j Z f b A G k deH %deeI 'dfeh a p q c v ! # _ 0 ndJ r (d% s )d' B l ( K g *d1 m ) L +dgeM ,dheN od-dO .dieP 2 3 YbQ jei n 4 z icpdt d w 5 D /dNbE * + , - . / : ; = ? @ 7 abbbS T U V W X Y
276def jit_active() -> bool:
277 """Check if we are under jit."""
278 return not hasattr(jnp.empty(0), 'platform') 2F $dceC o j Z f b A G k deH %deeI 'dfeh a p q c v ! # _ 0 ndJ r (d% s )d' B l ( K g *d1 m ) L +dgeM ,dheN od-dO .dieP 2 3 YbQ jei n 4 z icpdt d w 5 D /d{ E * + , - . / : ; = ? @ | } ~ 6 Ibjckclcmcncocpcqcrcsctcucvcwcxcyc
281def _equal_shards(x: Array, axis_name: str) -> Bool[Array, '']:
282 """Check if all shards of `x` are equal, to be used in a `shard_map` context."""
283 # get axis size, this could be `size = lax.axis_size(axis_name)`, but it's
284 # supported only since jax v0.6.1
285 mesh = typeof(x).sharding.mesh 2F o f G H I h a p q r s K L M N O Q i z t RbSbTbUb
286 i = mesh.axis_names.index(axis_name) 2F o f G H I h a p q r s K L M N O Q i z t RbSbTbUb
287 size = mesh.axis_sizes[i] 2F o f G H I h a p q r s K L M N O Q i z t RbSbTbUb
289 perm = [(i, (i + 1) % size) for i in range(size)] 2F o f G H I h a p q r s K L M N O Q i z t RbSbTbUb
290 perm_x = lax.ppermute(x, axis_name, perm) 2F o f G H I h a p q r s K L M N O Q i z t RbSbTbUb
291 diff = jnp.any(x != perm_x) 2F o f G H I h a p q r s K L M N O Q i z t RbSbTbUb
292 return jnp.logical_not(lax.psum(diff, axis_name)) 2F o f G H I h a p q r s K L M N O Q i z t RbSbTbUb
295def equal_shards(
296 x: PyTree[Array, ' S'], axis_name: str, **shard_map_kwargs: Any
297) -> PyTree[Bool[Array, ''], ' S']:
298 """Check that all shards of `x` are equal across axis `axis_name`.
300 Parameters
301 ----------
302 x
303 A pytree of arrays to check. Each array is checked separately.
304 axis_name
305 The mesh axis name across which equality is checked. It's not checked
306 across other axes.
307 **shard_map_kwargs
308 Additional arguments passed to `jax.shard_map` to set up the function
309 that checks equality. You may need to specify `in_specs` passing
310 the (pytree of) `jax.sharding.PartitionSpec` that specifies how `x`
311 is sharded, if the axes are not explicit, and `mesh` if there is not
312 a default mesh set by `jax.set_mesh`.
314 Returns
315 -------
316 A pytree of booleans indicating whether each leaf is equal across devices along the mesh axis.
317 """
318 equal_shards_leaf = partial(_equal_shards, axis_name=axis_name) 2F o f G H I h a p q r s K L M N O Q i z t RbSbTbUb
320 def check_equal(x: PyTree[Array, ' S']) -> PyTree[Bool[Array, ''], ' S']: 2F o f G H I h a p q r s K L M N O Q i z t RbSbTbUb
321 return tree.map(equal_shards_leaf, x) 2F o f G H I h a p q r s K L M N O Q i z t RbSbTbUb
323 sharded_check_equal = shard_map( 2F o f G H I h a p q r s K L M N O Q i z t RbSbTbUb
324 check_equal, out_specs=PartitionSpec(), **shard_map_kwargs
325 )
327 return sharded_check_equal(x) 2F o f G H I h a p q r s K L M N O Q i z t RbSbTbUb