Coverage for src/bartz/_jaxext/_jit.py: 91%
18 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/_jaxext/_jit.py
2#
3# Copyright (c) 2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Signature-preserving `jax.jit` wrapper."""
27from collections.abc import Callable, Sequence
28from typing import (
29 TYPE_CHECKING,
30 Any,
31 ParamSpec,
32 Protocol,
33 TypeVar,
34 overload,
35 runtime_checkable,
36)
38from jax import ShapeDtypeStruct
39from jax import jit as _jax_jit
40from jax.stages import Lowered, Traced
41from jaxtyping import PyTree
43_P = ParamSpec('_P')
44_R = TypeVar('_R')
45_R_co = TypeVar('_R_co', covariant=True)
48@runtime_checkable
49class JitWrapped(Protocol[_P, _R_co]):
50 """Static type of a jitted function: the wrapped signature plus jit methods."""
52 def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ...
54 def clear_cache(self) -> None: ...
56 def eval_shape(
57 self, *args: _P.args, **kwargs: _P.kwargs
58 ) -> PyTree[ShapeDtypeStruct]: ...
60 def lower(self, *args: _P.args, **kwargs: _P.kwargs) -> Lowered: ...
62 def trace(self, *args: _P.args, **kwargs: _P.kwargs) -> Traced: ...
64 if not TYPE_CHECKING: 64 ↛ anywhereline 64 didn't jump anywhere: it always raised an exception.
65 # WORKAROUND(beartype<99): beartype chokes on ParamSpec-subscripted
66 # generics, and the jaxtyping import hook used by the test suite makes
67 # it process the `JitWrapped[_P, _R]` hints in `jit`'s overloads. Erase
68 # the subscript at runtime so beartype sees the plain runtime-checkable
69 # protocol, which jitted functions genuinely satisfy. `99` is a
70 # placeholder for the beartype release gaining PEP 612 generics support.
71 def __class_getitem__(cls, item: object) -> type:
72 return cls
75# WORKAROUND(jax<99): `jax.jit` is typed to return `JitWrapped`, which erases the
76# wrapped function's signature, so static checkers can't validate calls to jitted
77# functions. This shim recovers the signature via `ParamSpec`, declaring our own
78# `JitWrapped` protocol that combines it with the jit-specific methods (including
79# `clear_cache`, which jax adds to the jitted callable at runtime and omits from
80# its own static `JitWrapped` type). Tracked upstream at jax-ml/jax#23719; the
81# jax maintainers are blocked on migrating internal Google code to a type checker
82# that understands `ParamSpec` (jax itself has moved to pyrefly). Once `jax.jit`
83# preserves the signature natively, this whole module can go and `jit` can be
84# imported straight from jax. `99` is a placeholder for that unknown future jax
85# release.
86@overload
87def jit(
88 fun: Callable[_P, _R],
89 /,
90 *,
91 static_argnums: int | Sequence[int] | None = ...,
92 static_argnames: str | Sequence[str] | None = ...,
93 donate_argnums: int | Sequence[int] | None = ...,
94 **kwargs: Any,
95) -> JitWrapped[_P, _R]: ...
98@overload
99def jit(
100 fun: None = ...,
101 /,
102 *,
103 static_argnums: int | Sequence[int] | None = ...,
104 static_argnames: str | Sequence[str] | None = ...,
105 donate_argnums: int | Sequence[int] | None = ...,
106 **kwargs: Any,
107) -> Callable[[Callable[_P, _R]], JitWrapped[_P, _R]]: ...
108def jit(fun: Any = None, /, **kwargs: Any) -> Any:
109 """Wrap `jax.jit` preserving the wrapped function's static type signature.
111 `jax.jit` is typed to return an opaque ``JitWrapped`` callable, which erases
112 the wrapped signature; static checkers then treat every call to a jitted
113 function as returning an unknown type, cascading into false positives. This
114 shim is typed with a `ParamSpec` so jitted calls keep their real signature
115 and argument checking, while at runtime it just defers to `jax.jit`.
117 Use it as a drop-in for both decorator forms, ``@jit`` and ``@jit(...)``.
119 Parameters
120 ----------
121 fun
122 The function to compile, or `None` to use the keyword-only form.
123 **kwargs
124 Keyword arguments forwarded to `jax.jit` (e.g. `static_argnums`,
125 `static_argnames`, `donate_argnums`).
127 Returns
128 -------
129 The jitted function, or a decorator if `fun` is `None`.
130 """
131 # WORKAROUND(jax<0.8.1): jax gained native `@jit(...)` two-stage decorator
132 # support in 0.8.1. Once the floor reaches 0.8.1 the runtime fallback could
133 # defer to jax's native form, but keep the shim regardless, because jax's
134 # own overloads still return `JitWrapped` and erase the signature; the
135 # ParamSpec typing here is the whole point.
136 if fun is None:
137 return lambda f: _jax_jit(f, **kwargs)
138 else:
139 return _jax_jit(fun, **kwargs)