Coverage for src / bartz / jaxext / _autobatch.py: 100%
176 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
1# bartz/src/bartz/jaxext/_autobatch.py
2#
3# Copyright (c) 2025-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implementation of `autobatch`."""
27import math
28from collections.abc import Callable
29from functools import partial, wraps
30from typing import Any
31from warnings import warn
33from jax import ShapeDtypeStruct, eval_shape, jit, lax, tree
34from jax import numpy as jnp
35from jax.typing import DTypeLike
36from jaxtyping import Array, PyTree, Shaped
37from numpy.lib.array_utils import normalize_axis_index
40def expand_axes(axes: PyTree[int | None], tree_arg: PyTree) -> PyTree[int | None]:
41 """Expand `axes` such that they match the pytreedef of `tree_arg`."""
43 def expand_axis(axis: int | None, subtree: PyTree) -> PyTree[int | None]: 1ab
44 return tree.map(lambda _: axis, subtree) 1ab
46 return tree.map(expand_axis, axes, tree_arg, is_leaf=lambda x: x is None) 1ab
49def normalize_axes(
50 axes: PyTree[int | None, ' T'], tree_arg: PyTree[Array, ' T']
51) -> PyTree[int | None, ' T']:
52 """Normalize axes to be non-negative and valid for the corresponding arrays in the tree_arg."""
54 def normalize_axis(axis: int | None, x: Array) -> int | None: 1ab
55 if axis is None: 1abf
56 return None 1af
57 else:
58 return normalize_axis_index(axis, len(x.shape)) 1ab
60 return tree.map(normalize_axis, axes, tree_arg, is_leaf=lambda x: x is None) 1ab
63def check_no_nones(axes: PyTree[int | None], tree_arg: PyTree) -> None:
64 def check_not_none(_: object, axis: int | None) -> None: 1ab
65 assert axis is not None 1ab
67 tree.map(check_not_none, tree_arg, axes, is_leaf=lambda x: x is None) 1ab
70def remove_axis(
71 x: PyTree[ShapeDtypeStruct, ' T'], axis: PyTree[int, ' T'], ufunc: jnp.ufunc
72) -> PyTree[ShapeDtypeStruct, ' T']:
73 """Remove an axis from dummy arrays and change the type to reduction type."""
75 def remove_axis(x: ShapeDtypeStruct, axis: int) -> ShapeDtypeStruct: 1ae
76 new_shape = x.shape[:axis] + x.shape[axis + 1 :] 1ae
77 new_dtype = reduction_dtype(ufunc, x.dtype) 1ae
78 return ShapeDtypeStruct(new_shape, new_dtype) 1ae
80 return tree.map(remove_axis, x, axis) 1ae
83def extract_size(axes: PyTree[int | None], tree_arg: PyTree) -> int:
84 """Get the size of each array in tree_arg at the axis in axes, check they are equal and return it."""
86 def get_size(x: object, axis: int | None) -> int | None: 1ab
87 if axis is None: 1abf
88 return None 1af
89 else:
90 return x.shape[axis] 1ab
92 sizes = tree.map(get_size, tree_arg, axes) 1ab
93 sizes, _ = tree.flatten(sizes) 1ab
94 assert all(s == sizes[0] for s in sizes) 1ab
95 return sizes[0] 1ab
98def sum_nbytes(tree_arg: PyTree[Array | ShapeDtypeStruct]) -> int:
99 def nbytes(x: Array | ShapeDtypeStruct) -> int: 1ab
100 return math.prod(x.shape) * x.dtype.itemsize 1ab
102 return tree.reduce(lambda size, x: size + nbytes(x), tree_arg, 0) 1ab
105def next_divisor_small(dividend: int, min_divisor: int) -> int:
106 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 1ahb
107 if dividend % divisor == 0: 1ahb
108 return divisor 1ab
109 return dividend 1ah
112def next_divisor_large(dividend: int, min_divisor: int) -> int:
113 max_inv_divisor = dividend // min_divisor 1ac
114 for inv_divisor in range(max_inv_divisor, 0, -1): 1aihcg
115 if dividend % inv_divisor == 0: 1aihc
116 return dividend // inv_divisor 1ac
117 return dividend 1ag
120def next_divisor(dividend: int, min_divisor: int) -> int:
121 """Return divisor >= min_divisor such that dividend % divisor == 0."""
122 if dividend == 0: 1ajbk
123 return min_divisor 1jk
124 if min_divisor * min_divisor <= dividend: 1abc
125 return next_divisor_small(dividend, min_divisor) 1ab
126 return next_divisor_large(dividend, min_divisor) 1ac
129def pull_nonbatched(
130 axes: PyTree[int | None], tree_arg: PyTree
131) -> tuple[PyTree, PyTree]:
132 def pull_nonbatched(x: object, axis: int | None) -> object: 1ab
133 if axis is None: 1abf
134 return None 1af
135 else:
136 return x 1ab
138 return tree.map(pull_nonbatched, tree_arg, axes), tree_arg 1ab
141def push_nonbatched(
142 axes: PyTree[int | None], tree_arg: PyTree, original_tree: PyTree
143) -> PyTree[Any]:
144 def push_nonbatched(original_x: object, x: object, axis: int | None) -> object: 1ac
145 if axis is None: 1acf
146 return original_x 1af
147 else:
148 return x 1ac
150 return tree.map(push_nonbatched, original_tree, tree_arg, axes) 1ac
153def move_axes_out(axes: PyTree[int], tree_arg: PyTree[Array]) -> PyTree[Array]:
154 def move_axis_out(x: Array, axis: int) -> Array: 1ac
155 return jnp.moveaxis(x, axis, 0) 1ac
157 return tree.map(move_axis_out, tree_arg, axes) 1ac
160def move_axes_in(axes: PyTree[int], tree_arg: PyTree[Array]) -> PyTree[Array]:
161 def move_axis_in(x: Array, axis: int) -> Array: 1ac
162 return jnp.moveaxis(x, 0, axis) 1ac
164 return tree.map(move_axis_in, tree_arg, axes) 1ac
167def batch(tree_arg: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']:
168 """Split the first axis into two axes, the first of size `nbatches`."""
170 def batch(x: Array) -> Array: 1ac
171 return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:]) 1ac
173 return tree.map(batch, tree_arg) 1ac
176def unbatch(tree_arg: PyTree[Array, ' T']) -> PyTree[Array, ' T']:
177 """Merge the first two axes into a single axis."""
179 def unbatch(x: Array) -> Array: 1ac
180 return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) 1ac
182 return tree.map(unbatch, tree_arg) 1ac
185def reduce(
186 ufunc: jnp.ufunc,
187 x: PyTree[Array, ' T'],
188 axes: PyTree[int, ' T'],
189 initial: PyTree[Array, ' T'] | None,
190) -> PyTree[Array, ' T']:
191 """Reduce each array in `x` along the axes in `axes` starting from `initial` using `ufunc.reduce`."""
192 if initial is None: 1ae
194 def reduce(x: Array, axis: int) -> Array: 1ae
195 return ufunc.reduce(x, axis=axis) 1ae
197 return tree.map(reduce, x, axes) 1ae
199 else:
201 def reduce(x: Array, initial: Array, axis: int) -> Array: 1e
202 reduced = ufunc.reduce(x, axis=axis) 1e
203 return ufunc(initial, reduced) 1e
205 return tree.map(reduce, x, initial, axes) 1e
208def identity(
209 ufunc: jnp.ufunc, x: PyTree[ShapeDtypeStruct, ' T']
210) -> PyTree[Array, ' T']:
211 """Get the identity element for `ufunc` and each array in `x`."""
213 def identity(x: ShapeDtypeStruct) -> Array: 1e
214 identity = identity_for(ufunc, x.dtype) 1e
215 return jnp.broadcast_to(identity, x.shape) 1e
217 return tree.map(identity, x) 1e
220def reduction_dtype(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> DTypeLike:
221 """Return the output dtype for a reduction with `ufunc` on inputs of type `dtype`."""
222 return ufunc.reduce(jnp.empty(1, input_dtype)).dtype 1ae
225def identity_for(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> Shaped[Array, '']:
226 """Return the identity for ufunc as an array scalar with the right dtype."""
227 # get output type from input type, e.g., int8 is accumulated to int32
228 dtype = reduction_dtype(ufunc, input_dtype) 1e
230 # return as explicitly typed array
231 return jnp.array(ufunc.identity, dtype) 1e
234def check_same(tree1: PyTree, tree2: PyTree) -> None:
235 def check_same(x1: Array | ShapeDtypeStruct, x2: Array | ShapeDtypeStruct) -> None: 1ab
236 assert x1.shape == x2.shape 1ab
237 assert x1.dtype == x2.dtype 1ab
239 tree.map(check_same, tree1, tree2) 1ab
242class NotDefined:
243 pass
246def autobatch(
247 func: Callable,
248 max_io_nbytes: int,
249 in_axes: PyTree[int | None] = 0,
250 out_axes: PyTree[int] = 0,
251 *,
252 return_nbatches: bool = False,
253 reduce_ufunc: jnp.ufunc | None = None,
254 warn_on_overflow: bool = True,
255 result_shape_dtype: PyTree[ShapeDtypeStruct] = NotDefined,
256) -> Callable:
257 """
258 Batch a function such that each batch is smaller than a threshold.
260 Parameters
261 ----------
262 func
263 A jittable function with positional arguments only, with inputs and
264 outputs pytrees of arrays.
265 max_io_nbytes
266 The maximum number of input + output bytes in each batch (excluding
267 unbatched arguments.)
268 in_axes
269 A tree matching (a prefix of) the structure of the function input,
270 indicating along which axes each array should be batched. A `None` axis
271 indicates to not batch an argument.
272 out_axes
273 The same for outputs (but non-batching is not allowed).
274 return_nbatches
275 If True, the number of batches is returned as a second output.
276 reduce_ufunc
277 Function used to reduce the output along the batched axis (e.g.,
278 `jax.numpy.add`).
279 warn_on_overflow
280 If True, a warning is raised if the memory limit could not be
281 respected.
282 result_shape_dtype
283 A pytree of dummy arrays matching the expected output. If not provided,
284 the function is traced an additional time to determine the output
285 structure.
287 Returns
288 -------
289 A function with the same signature as `func`, save for the return value if `return_nbatches`.
291 Notes
292 -----
293 Unless `return_nbatches` or `reduce_ufunc` are set, `autobatch` at given
294 arguments is idempotent. Furthermore, `autobatch` can be applied multiple
295 times over multiple axes with the same `max_io_nbytes` limit to work on
296 multiple axes; in this case it won't unnecessarily loop over additional axes
297 if one or more outer `autobatch` are already sufficient.
299 To handle memory used in intermediate values: assuming all intermediate
300 values have size that scales linearly with the axis batched over, say the
301 batched input/output total size is ``batched_size * core_io_size``, and the
302 intermediate values have size ``batched_size * core_int_size``, then to take
303 them into account divide `max_io_nbytes` by ``(1 + core_int_size /
304 core_io_size)``.
305 """
307 @jit 1ab
308 @wraps(func) 1ab
309 def autobatch_wrapper(*args: PyTree) -> PyTree: 1ab
310 return batched_func( 1ab
311 func,
312 max_io_nbytes,
313 in_axes,
314 out_axes,
315 return_nbatches,
316 reduce_ufunc,
317 warn_on_overflow,
318 result_shape_dtype,
319 args,
320 )
322 return autobatch_wrapper 1ab
325def batched_func(
326 func: Callable,
327 max_io_nbytes: int,
328 in_axes: PyTree[int | None],
329 out_axes: PyTree[int],
330 return_nbatches: bool,
331 reduce_ufunc: jnp.ufunc | None,
332 warn_on_overflow: bool,
333 result_shape_dtype: PyTree[ShapeDtypeStruct] | NotDefined,
334 args: tuple[PyTree[Array], ...],
335) -> PyTree[Array]:
336 """Implement the wrapper used in `autobatch`."""
337 # determine the output structure of the function
338 if result_shape_dtype is NotDefined: 1almb
339 example_result = eval_shape(func, *args) 1ab
340 else:
341 example_result = result_shape_dtype 1alm
343 # expand the axes pytrees if they are prefixes
344 in_axes = expand_axes(in_axes, args) 1ab
345 out_axes = expand_axes(out_axes, example_result) 1ab
346 check_no_nones(out_axes, example_result) 1ab
348 # check the axes are valid
349 in_axes = normalize_axes(in_axes, args) 1ab
350 out_axes = normalize_axes(out_axes, example_result) 1ab
352 # get the size of the batched axis
353 size = extract_size((in_axes, out_axes), (args, example_result)) 1ab
355 # split arguments in batched and not batched
356 original_args = args 1ab
357 args, nonbatched_args = pull_nonbatched(in_axes, args) 1ab
359 # determine the number of batches to respect the memory limit
360 total_nbytes = sum_nbytes((args, example_result)) 1ab
361 min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes) 1ab
362 min_nbatches = max(1, min_nbatches) 1ab
363 nbatches = next_divisor(size, min_nbatches) 1ab
364 assert 1 <= nbatches <= max(1, size) 1ab
365 assert size % nbatches == 0 1ab
366 assert total_nbytes % nbatches == 0 1ab
368 # warn if the memory limit could not be respected
369 batch_nbytes = total_nbytes // nbatches 1ab
370 if batch_nbytes > max_io_nbytes and warn_on_overflow: 1abg
371 assert size == nbatches 1ag
372 msg = f'batch_nbytes = {batch_nbytes:_} > max_io_nbytes = {max_io_nbytes:_}' 1ag
373 warn(msg) 1ag
375 # squeeze out the output dims that will be reduced
376 if reduce_ufunc is not None: 1abe
377 example_result = remove_axis(example_result, out_axes, reduce_ufunc) 1ae
379 if nbatches > 1: 1abc
380 # prepare arguments for looping
381 args = move_axes_out(in_axes, args) 1ac
382 args = batch(args, nbatches) 1ac
384 # prepare carry for reduction
385 if reduce_ufunc is None: 1ace
386 initial = None 1ac
387 else:
388 initial = identity(reduce_ufunc, example_result) 1e
390 # loop and invoke the function in batches
391 loop = partial( 1ac
392 batching_loop,
393 func=func,
394 nonbatched_args=nonbatched_args,
395 in_axes=in_axes,
396 out_axes=out_axes,
397 reduce_ufunc=reduce_ufunc,
398 )
399 reduced_result, result = lax.scan(loop, initial, args) 1ac
401 # remove auxiliary batching axis and reverse transposition
402 if reduce_ufunc is None: 1ace
403 assert reduced_result is None 1ac
404 result = unbatch(result) 1ac
405 result = move_axes_in(out_axes, result) 1ac
406 else:
407 assert result is None 1e
408 result = reduced_result 1e
410 # trivial case: no batching needed
411 else:
412 result = func(*original_args) 1ab
413 if reduce_ufunc is not None: 1abe
414 result = reduce(reduce_ufunc, result, out_axes, None) 1ae
416 check_same(example_result, result) 1ab
418 if return_nbatches: 1ab
419 return result, nbatches 1b
420 return result 1ab
423def batching_loop(
424 initial: PyTree[Array] | None,
425 args: PyTree[Array],
426 *,
427 func: Callable,
428 nonbatched_args: PyTree,
429 in_axes: PyTree[int | None],
430 out_axes: PyTree[int],
431 reduce_ufunc: jnp.ufunc | None,
432) -> tuple[PyTree[Array], None] | tuple[None, PyTree[Array]]:
433 """Implement the batching loop in `autobatch`."""
434 # evaluate the function
435 args = move_axes_in(in_axes, args) 1ac
436 args = push_nonbatched(in_axes, args, nonbatched_args) 1ac
437 result = func(*args) 1ac
439 # unreduced case: transpose for concatenation and return
440 if reduce_ufunc is None: 1ace
441 result = move_axes_out(out_axes, result) 1ac
442 return None, result 1ac
444 # reduced case: reduce starting from initial
445 else:
446 reduced_result = reduce(reduce_ufunc, result, out_axes, initial) 1e
447 return reduced_result, None 1e