Coverage for src / bartz / jaxext / scipy / special.py: 96%

63 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/jaxext/scipy/special.py 

2# 

3# Copyright (c) 2025-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Mockup of the :external:py:mod:`scipy.special` module.""" 

26 

27from collections.abc import Callable, Sequence 

28from functools import wraps 

29from typing import Any 

30 

31from jax import ShapeDtypeStruct, jit, pure_callback 

32from jax import numpy as jnp 

33from jax.typing import DTypeLike 

34from jaxtyping import Array, Float 

35from scipy.special import gammainccinv as scipy_gammainccinv 

36 

37 

38def _float_type(*args: DTypeLike | Array) -> jnp.dtype: 

39 """Determine the jax floating point result type given operands/types.""" 

40 t = jnp.result_type(*args) 1CDE

41 return jnp.sin(jnp.empty(0, t)).dtype 1CDE

42 

43 

44def _castto(func: Callable[..., Array], dtype: DTypeLike) -> Callable[..., Array]: 

45 @wraps(func) 1CDE

46 def newfunc(*args: Any, **kw: Any) -> Array: 1CDE

47 return func(*args, **kw).astype(dtype) 1FGHIJKLCMNOPQRSTUVWXYZ0123456789!#$%'()*+,-./:;=?@[]^_`{DE

48 

49 return newfunc 1CDE

50 

51 

52@jit 

53def gammainccinv(a: Float[Array, '*'], y: Float[Array, '*']) -> Float[Array, '*']: 

54 """Survival function inverse of the Gamma(a, 1) distribution.""" 

55 shape = jnp.broadcast_shapes(a.shape, y.shape) 1CDE

56 dtype = _float_type(a.dtype, y.dtype) 1CDE

57 dummy = ShapeDtypeStruct(shape, dtype) 1CDE

58 ufunc = _castto(scipy_gammainccinv, dtype) 1CDE

59 return pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims') 1CDE

60 

61 

62################# COPIED AND ADAPTED FROM JAX ################## 

63# Copyright 2018 The JAX Authors. 

64# 

65# Licensed under the Apache License, Version 2.0 (the "License"); 

66# you may not use this file except in compliance with the License. 

67# You may obtain a copy of the License at 

68# 

69# https://www.apache.org/licenses/LICENSE-2.0 

70# 

71# Unless required by applicable law or agreed to in writing, software 

72# distributed under the License is distributed on an "AS IS" BASIS, 

73# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

74# See the License for the specific language governing permissions and 

75# limitations under the License. 

76 

77import numpy as np 

78from jax import debug_infs, lax 

79 

80 

81def ndtri(p: Float[Array, '*']) -> Float[Array, '*']: 

82 """Compute the inverse of the CDF of the Normal distribution function. 

83 

84 This is a patch of `jax.scipy.special.ndtri`. 

85 """ 

86 dtype = lax.dtype(p) 1abcdefghijklmnopqrstuvwxyzA

87 if dtype not in (jnp.float32, jnp.float64): 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true1abcdefghijklmnopqrstuvwxyzA

88 msg = f'x.dtype={dtype} is not supported, see docstring for supported types.' 

89 raise TypeError(msg) 

90 return _ndtri(p) 1abcdefghijklmnopqrstuvwxyzA

91 

92 

93def _ndtri(p: Float[Array, '...']) -> Float[Array, '...']: 

94 # Constants used in piece-wise rational approximations. Taken from the cephes 

95 # library: 

96 # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html 

97 p0 = list( 1abcdefghijklmnopqrstuvwxyzA

98 reversed( 

99 [ 

100 -5.99633501014107895267e1, 

101 9.80010754185999661536e1, 

102 -5.66762857469070293439e1, 

103 1.39312609387279679503e1, 

104 -1.23916583867381258016e0, 

105 ] 

106 ) 

107 ) 

108 q0 = list( 1abcdefghijklmnopqrstuvwxyzA

109 reversed( 

110 [ 

111 1.0, 

112 1.95448858338141759834e0, 

113 4.67627912898881538453e0, 

114 8.63602421390890590575e1, 

115 -2.25462687854119370527e2, 

116 2.00260212380060660359e2, 

117 -8.20372256168333339912e1, 

118 1.59056225126211695515e1, 

119 -1.18331621121330003142e0, 

120 ] 

121 ) 

122 ) 

123 p1 = list( 1abcdefghijklmnopqrstuvwxyzA

124 reversed( 

125 [ 

126 4.05544892305962419923e0, 

127 3.15251094599893866154e1, 

128 5.71628192246421288162e1, 

129 4.40805073893200834700e1, 

130 1.46849561928858024014e1, 

131 2.18663306850790267539e0, 

132 -1.40256079171354495875e-1, 

133 -3.50424626827848203418e-2, 

134 -8.57456785154685413611e-4, 

135 ] 

136 ) 

137 ) 

