Coverage for src / bartz / jaxext / _autobatch.py: 100%

176 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 18:11 +0000

1# bartz/src/bartz/jaxext/_autobatch.py 

2# 

3# Copyright (c) 2025-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implementation of `autobatch`.""" 

26 

27import math 

28from collections.abc import Callable 

29from functools import partial, wraps 

30from typing import Any 

31from warnings import warn 

32 

33from jax import ShapeDtypeStruct, eval_shape, jit, lax, tree 

34from jax import numpy as jnp 

35from jax.typing import DTypeLike 

36from jaxtyping import Array, PyTree, Shaped 

37from numpy.lib.array_utils import normalize_axis_index 

38 

39 

40def expand_axes(axes: PyTree[int | None], tree_arg: PyTree) -> PyTree[int | None]: 

41 """Expand `axes` such that they match the pytreedef of `tree_arg`.""" 

42 

43 def expand_axis(axis: int | None, subtree: PyTree) -> PyTree[int | None]: 1ab

44 return tree.map(lambda _: axis, subtree) 1ab

45 

46 return tree.map(expand_axis, axes, tree_arg, is_leaf=lambda x: x is None) 1ab

47 

48 

49def normalize_axes( 

50 axes: PyTree[int | None, ' T'], tree_arg: PyTree[Array, ' T'] 

51) -> PyTree[int | None, ' T']: 

52 """Normalize axes to be non-negative and valid for the corresponding arrays in the tree_arg.""" 

53 

54 def normalize_axis(axis: int | None, x: Array) -> int | None: 1ab

55 if axis is None: 1abf

56 return None 1af

57 else: 

58 return normalize_axis_index(axis, len(x.shape)) 1ab

59 

60 return tree.map(normalize_axis, axes, tree_arg, is_leaf=lambda x: x is None) 1ab

61 

62 

63def check_no_nones(axes: PyTree[int | None], tree_arg: PyTree) -> None: 

64 def check_not_none(_: object, axis: int | None) -> None: 1ab

65 assert axis is not None 1ab

66 

67 tree.map(check_not_none, tree_arg, axes, is_leaf=lambda x: x is None) 1ab

68 

69 

70def remove_axis( 

71 x: PyTree[ShapeDtypeStruct, ' T'], axis: PyTree[int, ' T'], ufunc: jnp.ufunc 

72) -> PyTree[ShapeDtypeStruct, ' T']: 

73 """Remove an axis from dummy arrays and change the type to reduction type.""" 

74 

75 def remove_axis(x: ShapeDtypeStruct, axis: int) -> ShapeDtypeStruct: 1ae

76 new_shape = x.shape[:axis] + x.shape[axis + 1 :] 1ae

77 new_dtype = reduction_dtype(ufunc, x.dtype) 1ae

78 return ShapeDtypeStruct(new_shape, new_dtype) 1ae

79 

80 return tree.map(remove_axis, x, axis) 1ae

81 

82 

83def extract_size(axes: PyTree[int | None], tree_arg: PyTree) -> int: 

84 """Get the size of each array in tree_arg at the axis in axes, check they are equal and return it.""" 

85 

86 def get_size(x: object, axis: int | None) -> int | None: 1ab

87 if axis is None: 1abf

88 return None 1af

89 else: 

90 return x.shape[axis] 1ab

91 

92 sizes = tree.map(get_size, tree_arg, axes) 1ab

93 sizes, _ = tree.flatten(sizes) 1ab

94 assert all(s == sizes[0] for s in sizes) 1ab

95 return sizes[0] 1ab

96 

97 

98def sum_nbytes(tree_arg: PyTree[Array | ShapeDtypeStruct]) -> int: 

99 def nbytes(x: Array | ShapeDtypeStruct) -> int: 1ab

100 return math.prod(x.shape) * x.dtype.itemsize 1ab

101 

102 return tree.reduce(lambda size, x: size + nbytes(x), tree_arg, 0) 1ab

103 

104 

105def next_divisor_small(dividend: int, min_divisor: int) -> int: 

106 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 1ahb

107 if dividend % divisor == 0: 1ahb

108 return divisor 1ab

109 return dividend 1ah

110 

111 

112def next_divisor_large(dividend: int, min_divisor: int) -> int: 

113 max_inv_divisor = dividend // min_divisor 1ac

114 for inv_divisor in range(max_inv_divisor, 0, -1): 1aihcg

115 if dividend % inv_divisor == 0: 1aihc

