Coverage for src/bartz/_jaxext/_jit.py: 91%

18 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/_jaxext/_jit.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Signature-preserving `jax.jit` wrapper.""" 

26 

27from collections.abc import Callable, Sequence 

28from typing import ( 

29 TYPE_CHECKING, 

30 Any, 

31 ParamSpec, 

32 Protocol, 

33 TypeVar, 

34 overload, 

35 runtime_checkable, 

36) 

37 

38from jax import ShapeDtypeStruct 

39from jax import jit as _jax_jit 

40from jax.stages import Lowered, Traced 

41from jaxtyping import PyTree 

42 

43_P = ParamSpec('_P') 

44_R = TypeVar('_R') 

45_R_co = TypeVar('_R_co', covariant=True) 

46 

47 

48@runtime_checkable 

49class JitWrapped(Protocol[_P, _R_co]): 

50 """Static type of a jitted function: the wrapped signature plus jit methods.""" 

51 

52 def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ... 

53 

54 def clear_cache(self) -> None: ... 

55 

56 def eval_shape( 

57 self, *args: _P.args, **kwargs: _P.kwargs 

58 ) -> PyTree[ShapeDtypeStruct]: ... 

59 

60 def lower(self, *args: _P.args, **kwargs: _P.kwargs) -> Lowered: ... 

61 

62 def trace(self, *args: _P.args, **kwargs: _P.kwargs) -> Traced: ... 

63 

64 if not TYPE_CHECKING: 64 ↛ anywhereline 64 didn't jump anywhere: it always raised an exception.

65 # WORKAROUND(beartype<99): beartype chokes on ParamSpec-subscripted 

66 # generics, and the jaxtyping import hook used by the test suite makes 

67 # it process the `JitWrapped[_P, _R]` hints in `jit`'s overloads. Erase 

68 # the subscript at runtime so beartype sees the plain runtime-checkable 

69 # protocol, which jitted functions genuinely satisfy. `99` is a 

70 # placeholder for the beartype release gaining PEP 612 generics support. 

71 def __class_getitem__(cls, item: object) -> type: 

72 return cls 

73 

74 

75# WORKAROUND(jax<99): `jax.jit` is typed to return `JitWrapped`, which erases the 

76# wrapped function's signature, so static checkers can't validate calls to jitted 

77# functions. This shim recovers the signature via `ParamSpec`, declaring our own 

78# `JitWrapped` protocol that combines it with the jit-specific methods (including 

79# `clear_cache`, which jax adds to the jitted callable at runtime and omits from 

80# its own static `JitWrapped` type). Tracked upstream at jax-ml/jax#23719; the 

81# jax maintainers are blocked on migrating internal Google code to a type checker 

82# that understands `ParamSpec` (jax itself has moved to pyrefly). Once `jax.jit` 

83# preserves the signature natively, this whole module can go and `jit` can be 

84# imported straight from jax. `99` is a placeholder for that unknown future jax 

85# release. 

86@overload 

87def jit( 

88 fun: Callable[_P, _R], 

89 /, 

90 *, 

91 static_argnums: int | Sequence[int] | None = ..., 

92 static_argnames: str | Sequence[str] | None = ..., 

93 donate_argnums: int | Sequence[int] | None = ..., 

94 **kwargs: Any, 

95) -> JitWrapped[_P, _R]: ... 

96 

97 

98@overload 

99def jit( 

100 fun: None = ..., 

101 /, 

102 *, 

103 static_argnums: int | Sequence[int] | None = ..., 

104 static_argnames: str | Sequence[str] | None = ..., 

105 donate_argnums: int | Sequence[int] | None = ..., 

106 **kwargs: Any, 

107) -> Callable[[Callable[_P, _R]], JitWrapped[_P, _R]]: ... 

108def jit(fun: Any = None, /, **kwargs: Any) -> Any: 

109 """Wrap `jax.jit` preserving the wrapped function's static type signature. 

110 

111 `jax.jit` is typed to return an opaque ``JitWrapped`` callable, which erases 

112 the wrapped signature; static checkers then treat every call to a jitted 

113 function as returning an unknown type, cascading into false positives. This 

114 shim is typed with a `ParamSpec` so jitted calls keep their real signature 

115 and argument checking, while at runtime it just defers to `jax.jit`. 

116 

117 Use it as a drop-in for both decorator forms, ``@jit`` and ``@jit(...)``. 

118 

119 Parameters 

120 ---------- 

121 fun 

122 The function to compile, or `None` to use the keyword-only form. 

123 **kwargs 

124 Keyword arguments forwarded to `jax.jit` (e.g. `static_argnums`, 

125 `static_argnames`, `donate_argnums`). 

126 

127 Returns 

128 ------- 

129 The jitted function, or a decorator if `fun` is `None`. 

130 """ 

131 # WORKAROUND(jax<0.8.1): jax gained native `@jit(...)` two-stage decorator 

132 # support in 0.8.1. Once the floor reaches 0.8.1 the runtime fallback could 

133 # defer to jax's native form, but keep the shim regardless, because jax's 

134 # own overloads still return `JitWrapped` and erase the signature; the 

135 # ParamSpec typing here is the whole point. 

136 if fun is None: 

137 return lambda f: _jax_jit(f, **kwargs) 

138 else: 

139 return _jax_jit(fun, **kwargs)