Coverage for src/bartz/_jaxext/_autobatch.py: 99%
178 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/_jaxext/_autobatch.py
2#
3# Copyright (c) 2025-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implementation of `autobatch`."""
27import math
28from collections.abc import Callable
29from functools import partial, wraps
30from typing import Any, Protocol, runtime_checkable
31from warnings import warn
33from jax import ShapeDtypeStruct, eval_shape, lax, tree
34from jax import numpy as jnp
35from jax.typing import ArrayLike, DTypeLike
36from jaxtyping import Array, PyTree, Shaped
37from numpy.lib.array_utils import normalize_axis_index
39from bartz._jaxext._jit import jit
42@runtime_checkable
43class BinaryUfunc(Protocol):
44 """Duck type of binary `jax.numpy.ufunc`s like `jnp.add`.
46 Mirrors the stub-only protocol `jax.numpy.BinaryUfunc`, which does not
47 exist at runtime.
48 """
50 @property
51 def identity(self) -> bool | int | float: ...
53 def __call__(
54 self, x: Shaped[ArrayLike, '...'], y: Shaped[ArrayLike, '...'], /
55 ) -> Shaped[Array, '...']: ...
57 def reduce(
58 self, a: Shaped[ArrayLike, '...'], /, *, axis: int | None = 0
59 ) -> Shaped[Array, '...']: ...
62def expand_axes(
63 axes: PyTree[int | None], tree_arg: PyTree, *, none_is_leaf: bool = True
64) -> PyTree[int | None]:
65 """Expand `axes` such that they match the pytreedef of `tree_arg`."""
67 def expand_axis(axis: int | None, subtree: PyTree) -> PyTree[int | None]:
68 return tree.map(lambda _: axis, subtree)
70 is_leaf = (lambda x: x is None) if none_is_leaf else None
71 return tree.map(expand_axis, axes, tree_arg, is_leaf=is_leaf)
74def normalize_axes(
75 axes: PyTree[int | None, ' T'],
76 tree_arg: PyTree[Array | ShapeDtypeStruct | None, ' T'],
77) -> PyTree[int | None, ' T']:
78 """Normalize axes to be non-negative and valid for the corresponding arrays in the tree_arg."""
80 def normalize_axis(
81 axis: int | None, x: Shaped[Array, '...'] | ShapeDtypeStruct | None
82 ) -> int | None:
83 if axis is None:
84 return None
85 else:
86 assert x is not None
87 return normalize_axis_index(axis, len(x.shape))
89 return tree.map(normalize_axis, axes, tree_arg, is_leaf=lambda x: x is None)
92def remove_axis(
93 x: PyTree[ShapeDtypeStruct, ' T'], axis: PyTree[int, ' T'], ufunc: BinaryUfunc
94) -> PyTree[ShapeDtypeStruct, ' T']:
95 """Remove an axis from dummy arrays and change the type to reduction type."""
97 def remove_axis(x: ShapeDtypeStruct, axis: int) -> ShapeDtypeStruct:
98 new_shape = x.shape[:axis] + x.shape[axis + 1 :]
99 new_dtype = reduction_dtype(ufunc, x.dtype)
100 return ShapeDtypeStruct(new_shape, new_dtype)
102 return tree.map(remove_axis, x, axis)
105def extract_size(axes: PyTree[int | None], tree_arg: PyTree) -> int:
106 """Get the size of each array in tree_arg at the axis in axes, check they are equal and return it."""
108 def get_size(
109 x: Shaped[Array, '...'] | ShapeDtypeStruct, axis: int | None
110 ) -> int | None:
111 if axis is None:
112 return None
113 else:
114 return x.shape[axis]
116 sizes = tree.map(get_size, tree_arg, axes)
117 sizes, _ = tree.flatten(sizes)
118 assert all(s == sizes[0] for s in sizes)
119 return sizes[0]
122def sum_nbytes(tree_arg: PyTree[Array | ShapeDtypeStruct]) -> int:
123 def nbytes(x: Shaped[Array, '...'] | ShapeDtypeStruct) -> int:
124 return math.prod(x.shape) * x.dtype.itemsize
126 return tree.reduce(lambda size, x: size + nbytes(x), tree_arg, 0)
129def next_divisor_small(dividend: int, min_divisor: int) -> int:
130 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
131 if dividend % divisor == 0:
132 return divisor
133 return dividend
136def next_divisor_large(dividend: int, min_divisor: int) -> int:
137 max_inv_divisor = dividend // min_divisor
138 for inv_divisor in range(max_inv_divisor, 0, -1):
139 if dividend % inv_divisor == 0:
140 return dividend // inv_divisor
141 return dividend
144def next_divisor(dividend: int, min_divisor: int) -> int:
145 """Return divisor >= min_divisor such that dividend % divisor == 0."""
146 if dividend == 0:
147 return min_divisor
148 if min_divisor * min_divisor <= dividend:
149 return next_divisor_small(dividend, min_divisor)
150 return next_divisor_large(dividend, min_divisor)
153def pull_nonbatched(
154 axes: PyTree[int | None], tree_arg: PyTree
155) -> tuple[PyTree, PyTree]:
156 def pull_nonbatched(x: object, axis: int | None) -> object:
157 if axis is None:
158 return None
159 else:
160 return x
162 return tree.map(pull_nonbatched, tree_arg, axes), tree_arg
165def push_nonbatched(
166 axes: PyTree[int | None], tree_arg: PyTree, original_tree: PyTree
167) -> PyTree[Any]:
168 def push_nonbatched(original_x: object, x: object, axis: int | None) -> object:
169 if axis is None:
170 return original_x
171 else:
172 return x
174 return tree.map(push_nonbatched, original_tree, tree_arg, axes)
177def move_axes_out(axes: PyTree[int], tree_arg: PyTree[Array]) -> PyTree[Array]:
178 def move_axis_out(x: Shaped[Array, '...'], axis: int) -> Shaped[Array, '...']:
179 return jnp.moveaxis(x, axis, 0)
181 return tree.map(move_axis_out, tree_arg, axes)
184def move_axes_in(axes: PyTree[int], tree_arg: PyTree[Array]) -> PyTree[Array]:
185 def move_axis_in(x: Shaped[Array, '...'], axis: int) -> Shaped[Array, '...']:
186 return jnp.moveaxis(x, 0, axis)
188 return tree.map(move_axis_in, tree_arg, axes)
191def batch(tree_arg: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']:
192 """Split the first axis into two axes, the first of size `nbatches`."""
194 def batch(x: Shaped[Array, '...']) -> Shaped[Array, '...']:
195 return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:])
197 return tree.map(batch, tree_arg)
200def unbatch(tree_arg: PyTree[Array, ' T']) -> PyTree[Array, ' T']:
201 """Merge the first two axes into a single axis."""
203 def unbatch(x: Shaped[Array, '...']) -> Shaped[Array, '...']:
204 return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
206 return tree.map(unbatch, tree_arg)
209def reduce(
210 ufunc: BinaryUfunc,
211 x: PyTree[Array, ' T'],
212 axes: PyTree[int, ' T'],
213 initial: PyTree[Array, ' T'] | None,
214) -> PyTree[Array, ' T']:
215 """Reduce each array in `x` along the axes in `axes` starting from `initial` using `ufunc.reduce`."""
216 if initial is None: 216 ↛ anywhereline 216 didn't jump anywhere: it always raised an exception.
218 def reduce(x: Shaped[Array, '...'], axis: int) -> Shaped[Array, '...']:
219 return ufunc.reduce(x, axis=axis)
221 return tree.map(reduce, x, axes)
223 else:
225 def reduce(
226 x: Shaped[Array, '...'], initial: Shaped[Array, '...'], axis: int
227 ) -> Shaped[Array, '...']:
228 reduced = ufunc.reduce(x, axis=axis)
229 return ufunc(initial, reduced)
231 return tree.map(reduce, x, initial, axes)
234def identity(
235 ufunc: BinaryUfunc, x: PyTree[ShapeDtypeStruct, ' T']
236) -> PyTree[Array, ' T']:
237 """Get the identity element for `ufunc` and each array in `x`."""
239 def identity(x: ShapeDtypeStruct) -> Shaped[Array, '...']:
240 identity = identity_for(ufunc, x.dtype)
241 return jnp.broadcast_to(identity, x.shape)
243 return tree.map(identity, x)
246def reduction_dtype(ufunc: BinaryUfunc, input_dtype: DTypeLike) -> DTypeLike:
247 """Return the output dtype for a reduction with `ufunc` on inputs of type `dtype`."""
248 return ufunc.reduce(jnp.empty(1, input_dtype)).dtype
251def identity_for(ufunc: BinaryUfunc, input_dtype: DTypeLike) -> Shaped[Array, '']:
252 """Return the identity for ufunc as an array scalar with the right dtype."""
253 # get output type from input type, e.g., int8 is accumulated to int32
254 dtype = reduction_dtype(ufunc, input_dtype)
256 # return as explicitly typed array
257 return jnp.array(ufunc.identity, dtype)
260def check_same(tree1: PyTree, tree2: PyTree) -> None:
261 def check_same(
262 x1: Shaped[Array, '*shape'] | ShapeDtypeStruct,
263 x2: Shaped[Array, '*shape'] | ShapeDtypeStruct,
264 ) -> None:
265 assert x1.shape == x2.shape
266 assert x1.dtype == x2.dtype
268 tree.map(check_same, tree1, tree2)
271class NotDefined:
272 pass
275def autobatch(
276 func: Callable,
277 max_io_nbytes: int,
278 in_axes: PyTree[int | None] = 0,
279 out_axes: PyTree[int] = 0,
280 *,
281 return_nbatches: bool = False,
282 reduce_ufunc: BinaryUfunc | None = None,
283 reduce_vary_axes: tuple[str, ...] = (),
284 warn_on_overflow: bool = True,
285 result_shape_dtype: PyTree[ShapeDtypeStruct] | type[NotDefined] = NotDefined,
286) -> Callable:
287 """
288 Batch a function such that each batch is smaller than a threshold.
290 Parameters
291 ----------
292 func
293 A jittable function with positional arguments only, with inputs and
294 outputs pytrees of arrays.
295 max_io_nbytes
296 The maximum number of input + output bytes in each batch (excluding
297 unbatched arguments.)
298 in_axes
299 A tree matching (a prefix of) the structure of the function input,
300 indicating along which axes each array should be batched. A `None` axis
301 indicates to not batch an argument.
302 out_axes
303 The same for outputs (but non-batching is not allowed).
304 return_nbatches
305 If True, the number of batches is returned as a second output.
306 reduce_ufunc
307 Function used to reduce the output along the batched axis (e.g.,
308 `jax.numpy.add`).
309 reduce_vary_axes
310 Manual `jax.shard_map` mesh axes over which the reduction accumulator
311 varies. Under a `shard_map`, the reduction seed is `pcast` to vary over
312 these axes so its type matches the shard-varying loop body, satisfying
313 the VMA checker. Ignored unless `reduce_ufunc` is set.
314 warn_on_overflow
315 If True, a warning is raised if the memory limit could not be
316 respected.
317 result_shape_dtype
318 A pytree of dummy arrays matching the expected output. If not provided,
319 the function is traced an additional time to determine the output
320 structure.
322 Returns
323 -------
324 A function with the same signature as `func`, save for the return value if `return_nbatches`.
326 Notes
327 -----
328 Unless `return_nbatches` or `reduce_ufunc` are set, `autobatch` at given
329 arguments is idempotent. Furthermore, `autobatch` can be applied multiple
330 times over multiple axes with the same `max_io_nbytes` limit to work on
331 multiple axes; in this case it won't unnecessarily loop over additional axes
332 if one or more outer `autobatch` are already sufficient.
334 To handle memory used in intermediate values: assuming all intermediate
335 values have size that scales linearly with the axis batched over, say the
336 batched input/output total size is ``batched_size * core_io_size``, and the
337 intermediate values have size ``batched_size * core_int_size``, then to take
338 them into account divide `max_io_nbytes` by ``(1 + core_int_size /
339 core_io_size)``.
340 """
342 @jit
343 @wraps(func)
344 def autobatch_wrapper(*args: PyTree) -> PyTree:
345 return batched_func(
346 func,
347 max_io_nbytes,
348 in_axes,
349 out_axes,
350 return_nbatches,
351 reduce_ufunc,
352 reduce_vary_axes,
353 warn_on_overflow,
354 result_shape_dtype,
355 args,
356 )
358 return autobatch_wrapper
361def batched_func(
362 func: Callable,
363 max_io_nbytes: int,
364 in_axes: PyTree[int | None],
365 out_axes: PyTree[int],
366 return_nbatches: bool,
367 reduce_ufunc: BinaryUfunc | None,
368 reduce_vary_axes: tuple[str, ...],
369 warn_on_overflow: bool,
370 result_shape_dtype: PyTree[ShapeDtypeStruct] | type[NotDefined],
371 args: tuple[PyTree[Array], ...],
372) -> PyTree[Array] | tuple[PyTree[Array], int]:
373 """Implement the wrapper used in `autobatch`."""
374 # determine the output structure of the function
375 if result_shape_dtype is NotDefined:
376 example_result = eval_shape(func, *args)
377 else:
378 example_result = result_shape_dtype
380 # expand the axes pytrees if they are prefixes
381 in_axes = expand_axes(in_axes, args)
382 out_axes = expand_axes(out_axes, example_result, none_is_leaf=False)
384 # check the axes are valid
385 in_axes = normalize_axes(in_axes, args)
386 out_axes = normalize_axes(out_axes, example_result)
388 # get the size of the batched axis
389 size = extract_size((in_axes, out_axes), (args, example_result))
391 # split arguments in batched and not batched
392 original_args = args
393 args, nonbatched_args = pull_nonbatched(in_axes, args)
395 # determine the number of batches to respect the memory limit
396 total_nbytes = sum_nbytes((args, example_result))
397 min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
398 min_nbatches = max(1, min_nbatches)
399 nbatches = next_divisor(size, min_nbatches)
400 assert 1 <= nbatches <= max(1, size)
401 assert size % nbatches == 0
402 assert total_nbytes % nbatches == 0
404 # warn if the memory limit could not be respected
405 batch_nbytes = total_nbytes // nbatches
406 if batch_nbytes > max_io_nbytes and warn_on_overflow:
407 assert size == nbatches
408 msg = f'batch_nbytes = {batch_nbytes:_} > max_io_nbytes = {max_io_nbytes:_}'
409 warn(msg)
411 # squeeze out the output dims that will be reduced
412 if reduce_ufunc is not None:
413 example_result = remove_axis(example_result, out_axes, reduce_ufunc)
415 if nbatches > 1:
416 # prepare arguments for looping
417 args = move_axes_out(in_axes, args)
418 args = batch(args, nbatches)
420 # prepare carry for reduction
421 if reduce_ufunc is None:
422 initial = None
423 else:
424 initial = identity(reduce_ufunc, example_result)
425 # under a `shard_map`, the loop body output varies over the manual
426 # axes while this seed is replicated; mark it varying so the scan's
427 # carry types match and the VMA checker is satisfied
428 if reduce_vary_axes:
429 initial = tree.map(
430 lambda x: lax.pcast(x, reduce_vary_axes, to='varying'), initial
431 )
433 # loop and invoke the function in batches
434 loop = partial(
435 batching_loop,
436 func=func,
437 nonbatched_args=nonbatched_args,
438 in_axes=in_axes,
439 out_axes=out_axes,
440 reduce_ufunc=reduce_ufunc,
441 )
442 reduced_result, result = lax.scan(loop, initial, args)
444 # remove auxiliary batching axis and reverse transposition
445 if reduce_ufunc is None:
446 assert reduced_result is None
447 result = unbatch(result)
448 result = move_axes_in(out_axes, result)
449 else:
450 assert result is None
451 result = reduced_result
453 # trivial case: no batching needed
454 else:
455 result = func(*original_args)
456 if reduce_ufunc is not None:
457 result = reduce(reduce_ufunc, result, out_axes, None)
459 check_same(example_result, result)
461 if return_nbatches:
462 return result, nbatches
463 return result
466def batching_loop(
467 initial: PyTree[Array] | None,
468 args: PyTree[Array],
469 *,
470 func: Callable,
471 nonbatched_args: PyTree,
472 in_axes: PyTree[int | None],
473 out_axes: PyTree[int],
474 reduce_ufunc: BinaryUfunc | None,
475) -> tuple[PyTree[Array], None] | tuple[None, PyTree[Array]]:
476 """Implement the batching loop in `autobatch`."""
477 # evaluate the function
478 args = move_axes_in(in_axes, args)
479 args = push_nonbatched(in_axes, args, nonbatched_args)
480 result = func(*args)
482 # unreduced case: transpose for concatenation and return
483 if reduce_ufunc is None:
484 result = move_axes_out(out_axes, result)
485 return None, result
487 # reduced case: reduce starting from initial
488 else:
489 reduced_result = reduce(reduce_ufunc, result, out_axes, initial)
490 return reduced_result, None