116 return dividend // inv_divisor 1ac

117 return dividend 1ag

118 

119 

120def next_divisor(dividend: int, min_divisor: int) -> int: 

121 """Return divisor >= min_divisor such that dividend % divisor == 0.""" 

122 if dividend == 0: 1ajbk

123 return min_divisor 1jk

124 if min_divisor * min_divisor <= dividend: 1abc

125 return next_divisor_small(dividend, min_divisor) 1ab

126 return next_divisor_large(dividend, min_divisor) 1ac

127 

128 

129def pull_nonbatched( 

130 axes: PyTree[int | None], tree_arg: PyTree 

131) -> tuple[PyTree, PyTree]: 

132 def pull_nonbatched(x: object, axis: int | None) -> object: 1ab

133 if axis is None: 1abf

134 return None 1af

135 else: 

136 return x 1ab

137 

138 return tree.map(pull_nonbatched, tree_arg, axes), tree_arg 1ab

139 

140 

141def push_nonbatched( 

142 axes: PyTree[int | None], tree_arg: PyTree, original_tree: PyTree 

143) -> PyTree[Any]: 

144 def push_nonbatched(original_x: object, x: object, axis: int | None) -> object: 1ac

145 if axis is None: 1acf

146 return original_x 1af

147 else: 

148 return x 1ac

149 

150 return tree.map(push_nonbatched, original_tree, tree_arg, axes) 1ac

151 

152 

153def move_axes_out(axes: PyTree[int], tree_arg: PyTree[Array]) -> PyTree[Array]: 

154 def move_axis_out(x: Array, axis: int) -> Array: 1ac

155 return jnp.moveaxis(x, axis, 0) 1ac

156 

157 return tree.map(move_axis_out, tree_arg, axes) 1ac

158 

159 

160def move_axes_in(axes: PyTree[int], tree_arg: PyTree[Array]) -> PyTree[Array]: 

161 def move_axis_in(x: Array, axis: int) -> Array: 1ac

162 return jnp.moveaxis(x, 0, axis) 1ac

163 

164 return tree.map(move_axis_in, tree_arg, axes) 1ac

165 

166 

167def batch(tree_arg: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']: 

168 """Split the first axis into two axes, the first of size `nbatches`.""" 

169 

170 def batch(x: Array) -> Array: 1ac

171 return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:]) 1ac

172 

173 return tree.map(batch, tree_arg) 1ac

174 

175 

176def unbatch(tree_arg: PyTree[Array, ' T']) -> PyTree[Array, ' T']: 

177 """Merge the first two axes into a single axis.""" 

178 

179 def unbatch(x: Array) -> Array: 1ac

180 return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) 1ac

181 

182 return tree.map(unbatch, tree_arg) 1ac

183 

184 

185def reduce( 

186 ufunc: jnp.ufunc, 

187 x: PyTree[Array, ' T'], 

188 axes: PyTree[int, ' T'], 

189 initial: PyTree[Array, ' T'] | None, 

190) -> PyTree[Array, ' T']: 

191 """Reduce each array in `x` along the axes in `axes` starting from `initial` using `ufunc.reduce`.""" 

192 if initial is None: 1ae

193 

194 def reduce(x: Array, axis: int) -> Array: 1ae

195 return ufunc.reduce(x, axis=axis) 1ae

196 

197 return tree.map(reduce, x, axes) 1ae

198 

199 else: 

200 

201 def reduce(x: Array, initial: Array, axis: int) -> Array: 1e

202 reduced = ufunc.reduce(x, axis=axis) 1e

203 return ufunc(initial, reduced) 1e

204 

205 return tree.map(reduce, x, initial, axes) 1e

206 

207 

208def identity( 

209 ufunc: jnp.ufunc, x: PyTree[ShapeDtypeStruct, ' T'] 

210) -> PyTree[Array, ' T']: 

211 """Get the identity element for `ufunc` and each array in `x`.""" 

212 

213 def identity(x: ShapeDtypeStruct) -> Array: 1e

214 identity = identity_for(ufunc, x.dtype) 1e

215 return jnp.broadcast_to(identity, x.shape) 1e

216 

217 return tree.map(identity, x) 1e

218 

219 

