Coverage for src / bartz / jaxext / _autobatch.py: 100%

179 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/jaxext/_autobatch.py 

2# 

3# Copyright (c) 2025-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implementation of `autobatch`.""" 

26 

27import math 

28from collections.abc import Callable 

29from functools import partial, wraps 

30from typing import Any 

31from warnings import warn 

32 

33from jax.typing import DTypeLike 

34 

35try: 

36 from numpy.lib.array_utils import normalize_axis_index # numpy 2 

37except ImportError: 

38 from numpy.core.numeric import normalize_axis_index # numpy 1 

39 

40from jax import ShapeDtypeStruct, eval_shape, jit, lax, tree 

41from jax import numpy as jnp 

42from jaxtyping import Array, PyTree, Shaped 

43 

44 

45def expand_axes(axes: PyTree[int | None], tree_arg: PyTree) -> PyTree[int | None]: 

46 """Expand `axes` such that they match the pytreedef of `tree_arg`.""" 

47 

48 def expand_axis(axis: int | None, subtree: PyTree) -> PyTree[int | None]: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

49 return tree.map(lambda _: axis, subtree) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

50 

51 return tree.map(expand_axis, axes, tree_arg, is_leaf=lambda x: x is None) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

52 

53 

54def normalize_axes( 

55 axes: PyTree[int | None, ' T'], tree_arg: PyTree[Array, ' T'] 

56) -> PyTree[int | None, ' T']: 

57 """Normalize axes to be non-negative and valid for the corresponding arrays in the tree_arg.""" 

58 

59 def normalize_axis(axis: int | None, x: Array) -> int | None: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

60 if axis is None: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

61 return None 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXngoY

62 else: 

63 return normalize_axis_index(axis, len(x.shape)) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

64 

65 return tree.map(normalize_axis, axes, tree_arg, is_leaf=lambda x: x is None) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

66 

67 

68def check_no_nones(axes: PyTree[int | None], tree_arg: PyTree) -> None: 

69 def check_not_none(_: object, axis: int | None) -> None: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

70 assert axis is not None 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

71 

72 tree.map(check_not_none, tree_arg, axes, is_leaf=lambda x: x is None) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

73 

74 

75def remove_axis( 

76 x: PyTree[ShapeDtypeStruct, ' T'], axis: PyTree[int, ' T'], ufunc: jnp.ufunc 

77) -> PyTree[ShapeDtypeStruct, ' T']: 

78 """Remove an axis from dummy arrays and change the type to reduction type.""" 

79 

80 def remove_axis(x: ShapeDtypeStruct, axis: int) -> ShapeDtypeStruct: 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXajg

81 new_shape = x.shape[:axis] + x.shape[axis + 1 :] 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXajg

82 new_dtype = reduction_dtype(ufunc, x.dtype) 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXajg

83 return ShapeDtypeStruct(new_shape, new_dtype) 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXajg

84 

85 return tree.map(remove_axis, x, axis) 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXajg

86 

87 

88def extract_size(axes: PyTree[int | None], tree_arg: PyTree) -> int: 

89 """Get the size of each array in tree_arg at the axis in axes, check they are equal and return it.""" 

90 

91 def get_size(x: object, axis: int | None) -> int | None: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

92 if axis is None: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

93 return None 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXngoY

94 else: 

95 return x.shape[axis] 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

96 

97 sizes = tree.map(get_size, tree_arg, axes) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

98 sizes, _ = tree.flatten(sizes) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

99 assert all(s == sizes[0] for s in sizes) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

100 return sizes[0] 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

101 

102 

103def sum_nbytes(tree_arg: PyTree[Array | ShapeDtypeStruct]) -> int: 

104 def nbytes(x: Array | ShapeDtypeStruct) -> int: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

105 return math.prod(x.shape) * x.dtype.itemsize 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

106 

107 return tree.reduce(lambda size, x: size + nbytes(x), tree_arg, 0) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

