Coverage for src/bartz/_jaxext/_autobatch.py: 99%

178 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/_jaxext/_autobatch.py 

2# 

3# Copyright (c) 2025-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implementation of `autobatch`.""" 

26 

27import math 

28from collections.abc import Callable 

29from functools import partial, wraps 

30from typing import Any, Protocol, runtime_checkable 

31from warnings import warn 

32 

33from jax import ShapeDtypeStruct, eval_shape, lax, tree 

34from jax import numpy as jnp 

35from jax.typing import ArrayLike, DTypeLike 

36from jaxtyping import Array, PyTree, Shaped 

37from numpy.lib.array_utils import normalize_axis_index 

38 

39from bartz._jaxext._jit import jit 

40 

41 

42@runtime_checkable 

43class BinaryUfunc(Protocol): 

44 """Duck type of binary `jax.numpy.ufunc`s like `jnp.add`. 

45 

46 Mirrors the stub-only protocol `jax.numpy.BinaryUfunc`, which does not 

47 exist at runtime. 

48 """ 

49 

50 @property 

51 def identity(self) -> bool | int | float: ... 

52 

53 def __call__( 

54 self, x: Shaped[ArrayLike, '...'], y: Shaped[ArrayLike, '...'], / 

55 ) -> Shaped[Array, '...']: ... 

56 

57 def reduce( 

58 self, a: Shaped[ArrayLike, '...'], /, *, axis: int | None = 0 

59 ) -> Shaped[Array, '...']: ... 

60 

61 

62def expand_axes( 

63 axes: PyTree[int | None], tree_arg: PyTree, *, none_is_leaf: bool = True 

64) -> PyTree[int | None]: 

65 """Expand `axes` such that they match the pytreedef of `tree_arg`.""" 

66 

67 def expand_axis(axis: int | None, subtree: PyTree) -> PyTree[int | None]: 

68 return tree.map(lambda _: axis, subtree) 

69 

70 is_leaf = (lambda x: x is None) if none_is_leaf else None 

71 return tree.map(expand_axis, axes, tree_arg, is_leaf=is_leaf) 

72 

73 

74def normalize_axes( 

75 axes: PyTree[int | None, ' T'], 

76 tree_arg: PyTree[Array | ShapeDtypeStruct | None, ' T'], 

77) -> PyTree[int | None, ' T']: 

78 """Normalize axes to be non-negative and valid for the corresponding arrays in the tree_arg.""" 

79 

80 def normalize_axis( 

81 axis: int | None, x: Shaped[Array, '...'] | ShapeDtypeStruct | None 

82 ) -> int | None: 

83 if axis is None: 

84 return None 

85 else: 

86 assert x is not None 

87 return normalize_axis_index(axis, len(x.shape)) 

88 

89 return tree.map(normalize_axis, axes, tree_arg, is_leaf=lambda x: x is None) 

90 

91 

92def remove_axis( 

93 x: PyTree[ShapeDtypeStruct, ' T'], axis: PyTree[int, ' T'], ufunc: BinaryUfunc 

94) -> PyTree[ShapeDtypeStruct, ' T']: 

95 """Remove an axis from dummy arrays and change the type to reduction type.""" 

96 

97 def remove_axis(x: ShapeDtypeStruct, axis: int) -> ShapeDtypeStruct: 

98 new_shape = x.shape[:axis] + x.shape[axis + 1 :] 

99 new_dtype = reduction_dtype(ufunc, x.dtype) 

100 return ShapeDtypeStruct(new_shape, new_dtype) 

101 

102 return tree.map(remove_axis, x, axis) 

103 

104 

105def extract_size(axes: PyTree[int | None], tree_arg: PyTree) -> int: 

106 """Get the size of each array in tree_arg at the axis in axes, check they are equal and return it.""" 

107 

108 def get_size( 

109 x: Shaped[Array, '...'] | ShapeDtypeStruct, axis: int | None 

110 ) -> int | None: 

111 if axis is None: 

112 return None 

113 else: 

114 return x.shape[axis] 

115 

116 sizes = tree.map(get_size, tree_arg, axes) 

117 sizes, _ = tree.flatten(sizes) 

118 assert all(s == sizes[0] for s in sizes) 

119 return sizes[0] 

120 

121 

122def sum_nbytes(tree_arg: PyTree[Array | ShapeDtypeStruct]) -> int: 

123 def nbytes(x: Shaped[Array, '...'] | ShapeDtypeStruct) -> int: 

124 return math.prod(x.shape) * x.dtype.itemsize 

125 

126 return tree.reduce(lambda size, x: size + nbytes(x), tree_arg, 0) 

127 

128 

129def next_divisor_small(dividend: int, min_divisor: int) -> int: 

130 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 

131 if dividend % divisor == 0: 

132 return divisor 

133 return dividend 