220def reduction_dtype(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> DTypeLike: 

221 """Return the output dtype for a reduction with `ufunc` on inputs of type `dtype`.""" 

222 return ufunc.reduce(jnp.empty(1, input_dtype)).dtype 1ae

223 

224 

225def identity_for(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> Shaped[Array, '']: 

226 """Return the identity for ufunc as an array scalar with the right dtype.""" 

227 # get output type from input type, e.g., int8 is accumulated to int32 

228 dtype = reduction_dtype(ufunc, input_dtype) 1e

229 

230 # return as explicitly typed array 

231 return jnp.array(ufunc.identity, dtype) 1e

232 

233 

234def check_same(tree1: PyTree, tree2: PyTree) -> None: 

235 def check_same(x1: Array | ShapeDtypeStruct, x2: Array | ShapeDtypeStruct) -> None: 1ab

236 assert x1.shape == x2.shape 1ab

237 assert x1.dtype == x2.dtype 1ab

238 

239 tree.map(check_same, tree1, tree2) 1ab

240 

241 

242class NotDefined: 

243 pass 

244 

245 

246def autobatch( 

247 func: Callable, 

248 max_io_nbytes: int, 

249 in_axes: PyTree[int | None] = 0, 

250 out_axes: PyTree[int] = 0, 

251 *, 

252 return_nbatches: bool = False, 

253 reduce_ufunc: jnp.ufunc | None = None, 

254 warn_on_overflow: bool = True, 

255 result_shape_dtype: PyTree[ShapeDtypeStruct] = NotDefined, 

256) -> Callable: 

257 """ 

258 Batch a function such that each batch is smaller than a threshold. 

259 

260 Parameters 

261 ---------- 

262 func 

263 A jittable function with positional arguments only, with inputs and 

264 outputs pytrees of arrays. 

265 max_io_nbytes 

266 The maximum number of input + output bytes in each batch (excluding 

267 unbatched arguments.) 

268 in_axes 

269 A tree matching (a prefix of) the structure of the function input, 

270 indicating along which axes each array should be batched. A `None` axis 

271 indicates to not batch an argument. 

272 out_axes 

273 The same for outputs (but non-batching is not allowed). 

274 return_nbatches 

275 If True, the number of batches is returned as a second output. 

276 reduce_ufunc 

277 Function used to reduce the output along the batched axis (e.g., 

278 `jax.numpy.add`). 

279 warn_on_overflow 

280 If True, a warning is raised if the memory limit could not be 

281 respected. 

282 result_shape_dtype 

283 A pytree of dummy arrays matching the expected output. If not provided, 

284 the function is traced an additional time to determine the output 

285 structure. 

286 

287 Returns 

288 ------- 

289 A function with the same signature as `func`, save for the return value if `return_nbatches`. 

290 

291 Notes 

292 ----- 

293 Unless `return_nbatches` or `reduce_ufunc` are set, `autobatch` at given 

294 arguments is idempotent. Furthermore, `autobatch` can be applied multiple 

295 times over multiple axes with the same `max_io_nbytes` limit to work on 

296 multiple axes; in this case it won't unnecessarily loop over additional axes 

297 if one or more outer `autobatch` are already sufficient. 

298 

299 To handle memory used in intermediate values: assuming all intermediate 

300 values have size that scales linearly with the axis batched over, say the 

301 batched input/output total size is ``batched_size * core_io_size``, and the 

302 intermediate values have size ``batched_size * core_int_size``, then to take 

303 them into account divide `max_io_nbytes` by ``(1 + core_int_size / 

304 core_io_size)``. 

305 """ 

306 

307 @jit 1ab

308 @wraps(func) 1ab

309 def autobatch_wrapper(*args: PyTree) -> PyTree: 1ab

310 return batched_func( 1ab

311 func, 

312 max_io_nbytes, 

313 in_axes, 

314 out_axes, 

315 return_nbatches, 

316 reduce_ufunc, 

317 warn_on_overflow, 

318 result_shape_dtype, 

319 args, 

320 ) 

321 

322 return autobatch_wrapper 1ab

323 

324 

325def batched_func( 

326 func: Callable, 

327 max_io_nbytes: int, 

328 in_axes: PyTree[int | None], 

329 out_axes: PyTree[int], 

330 return_nbatches: bool, 

331 reduce_ufunc: jnp.ufunc | None, 

332 warn_on_overflow: bool, 

333 result_shape_dtype: PyTree[ShapeDtypeStruct] | NotDefined, 

334 args: tuple[PyTree[Array], ...], 

335) -> PyTree[Array]: 

336 """Implement the wrapper used in `autobatch`.""" 

337 # determine the output structure of the function 

338 if result_shape_dtype is NotDefined: 1almb

339 example_result = eval_shape(func, *args) 1ab

340 else: 

341 example_result = result_shape_dtype 1alm

342 

343 # expand the axes pytrees if they are prefixes 

344 in_axes = expand_axes(in_axes, args) 1ab

345 out_axes = expand_axes(out_axes, example_result) 1ab

346 check_no_nones(out_axes, example_result) 1ab

347 

348 # check the axes are valid 

349 in_axes = normalize_axes(in_axes, args) 1ab

350 out_axes = normalize_axes(out_axes, example_result) 1ab

351 

352 # get the size of the batched axis 

353 size = extract_size((in_axes, out_axes), (args, example_result)) 1ab

354 

355 # split arguments in batched and not batched 

356 original_args = args 1ab

357 args, nonbatched_args = pull_nonbatched(in_axes, args) 1ab

358 

359 # determine the number of batches to respect the memory limit 

360 total_nbytes = sum_nbytes((args, example_result)) 1ab

361 min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes) 1ab

362 min_nbatches = max(1, min_nbatches) 1ab

363 nbatches = next_divisor(size, min_nbatches) 1ab

364 assert 1 <= nbatches <= max(1, size) 1ab

365 assert size % nbatches == 0 1ab

366 assert total_nbytes % nbatches == 0 1ab

367 

368 # warn if the memory limit could not be respected 

369 batch_nbytes = total_nbytes // nbatches 1ab

370 if batch_nbytes > max_io_nbytes and warn_on_overflow: 1abg

371 assert size == nbatches 1ag

372 msg = f'batch_nbytes = {batch_nbytes:_} > max_io_nbytes = {max_io_nbytes:_}' 1ag

373 warn(msg) 1ag

374 

375 # squeeze out the output dims that will be reduced 

376 if reduce_ufunc is not None: 1abe

377 example_result = remove_axis(example_result, out_axes, reduce_ufunc) 1ae

378 

379 if nbatches > 1: 1abc

380 # prepare arguments for looping 

381 args = move_axes_out(in_axes, args) 1ac

382 args = batch(args, nbatches) 1ac

383 

384 # prepare carry for reduction 

385 if reduce_ufunc is None: 1ace

386 initial = None 1ac

387 else: 

388 initial = identity(reduce_ufunc, example_result) 1e

389 

390 # loop and invoke the function in batches 

391 loop = partial( 1ac

392 batching_loop, 

393 func=func, 

394 nonbatched_args=nonbatched_args, 

395 in_axes=in_axes, 

396 out_axes=out_axes, 

397 reduce_ufunc=reduce_ufunc, 

398 ) 

399 reduced_result, result = lax.scan(loop, initial, args) 1ac

400 

401 # remove auxiliary batching axis and reverse transposition 

402 if reduce_ufunc is None: 1ace

403 assert reduced_result is None 1ac

404 result = unbatch(result) 1ac

405 result = move_axes_in(out_axes, result) 1ac

406 else: 

407 assert result is None 1e

408 result = reduced_result 1e

409 

410 # trivial case: no batching needed 

411 else: 

412 result = func(*original_args) 1ab

413 if reduce_ufunc is not None: 1abe

414 result = reduce(reduce_ufunc, result, out_axes, None) 1ae

415 

416 check_same(example_result, result) 1ab

417 

418 if return_nbatches: 1ab

419 return result, nbatches 1b

420 return result 1ab

421 

422 

423def batching_loop( 

424 initial: PyTree[Array] | None, 

425 args: PyTree[Array], 

426 *, 

427 func: Callable, 

428 nonbatched_args: PyTree, 

429 in_axes: PyTree[int | None], 

430 out_axes: PyTree[int], 

431 reduce_ufunc: jnp.ufunc | None, 

432) -> tuple[PyTree[Array], None] | tuple[None, PyTree[Array]]: 

433 """Implement the batching loop in `autobatch`.""" 

434 # evaluate the function 

435 args = move_axes_in(in_axes, args) 1ac

436 args = push_nonbatched(in_axes, args, nonbatched_args) 1ac

437 result = func(*args) 1ac

438 

439 # unreduced case: transpose for concatenation and return 

440 if reduce_ufunc is None: 1ace

441 result = move_axes_out(out_axes, result) 1ac

442 return None, result 1ac

443 

444 # reduced case: reduce starting from initial 

445 else: 

446 reduced_result = reduce(reduce_ufunc, result, out_axes, initial) 1e

447 return reduced_result, None 1e