108 

109 

110def next_divisor_small(dividend: int, min_divisor: int) -> int: 

111 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 1kyzA4heiBfC_cbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZq0123*aY+,-./:;=?@[

112 if dividend % divisor == 0: 1kyzA4heiBfC_cbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZq0123*aY+,-./:;=?@[

113 return divisor 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZq0123*aY+,-./:;=?@[

114 return dividend 1eb

115 

116 

117def next_divisor_large(dividend: int, min_divisor: int) -> int: 

118 max_inv_divisor = dividend // min_divisor 1kcbdlmnrstupajgo

119 for inv_divisor in range(max_inv_divisor, 0, -1): 1kcbdlmnrstupajgo

120 if dividend % inv_divisor == 0: 1cbdnrstuajgo

121 return dividend // inv_divisor 1cbdnrstuajgo

122 return dividend 1kcbdlmpa

123 

124 

125def next_divisor(dividend: int, min_divisor: int) -> int: 

126 """Return divisor >= min_divisor such that dividend % divisor == 0.""" 

127 if dividend == 0: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

128 return min_divisor 1vwxZa]

129 if min_divisor * min_divisor <= dividend: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY+,-./:;=?@[

130 return next_divisor_small(dividend, min_divisor) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZq0123*aY+,-./:;=?@[

131 return next_divisor_large(dividend, min_divisor) 1kcbdlmnrstupajgo

132 

133 

134def pull_nonbatched( 

135 axes: PyTree[int | None], tree_arg: PyTree 

136) -> tuple[PyTree, PyTree]: 

137 def pull_nonbatched(x: object, axis: int | None) -> object: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

138 if axis is None: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

139 return None 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXngoY

140 else: 

141 return x 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

142 

143 return tree.map(pull_nonbatched, tree_arg, axes), tree_arg 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

144 

145 

146def push_nonbatched( 

147 axes: PyTree[int | None], tree_arg: PyTree, original_tree: PyTree 

148) -> PyTree[Any]: 

149 def push_nonbatched(original_x: object, x: object, axis: int | None) -> object: 1kheifcbdlmqnrstupajgo

150 if axis is None: 1kheifcbdlmqnrstupajgo

151 return original_x 1efcbngo

152 else: 

153 return x 1kheifcbdlmqnrstupajgo

154 

155 return tree.map(push_nonbatched, original_tree, tree_arg, axes) 1kheifcbdlmqnrstupajgo

156 

157 

158def move_axes_out(axes: PyTree[int], tree_arg: PyTree[Array]) -> PyTree[Array]: 

159 def move_axis_out(x: Array, axis: int) -> Array: 1kheifcbdlmqnrstupajgo

160 return jnp.moveaxis(x, axis, 0) 1kheifcbdlmqnrstupajgo

161 

162 return tree.map(move_axis_out, tree_arg, axes) 1kheifcbdlmqnrstupajgo

163 

164 

165def move_axes_in(axes: PyTree[int], tree_arg: PyTree[Array]) -> PyTree[Array]: 

166 def move_axis_in(x: Array, axis: int) -> Array: 1kheifcbdlmqnrstupajgo

167 return jnp.moveaxis(x, 0, axis) 1kheifcbdlmqnrstupajgo

168 

169 return tree.map(move_axis_in, tree_arg, axes) 1kheifcbdlmqnrstupajgo

170 

171 

172def batch(tree_arg: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']: 

173 """Split the first axis into two axes, the first of size `nbatches`.""" 

174 

175 def batch(x: Array) -> Array: 1kheifcbdlmqnrstupajgo

176 return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:]) 1kheifcbdlmqnrstupajgo

177 

178 return tree.map(batch, tree_arg) 1kheifcbdlmqnrstupajgo

179 

180 

181def unbatch(tree_arg: PyTree[Array, ' T']) -> PyTree[Array, ' T']: 

182 """Merge the first two axes into a single axis.""" 

183 

184 def unbatch(x: Array) -> Array: 1kheifcbdlmqnrstupao

185 return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) 1kheifcbdlmqnrstupao

186 

187 return tree.map(unbatch, tree_arg) 1kheifcbdlmqnrstupao

188 

189 

190def reduce( 

191 ufunc: jnp.ufunc, 

192 x: PyTree[Array, ' T'], 

193 axes: PyTree[int, ' T'], 

194 initial: PyTree[Array, ' T'] | None, 

195) -> PyTree[Array, ' T']: 

196 """Reduce each array in `x` along the axes in `axes` starting from `initial` using `ufunc.reduce`.""" 

197 if initial is None: 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXajg

198 

199 def reduce(x: Array, axis: int) -> Array: 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXa

200 return ufunc.reduce(x, axis=axis) 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXa

201 

202 return tree.map(reduce, x, axes) 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXa

203 

204 else: 

205 

206 def reduce(x: Array, initial: Array, axis: int) -> Array: 1ajg

207 reduced = ufunc.reduce(x, axis=axis) 1ajg

208 return ufunc(initial, reduced) 1ajg

209 

210 return tree.map(reduce, x, initial, axes) 1ajg

211 

212 

213def identity( 

214 ufunc: jnp.ufunc, x: PyTree[ShapeDtypeStruct, ' T'] 

215) -> PyTree[Array, ' T']: 

216 """Get the identity element for `ufunc` and each array in `x`.""" 

217 

218 def identity(x: ShapeDtypeStruct) -> Array: 1ajg

219 identity = identity_for(ufunc, x.dtype) 1ajg

220 return jnp.broadcast_to(identity, x.shape) 1ajg

221 

222 return tree.map(identity, x) 1ajg

223 

224 

225def reduction_dtype(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> DTypeLike: 

226 """Return the output dtype for a reduction with `ufunc` on inputs of type `dtype`.""" 

227 return ufunc.reduce(jnp.empty(1, input_dtype)).dtype 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXajg

228 

229 

230def identity_for(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> Shaped[Array, '']: 

231 """Return the identity for ufunc as an array scalar with the right dtype.""" 

232 # get output type from input type, e.g., int8 is accumulated to int32 

233 dtype = reduction_dtype(ufunc, input_dtype) 1ajg

234 

235 # return as explicitly typed array 

236 return jnp.array(ufunc.identity, dtype) 1ajg

237 

238 

239def check_same(tree1: PyTree, tree2: PyTree) -> None: 

240 def check_same(x1: Array | ShapeDtypeStruct, x2: Array | ShapeDtypeStruct) -> None: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

241 assert x1.shape == x2.shape 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

242 assert x1.dtype == x2.dtype 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

243 

244 tree.map(check_same, tree1, tree2) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

245 

246 

247class NotDefined: 

248 pass 

249 

250 

251def autobatch( 

252 func: Callable, 

253 max_io_nbytes: int, 

254 in_axes: PyTree[int | None] = 0, 

255 out_axes: PyTree[int] = 0, 

256 *, 

257 return_nbatches: bool = False, 

258 reduce_ufunc: jnp.ufunc | None = None, 

259 warn_on_overflow: bool = True, 

260 result_shape_dtype: PyTree[ShapeDtypeStruct] = NotDefined, 

261) -> Callable: 

262 """ 

263 Batch a function such that each batch is smaller than a threshold. 

264 

265 Parameters 

266 ---------- 

267 func 

268 A jittable function with positional arguments only, with inputs and 

269 outputs pytrees of arrays. 

270 max_io_nbytes 

271 The maximum number of input + output bytes in each batch (excluding 

272 unbatched arguments.) 

273 in_axes 

274 A tree matching (a prefix of) the structure of the function input, 

275 indicating along which axes each array should be batched. A `None` axis 

276 indicates to not batch an argument. 

277 out_axes 

278 The same for outputs (but non-batching is not allowed). 

279 return_nbatches 

280 If True, the number of batches is returned as a second output. 

281 reduce_ufunc 

282 Function used to reduce the output along the batched axis (e.g., 

283 `jax.numpy.add`). 

284 warn_on_overflow 

285 If True, a warning is raised if the memory limit could not be 

286 respected. 

287 result_shape_dtype 

288 A pytree of dummy arrays matching the expected output. If not provided, 

289 the function is traced an additional time to determine the output 

290 structure. 

291 

292 Returns 

293 ------- 

294 A function with the same signature as `func`, save for the return value if `return_nbatches`. 

295 

296 Notes 

297 ----- 

298 Unless `return_nbatches` or `reduce_ufunc` are set, `autobatch` at given 

299 arguments is idempotent. Furthermore, `autobatch` can be applied multiple 

300 times over multiple axes with the same `max_io_nbytes` limit to work on 

301 multiple axes; in this case it won't unnecessarily loop over additional axes 

302 if one or more outer `autobatch` are already sufficient. 

303 

304 To handle memory used in intermediate values: assuming all intermediate 

305 values have size that scales linearly with the axis batched over, say the 

306 batched input/output total size is ``batched_size * core_io_size``, and the 

307 intermediate values have size ``batched_size * core_int_size``, then to take 

308 them into account divide `max_io_nbytes` by ``(1 + core_int_size / 

309 core_io_size)``. 

310 """ 

311 

312 @jit 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

313 @wraps(func) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

314 def autobatch_wrapper(*args: PyTree) -> PyTree: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

315 return batched_func( 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

316 func, 

317 max_io_nbytes, 

318 in_axes, 

319 out_axes, 

320 return_nbatches, 

321 reduce_ufunc, 

322 warn_on_overflow, 

323 result_shape_dtype, 

324 args, 

325 ) 

326 

327 return autobatch_wrapper 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

328 

329 

330def batched_func( 

331 func: Callable, 

332 max_io_nbytes: int, 

333 in_axes: PyTree[int | None], 

334 out_axes: PyTree[int], 

335 return_nbatches: bool, 

336 reduce_ufunc: jnp.ufunc | None, 

337 warn_on_overflow: bool, 

338 result_shape_dtype: PyTree[ShapeDtypeStruct] | NotDefined, 

339 args: tuple[PyTree[Array], ...], 

340) -> PyTree[Array]: 

341 """Implement the wrapper used in `autobatch`.""" 

342 # determine the output structure of the function 

343 if result_shape_dtype is NotDefined: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

344 example_result = eval_shape(func, *args) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

345 else: 

346 example_result = result_shape_dtype 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWX

347 

348 # expand the axes pytrees if they are prefixes 

349 in_axes = expand_axes(in_axes, args) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

350 out_axes = expand_axes(out_axes, example_result) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

351 check_no_nones(out_axes, example_result) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

352 

353 # check the axes are valid 

354 in_axes = normalize_axes(in_axes, args) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

355 out_axes = normalize_axes(out_axes, example_result) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

356 

357 # get the size of the batched axis 

358 size = extract_size((in_axes, out_axes), (args, example_result)) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

359 

360 # split arguments in batched and not batched 

361 original_args = args 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

362 args, nonbatched_args = pull_nonbatched(in_axes, args) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

363 

364 # determine the number of batches to respect the memory limit 

365 total_nbytes = sum_nbytes((args, example_result)) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

366 min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

367 min_nbatches = max(1, min_nbatches) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

368 nbatches = next_divisor(size, min_nbatches) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

369 assert 1 <= nbatches <= max(1, size) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

370 assert size % nbatches == 0 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

371 assert total_nbytes % nbatches == 0 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

372 

373 # warn if the memory limit could not be respected 

374 batch_nbytes = total_nbytes // nbatches 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

375 if batch_nbytes > max_io_nbytes and warn_on_overflow: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

376 assert size == nbatches 1kcbdlmpa

377 msg = f'batch_nbytes = {batch_nbytes:_} > max_io_nbytes = {max_io_nbytes:_}' 1kcbdlmpa

378 warn(msg) 1kcbdlmpa

379 

380 # squeeze out the output dims that will be reduced 

381 if reduce_ufunc is not None: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

382 example_result = remove_axis(example_result, out_axes, reduce_ufunc) 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXajg

383 

384 if nbatches > 1: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

385 # prepare arguments for looping 

386 args = move_axes_out(in_axes, args) 1kheifcbdlmqnrstupajgo

387 args = batch(args, nbatches) 1kheifcbdlmqnrstupajgo

388 

389 # prepare carry for reduction 

390 if reduce_ufunc is None: 1kheifcbdlmqnrstupajgo

391 initial = None 1kheifcbdlmqnrstupao

392 else: 

393 initial = identity(reduce_ufunc, example_result) 1ajg

394 

395 # loop and invoke the function in batches 

396 loop = partial( 1kheifcbdlmqnrstupajgo

397 batching_loop, 

398 func=func, 

399 nonbatched_args=nonbatched_args, 

400 in_axes=in_axes, 

401 out_axes=out_axes, 

402 reduce_ufunc=reduce_ufunc, 

403 ) 

404 reduced_result, result = lax.scan(loop, initial, args) 1kheifcbdlmqnrstupajgo

405 

406 # remove auxiliary batching axis and reverse transposition 

407 if reduce_ufunc is None: 1kheifcbdlmqnrstupajgo

408 assert reduced_result is None 1kheifcbdlmqnrstupao

409 result = unbatch(result) 1kheifcbdlmqnrstupao

410 result = move_axes_in(out_axes, result) 1kheifcbdlmqnrstupao

411 else: 

412 assert result is None 1ajg

413 result = reduced_result 1ajg

414 

415 # trivial case: no batching needed 

416 else: 

417 result = func(*original_args) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZq0123*aY]+,-./:;=?@[

418 if reduce_ufunc is not None: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZq0123*aY]+,-./:;=?@[

419 result = reduce(reduce_ufunc, result, out_axes, None) 1yzAheiBfCcbdDEFGHIJKLMNOvwxPQRSTUVWXa

420 

421 check_same(example_result, result) 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

422 

423 if return_nbatches: 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3u*pajgoY]+,-./:;=?@[

424 return result, nbatches 10r1s2t3u*aj]

425 return result 1kyzA4heiBfCcbdDEFGHIJKL56789!MNOvwxP#$QRSTU%VWX'()lmZqn0r1s2t3upgoY+,-./:;=?@[

426 

427 

428def batching_loop( 

429 initial: PyTree[Array] | None, 

430 args: PyTree[Array], 

431 *, 

432 func: Callable, 

433 nonbatched_args: PyTree, 

434 in_axes: PyTree[int | None], 

435 out_axes: PyTree[int], 

436 reduce_ufunc: jnp.ufunc | None, 

437) -> tuple[PyTree[Array], None] | tuple[None, PyTree[Array]]: 

438 """Implement the batching loop in `autobatch`.""" 

439 # evaluate the function 

440 args = move_axes_in(in_axes, args) 1kheifcbdlmqnrstupajgo

441 args = push_nonbatched(in_axes, args, nonbatched_args) 1kheifcbdlmqnrstupajgo

442 result = func(*args) 1kheifcbdlmqnrstupajgo

443 

444 # unreduced case: transpose for concatenation and return 

445 if reduce_ufunc is None: 1kheifcbdlmqnrstupajgo

446 result = move_axes_out(out_axes, result) 1kheifcbdlmqnrstupao

447 return None, result 1kheifcbdlmqnrstupao

448 

449 # reduced case: reduce starting from initial 

450 else: 

451 reduced_result = reduce(reduce_ufunc, result, out_axes, initial) 1ajg

452 return reduced_result, None 1ajg