134 

135 

136def next_divisor_large(dividend: int, min_divisor: int) -> int: 

137 max_inv_divisor = dividend // min_divisor 

138 for inv_divisor in range(max_inv_divisor, 0, -1): 

139 if dividend % inv_divisor == 0: 

140 return dividend // inv_divisor 

141 return dividend 

142 

143 

144def next_divisor(dividend: int, min_divisor: int) -> int: 

145 """Return divisor >= min_divisor such that dividend % divisor == 0.""" 

146 if dividend == 0: 

147 return min_divisor 

148 if min_divisor * min_divisor <= dividend: 

149 return next_divisor_small(dividend, min_divisor) 

150 return next_divisor_large(dividend, min_divisor) 

151 

152 

153def pull_nonbatched( 

154 axes: PyTree[int | None], tree_arg: PyTree 

155) -> tuple[PyTree, PyTree]: 

156 def pull_nonbatched(x: object, axis: int | None) -> object: 

157 if axis is None: 

158 return None 

159 else: 

160 return x 

161 

162 return tree.map(pull_nonbatched, tree_arg, axes), tree_arg 

163 

164 

165def push_nonbatched( 

166 axes: PyTree[int | None], tree_arg: PyTree, original_tree: PyTree 

167) -> PyTree[Any]: 

168 def push_nonbatched(original_x: object, x: object, axis: int | None) -> object: 

169 if axis is None: 

170 return original_x 

171 else: 

172 return x 

173 

174 return tree.map(push_nonbatched, original_tree, tree_arg, axes) 

175 

176 

177def move_axes_out(axes: PyTree[int], tree_arg: PyTree[Array]) -> PyTree[Array]: 

178 def move_axis_out(x: Shaped[Array, '...'], axis: int) -> Shaped[Array, '...']: 

179 return jnp.moveaxis(x, axis, 0) 

180 

181 return tree.map(move_axis_out, tree_arg, axes) 

182 

183 

184def move_axes_in(axes: PyTree[int], tree_arg: PyTree[Array]) -> PyTree[Array]: 

185 def move_axis_in(x: Shaped[Array, '...'], axis: int) -> Shaped[Array, '...']: 

186 return jnp.moveaxis(x, 0, axis) 

187 

188 return tree.map(move_axis_in, tree_arg, axes) 

189 

190 

191def batch(tree_arg: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']: 

192 """Split the first axis into two axes, the first of size `nbatches`.""" 

193 

194 def batch(x: Shaped[Array, '...']) -> Shaped[Array, '...']: 

195 return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:]) 

196 

197 return tree.map(batch, tree_arg) 

198 

199 

200def unbatch(tree_arg: PyTree[Array, ' T']) -> PyTree[Array, ' T']: 

201 """Merge the first two axes into a single axis.""" 

202 

203 def unbatch(x: Shaped[Array, '...']) -> Shaped[Array, '...']: 

204 return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) 

205 

206 return tree.map(unbatch, tree_arg) 

207 

208 

209def reduce( 

210 ufunc: BinaryUfunc, 

211 x: PyTree[Array, ' T'], 

212 axes: PyTree[int, ' T'], 

213 initial: PyTree[Array, ' T'] | None, 

214) -> PyTree[Array, ' T']: 

215 """Reduce each array in `x` along the axes in `axes` starting from `initial` using `ufunc.reduce`.""" 

216 if initial is None: 216 ↛ anywhereline 216 didn't jump anywhere: it always raised an exception.

217 

218 def reduce(x: Shaped[Array, '...'], axis: int) -> Shaped[Array, '...']: 

219 return ufunc.reduce(x, axis=axis) 

220 

221 return tree.map(reduce, x, axes) 

222 

223 else: 

224 

225 def reduce( 

226 x: Shaped[Array, '...'], initial: Shaped[Array, '...'], axis: int 

227 ) -> Shaped[Array, '...']: 

228 reduced = ufunc.reduce(x, axis=axis) 

229 return ufunc(initial, reduced) 

230 

231 return tree.map(reduce, x, initial, axes) 

232 

233 

234def identity( 

235 ufunc: BinaryUfunc, x: PyTree[ShapeDtypeStruct, ' T'] 

236) -> PyTree[Array, ' T']: 

237 """Get the identity element for `ufunc` and each array in `x`.""" 

238 

239 def identity(x: ShapeDtypeStruct) -> Shaped[Array, '...']: 

240 identity = identity_for(ufunc, x.dtype) 

241 return jnp.broadcast_to(identity, x.shape) 

242 

243 return tree.map(identity, x) 

244 

245 

246def reduction_dtype(ufunc: BinaryUfunc, input_dtype: DTypeLike) -> DTypeLike: 

247 """Return the output dtype for a reduction with `ufunc` on inputs of type `dtype`.""" 