138 q1 = list( 1abcdefghijklmnopqrstuvwxyzA

139 reversed( 

140 [ 

141 1.0, 

142 1.57799883256466749731e1, 

143 4.53907635128879210584e1, 

144 4.13172038254672030440e1, 

145 1.50425385692907503408e1, 

146 2.50464946208309415979e0, 

147 -1.42182922854787788574e-1, 

148 -3.80806407691578277194e-2, 

149 -9.33259480895457427372e-4, 

150 ] 

151 ) 

152 ) 

153 p2 = list( 1abcdefghijklmnopqrstuvwxyzA

154 reversed( 

155 [ 

156 3.23774891776946035970e0, 

157 6.91522889068984211695e0, 

158 3.93881025292474443415e0, 

159 1.33303460815807542389e0, 

160 2.01485389549179081538e-1, 

161 1.23716634817820021358e-2, 

162 3.01581553508235416007e-4, 

163 2.65806974686737550832e-6, 

164 6.23974539184983293730e-9, 

165 ] 

166 ) 

167 ) 

168 q2 = list( 1abcdefghijklmnopqrstuvwxyzA

169 reversed( 

170 [ 

171 1.0, 

172 6.02427039364742014255e0, 

173 3.67983563856160859403e0, 

174 1.37702099489081330271e0, 

175 2.16236993594496635890e-1, 

176 1.34204006088543189037e-2, 

177 3.28014464682127739104e-4, 

178 2.89247864745380683936e-6, 

179 6.79019408009981274425e-9, 

180 ] 

181 ) 

182 ) 

183 

184 dtype = lax.dtype(p).type 1abcdefghijklmnopqrstuvwxyzA

185 shape = jnp.shape(p) 1abcdefghijklmnopqrstuvwxyzA

186 

187 def _create_polynomial( 1abcdefghijklmnopqrstuvwxyzA

188 var: Float[Array, '...'], coeffs: Sequence[float] 

189 ) -> Float[Array, '...']: 

190 """Compute n_th order polynomial via Horner's method.""" 

191 coeffs = np.array(coeffs, dtype) 1abcdefghijklmnopqrstuvwxyzA

192 if not coeffs.size: 1abcdefghijklmnopqrstuvwxyzA

193 return jnp.zeros_like(var) 1abcdefghijklmnopqrstuvwxyzA

194 return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var 1abcdefghijklmnopqrstuvwxyzA

195 

196 maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.0)), dtype(1.0) - p, p) 1abcdefghijklmnopqrstuvwxyzA

197 # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs 

198 # later on. The result from the computation when p == 0 is not used so any 

199 # number that doesn't result in NaNs is fine. 

200 sanitized_mcp = jnp.where( 1abcdefghijklmnopqrstuvwxyzA

201 maybe_complement_p == dtype(0.0), 

202 jnp.full(shape, dtype(0.5)), 

203 maybe_complement_p, 

204 ) 

205 

206 # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2). 

207 w = sanitized_mcp - dtype(0.5) 1abcdefghijklmnopqrstuvwxyzA

208 ww = lax.square(w) 1abcdefghijklmnopqrstuvwxyzA

209 x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) / _create_polynomial(ww, q0)) 1abcdefghijklmnopqrstuvwxyzA

210 x_for_big_p *= -dtype(np.sqrt(2.0 * np.pi)) 1abcdefghijklmnopqrstuvwxyzA

211 

212 # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z), 

213 # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different 

214 # arrays based on whether p < exp(-32). 

215 z = lax.sqrt(dtype(-2.0) * lax.log(sanitized_mcp)) 1abcdefghijklmnopqrstuvwxyzA

216 first_term = z - lax.log(z) / z 1abcdefghijklmnopqrstuvwxyzA

217 second_term_small_p = ( 1abcdefghijklmnopqrstuvwxyzA

218 _create_polynomial(dtype(1.0) / z, p2) 

219 / _create_polynomial(dtype(1.0) / z, q2) 

220 / z 

221 ) 

222 second_term_otherwise = ( 1abcdefghijklmnopqrstuvwxyzA

223 _create_polynomial(dtype(1.0) / z, p1) 

224 / _create_polynomial(dtype(1.0) / z, q1) 

225 / z 

226 ) 

227 x_for_small_p = first_term - second_term_small_p 1abcdefghijklmnopqrstuvwxyzA

228 x_otherwise = first_term - second_term_otherwise 1abcdefghijklmnopqrstuvwxyzA

229 

230 x = jnp.where( 1abcdefghijklmnopqrstuvwxyzA

231 sanitized_mcp > dtype(np.exp(-2.0)), 

232 x_for_big_p, 

233 jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise), 

234 ) 

235 

236 x = jnp.where(p > dtype(1.0 - np.exp(-2.0)), x, -x) 1abcdefghijklmnopqrstuvwxyzA

237 with debug_infs(False): 1abcdefghijklmnopqrstuvwxyzA

238 infinity = jnp.full(shape, dtype(np.inf)) 1abcdefghijklmnopqrstuvwxyzA

239 neg_infinity = -infinity 1abcdefghijklmnopqrstuvwxyzA

240 return jnp.where( 1abcdefghijklmnopqrstuvwxyzA

241 p == dtype(0.0), neg_infinity, jnp.where(p == dtype(1.0), infinity, x) 

242 ) 

243 

244 

245################################################################