248 return ufunc.reduce(jnp.empty(1, input_dtype)).dtype 

249 

250 

251def identity_for(ufunc: BinaryUfunc, input_dtype: DTypeLike) -> Shaped[Array, '']: 

252 """Return the identity for ufunc as an array scalar with the right dtype.""" 

253 # get output type from input type, e.g., int8 is accumulated to int32 

254 dtype = reduction_dtype(ufunc, input_dtype) 

255 

256 # return as explicitly typed array 

257 return jnp.array(ufunc.identity, dtype) 

258 

259 

260def check_same(tree1: PyTree, tree2: PyTree) -> None: 

261 def check_same( 

262 x1: Shaped[Array, '*shape'] | ShapeDtypeStruct, 

263 x2: Shaped[Array, '*shape'] | ShapeDtypeStruct, 

264 ) -> None: 

265 assert x1.shape == x2.shape 

266 assert x1.dtype == x2.dtype 

267 

268 tree.map(check_same, tree1, tree2) 

269 

270 

271class NotDefined: 

272 pass 

273 

274 

275def autobatch( 

276 func: Callable, 

277 max_io_nbytes: int, 

278 in_axes: PyTree[int | None] = 0, 

279 out_axes: PyTree[int] = 0, 

280 *, 

281 return_nbatches: bool = False, 

282 reduce_ufunc: BinaryUfunc | None = None, 

283 reduce_vary_axes: tuple[str, ...] = (), 

284 warn_on_overflow: bool = True, 

285 result_shape_dtype: PyTree[ShapeDtypeStruct] | type[NotDefined] = NotDefined, 

286) -> Callable: 

287 """ 

288 Batch a function such that each batch is smaller than a threshold. 

289 

290 Parameters 

291 ---------- 

292 func 

293 A jittable function with positional arguments only, with inputs and 

294 outputs pytrees of arrays. 

295 max_io_nbytes 

296 The maximum number of input + output bytes in each batch (excluding 

297 unbatched arguments.) 

298 in_axes 

299 A tree matching (a prefix of) the structure of the function input, 

300 indicating along which axes each array should be batched. A `None` axis 

301 indicates to not batch an argument. 

302 out_axes 

303 The same for outputs (but non-batching is not allowed). 

304 return_nbatches 

305 If True, the number of batches is returned as a second output. 

306 reduce_ufunc 

307 Function used to reduce the output along the batched axis (e.g., 

308 `jax.numpy.add`). 

309 reduce_vary_axes 

310 Manual `jax.shard_map` mesh axes over which the reduction accumulator 

311 varies. Under a `shard_map`, the reduction seed is `pcast` to vary over 

312 these axes so its type matches the shard-varying loop body, satisfying 

313 the VMA checker. Ignored unless `reduce_ufunc` is set. 

314 warn_on_overflow 

315 If True, a warning is raised if the memory limit could not be 

316 respected. 

317 result_shape_dtype 

318 A pytree of dummy arrays matching the expected output. If not provided, 

319 the function is traced an additional time to determine the output 

320 structure. 

321 

322 Returns 

323 ------- 

324 A function with the same signature as `func`, save for the return value if `return_nbatches`. 

325 

326 Notes 

327 ----- 

328 Unless `return_nbatches` or `reduce_ufunc` are set, `autobatch` at given 

329 arguments is idempotent. Furthermore, `autobatch` can be applied multiple 

330 times over multiple axes with the same `max_io_nbytes` limit to work on 

331 multiple axes; in this case it won't unnecessarily loop over additional axes 

332 if one or more outer `autobatch` are already sufficient. 

333 

334 To handle memory used in intermediate values: assuming all intermediate 

335 values have size that scales linearly with the axis batched over, say the 

336 batched input/output total size is ``batched_size * core_io_size``, and the 

337 intermediate values have size ``batched_size * core_int_size``, then to take 

338 them into account divide `max_io_nbytes` by ``(1 + core_int_size / 

339 core_io_size)``. 

340 """ 

341 

342 @jit 

343 @wraps(func) 

344 def autobatch_wrapper(*args: PyTree) -> PyTree: 

345 return batched_func( 

346 func, 

347 max_io_nbytes, 

348 in_axes, 

349 out_axes, 

350 return_nbatches, 

351 reduce_ufunc, 

352 reduce_vary_axes, 

353 warn_on_overflow, 

354 result_shape_dtype, 

355 args, 

356 ) 

357 

358 return autobatch_wrapper 

359 

360 

361def batched_func( 

362 func: Callable, 

363 max_io_nbytes: int, 

364 in_axes: PyTree[int | None], 

365 out_axes: PyTree[int], 

366 return_nbatches: bool, 

367 reduce_ufunc: BinaryUfunc | None, 

368 reduce_vary_axes: tuple[str, ...], 

369 warn_on_overflow: bool, 

370 result_shape_dtype: PyTree[ShapeDtypeStruct] | type[NotDefined], 

371 args: tuple[PyTree[Array], ...], 

372) -> PyTree[Array] | tuple[PyTree[Array], int]: 

373 """Implement the wrapper used in `autobatch`.""" 

374 # determine the output structure of the function 

375 if result_shape_dtype is NotDefined: 

376 example_result = eval_shape(func, *args) 

377 else: 

378 example_result = result_shape_dtype 

379 

380 # expand the axes pytrees if they are prefixes 

381 in_axes = expand_axes(in_axes, args) 

382 out_axes = expand_axes(out_axes, example_result, none_is_leaf=False) 

383 

384 # check the axes are valid 

385 in_axes = normalize_axes(in_axes, args) 

386 out_axes = normalize_axes(out_axes, example_result) 

387 

388 # get the size of the batched axis 

389 size = extract_size((in_axes, out_axes), (args, example_result)) 

390 

391 # split arguments in batched and not batched 

392 original_args = args 

393 args, nonbatched_args = pull_nonbatched(in_axes, args) 

394 

395 # determine the number of batches to respect the memory limit 

396 total_nbytes = sum_nbytes((args, example_result)) 

397 min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes) 

398 min_nbatches = max(1, min_nbatches) 

399 nbatches = next_divisor(size, min_nbatches) 

400 assert 1 <= nbatches <= max(1, size) 

401 assert size % nbatches == 0 

402 assert total_nbytes % nbatches == 0 

403 

404 # warn if the memory limit could not be respected 

405 batch_nbytes = total_nbytes // nbatches 

406 if batch_nbytes > max_io_nbytes and warn_on_overflow: 

407 assert size == nbatches 

408 msg = f'batch_nbytes = {batch_nbytes:_} > max_io_nbytes = {max_io_nbytes:_}' 

409 warn(msg) 

410 

411 # squeeze out the output dims that will be reduced 

412 if reduce_ufunc is not None: 

413 example_result = remove_axis(example_result, out_axes, reduce_ufunc) 

414 

415 if nbatches > 1: 

416 # prepare arguments for looping 

417 args = move_axes_out(in_axes, args) 

418 args = batch(args, nbatches) 

419 

420 # prepare carry for reduction 

421 if reduce_ufunc is None: 

422 initial = None 

423 else: 

424 initial = identity(reduce_ufunc, example_result) 

425 # under a `shard_map`, the loop body output varies over the manual 

426 # axes while this seed is replicated; mark it varying so the scan's 

427 # carry types match and the VMA checker is satisfied 

428 if reduce_vary_axes: 

429 initial = tree.map( 

430 lambda x: lax.pcast(x, reduce_vary_axes, to='varying'), initial 

431 ) 

432 

433 # loop and invoke the function in batches 

434 loop = partial( 

435 batching_loop, 

436 func=func, 

437 nonbatched_args=nonbatched_args, 

438 in_axes=in_axes, 

439 out_axes=out_axes, 

440 reduce_ufunc=reduce_ufunc, 

441 ) 

442 reduced_result, result = lax.scan(loop, initial, args) 

443 

444 # remove auxiliary batching axis and reverse transposition 

445 if reduce_ufunc is None: 

446 assert reduced_result is None 

447 result = unbatch(result) 

448 result = move_axes_in(out_axes, result) 

449 else: 

450 assert result is None 

451 result = reduced_result 

452 

453 # trivial case: no batching needed 

454 else: 

455 result = func(*original_args) 

456 if reduce_ufunc is not None: 

457 result = reduce(reduce_ufunc, result, out_axes, None) 

458 

459 check_same(example_result, result) 

460 

461 if return_nbatches: 

462 return result, nbatches 

463 return result 

464 

465 

466def batching_loop( 

467 initial: PyTree[Array] | None, 

468 args: PyTree[Array], 

469 *, 

470 func: Callable, 

471 nonbatched_args: PyTree, 

472 in_axes: PyTree[int | None], 

473 out_axes: PyTree[int], 

474 reduce_ufunc: BinaryUfunc | None, 

475) -> tuple[PyTree[Array], None] | tuple[None, PyTree[Array]]: 

476 """Implement the batching loop in `autobatch`.""" 

477 # evaluate the function 

478 args = move_axes_in(in_axes, args) 

479 args = push_nonbatched(in_axes, args, nonbatched_args) 

480 result = func(*args) 

481 

482 # unreduced case: transpose for concatenation and return 

483 if reduce_ufunc is None: 

484 result = move_axes_out(out_axes, result) 

485 return None, result 

486 

487 # reduced case: reduce starting from initial 

488 else: 

489 reduced_result = reduce(reduce_ufunc, result, out_axes, initial) 

490 return reduced_result, None