Coverage for src / bartz / jaxext / _autobatch.py: 100%

182 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-13 00:35 +0000

1# bartz/src/bartz/jaxext/_autobatch.py 

2# 

3# Copyright (c) 2025-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implementation of `autobatch`.""" 

26 

27import math 

28from collections.abc import Callable 

29from functools import partial, wraps 

30from warnings import warn 

31 

32from jax.typing import DTypeLike 

33 

34try: 

35 from numpy.lib.array_utils import normalize_axis_index # numpy 2 

36except ImportError: 

37 from numpy.core.numeric import normalize_axis_index # numpy 1 

38 

39from jax import ShapeDtypeStruct, eval_shape, jit 

40from jax import numpy as jnp 

41from jax.lax import scan 

42from jax.tree import flatten as tree_flatten 

43from jax.tree import map as tree_map 

44from jax.tree import reduce as tree_reduce 

45from jaxtyping import Array, PyTree, Shaped 

46 

47 

48def expand_axes(axes, tree): 

49 """Expand `axes` such that they match the pytreedef of `tree`.""" 

50 

51 def expand_axis(axis, subtree): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

52 return tree_map(lambda _: axis, subtree) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

53 

54 return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

55 

56 

57def normalize_axes( 

58 axes: PyTree[int | None, ' T'], tree: PyTree[Array, ' T'] 

59) -> PyTree[int | None, ' T']: 

60 """Normalize axes to be non-negative and valid for the corresponding arrays in the tree.""" 

61 

62 def normalize_axis(axis: int | None, x: Array) -> int | None: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

63 if axis is None: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

64 return None 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd5cf 6c'g

65 else: 

66 return normalize_axis_index(axis, len(x.shape)) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

67 

68 return tree_map(normalize_axis, axes, tree, is_leaf=lambda x: x is None) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

69 

70 

71def check_no_nones(axes, tree): 

72 def check_not_none(_, axis): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

73 assert axis is not None 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

74 

75 tree_map(check_not_none, tree, axes, is_leaf=lambda x: x is None) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

76 

77 

78def remove_axis( 

79 x: PyTree[ShapeDtypeStruct, ' T'], axis: PyTree[int, ' T'], ufunc: jnp.ufunc 

80) -> PyTree[ShapeDtypeStruct, ' T']: 

81 """Remove an axis from dummy arrays and change the type to reduction type.""" 

82 

83 def remove_axis(x: ShapeDtypeStruct, axis: int) -> ShapeDtypeStruct: 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf

84 new_shape = x.shape[:axis] + x.shape[axis + 1 :] 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf

85 new_dtype = reduction_dtype(ufunc, x.dtype) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf

86 return ShapeDtypeStruct(new_shape, new_dtype) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf

87 

88 return tree_map(remove_axis, x, axis) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf

89 

90 

91def extract_size(axes, tree): 

92 """Get the size of each array in tree at the axis in axes, check they are equal and return it.""" 

93 

94 def get_size(x, axis): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

95 if axis is None: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

96 return None 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd5cf 6c'g

97 else: 

98 return x.shape[axis] 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

99 

100 sizes = tree_map(get_size, tree, axes) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

101 sizes, _ = tree_flatten(sizes) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

102 assert all(s == sizes[0] for s in sizes) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

103 return sizes[0] 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

104 

105 

106def sum_nbytes(tree): 

107 def nbytes(x): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

108 return math.prod(x.shape) * x.dtype.itemsize 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

109 

110 return tree_reduce(lambda size, x: size + nbytes(x), tree, 0) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

111 

112 

113def next_divisor_small(dividend, min_divisor): 

114 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd)g*g+g,g`g{gXdYdZd0d|g1d2d3d4d}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[deh]d^d_d`dfh{d|d}d~dghaebecedehheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBenhCeDeEeFeohGeHeIeJephKeLeMeNeqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#ewh$e%e'e(exh)e*e+e,eEdYcZc0c1cyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgfEhhfifjfkfFhlfmfnfofGhpfqfrfsfHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQfNhRfSfTfUfOhVfWfXfYfPhZf0f1f2fQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?fWh@f[f]f^f'gXhYhZh0h1h2h3h4h5h6h7h

115 if dividend % divisor == 0: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd)g*g+g,g`g{gXdYdZd0d|g1d2d3d4d}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[deh]d^d_d`dfh{d|d}d~dghaebecedehheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBenhCeDeEeFeohGeHeIeJephKeLeMeNeqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#ewh$e%e'e(exh)e*e+e,eEdYcZc0c1cyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgfEhhfifjfkfFhlfmfnfofGhpfqfrfsfHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQfNhRfSfTfUfOhVfWfXfYfPhZf0f1f2fQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?fWh@f[f]f^f'gXhYhZh0h1h2h3h4h5h6h7h

116 return divisor 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd)g*g+g,g`g{gXdYdZd0d|g1d2d3d4d}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[deh]d^d_d`dfh{d|d}d~dghaebecedehheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBenhCeDeEeFeohGeHeIeJephKeLeMeNeqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#ewh$e%e'e(exh)e*e+e,eEdYcZc0c1cyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgfEhhfifjfkfFhlfmfnfofGhpfqfrfsfHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQfNhRfSfTfUfOhVfWfXfYfPhZf0f1f2fQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?fWh@f[f]f^f'gXhYhZh0h1h2h3h4h5h6h7h

117 return dividend 1da

118 

119 

120def next_divisor_large(dividend, min_divisor): 

121 max_inv_divisor = dividend // min_divisor 2b a c 2c3c4c5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

122 for inv_divisor in range(max_inv_divisor, 0, -1): 2b a c 2c3c4c5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

123 if dividend % inv_divisor == 0: 2b a c 2c3c4c5cqdrdsdtdudjckclcmcvdncocpcqcwdrcsctcucxdvcwcxcycydzcAcBcCczdDcEcFcGcAdHcIcJcKcBdLcMcNcOcCdPcQcRcScDdTcUcVcWcXcf 6c

124 return dividend // inv_divisor 2b a c 2c3c4c5cqdrdsdtdudjckclcmcvdncocpcqcwdrcsctcucxdvcwcxcycydzcAcBcCczdDcEcFcGcAdHcIcJcKcBdLcMcNcOcCdPcQcRcScDdTcUcVcWcXcf 6c

125 return dividend 2b a c 2c3c4c7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I 'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! .c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbb[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bhd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecodfcgchcic

126 

127 

128def next_divisor(dividend, min_divisor): 

129 """Return divisor >= min_divisor such that divided % divisor == 0.""" 

130 if dividend == 0: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

131 return min_divisor 2(g8h_f`f{f|f9h}f~fagbg!hcgdgegfg#hgghgigjg$hkglgmgng%hogpgqgrg'hsgtgugvg(hwgxgygzg)hAgBgCgDg*hEgFgGgHg+hIgJgKgLg,hMgNgOgPg-hQgRgSgTg.hUgVgWgXg/hYgZg0g1g:h2g3g4g5g;h6g7g8g9g=h!g#g$g%g?h

132 if min_divisor * min_divisor <= dividend: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E {gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M }g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[deh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 fh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' hheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBenhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | ohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfbqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#ewh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgfEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQfNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?fWh@f[f]f^fXcf 6c'gXhYhZh0h1h2h3h4h5h6h7h

133 return next_divisor_small(dividend, min_divisor) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd)g*g+g,g`g{gXdYdZd0d|g1d2d3d4d}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[deh]d^d_d`dfh{d|d}d~dghaebecedehheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBenhCeDeEeFeohGeHeIeJephKeLeMeNeqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#ewh$e%e'e(exh)e*e+e,eEdYcZc0c1cyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgfEhhfifjfkfFhlfmfnfofGhpfqfrfsfHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQfNhRfSfTfUfOhVfWfXfYfPhZf0f1f2fQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?fWh@f[f]f^f'gXhYhZh0h1h2h3h4h5h6h7h

134 return next_divisor_large(dividend, min_divisor) 2b a c 2c3c4c5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

135 

136 

137def pull_nonbatched(axes, tree): 

138 def pull_nonbatched(x, axis): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

139 if axis is None: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

140 return None 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd5cf 6c'g

141 else: 

142 return x 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

143 

144 return tree_map(pull_nonbatched, tree, axes), tree 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

145 

146 

147def push_nonbatched(axes, tree, original_tree): 

148 def push_nonbatched(original_x, x, axis): 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

149 if axis is None: 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

150 return original_x 2d e a 5cf 6c

151 else: 

152 return x 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

153 

154 return tree_map(push_nonbatched, original_tree, tree, axes) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

155 

156 

157def move_axes_out(axes, tree): 

158 def move_axis_out(x, axis): 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

159 return jnp.moveaxis(x, axis, 0) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

160 

161 return tree_map(move_axis_out, tree, axes) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

162 

163 

164def move_axes_in(axes, tree): 

165 def move_axis_in(x, axis): 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

166 return jnp.moveaxis(x, 0, axis) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

167 

168 return tree_map(move_axis_in, tree, axes) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

169 

170 

171def batch(tree: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']: 

172 """Split the first axis into two axes, the first of size `nbatches`.""" 

173 

174 def batch(x): 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

175 return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:]) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

176 

177 return tree_map(batch, tree) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

178 

179 

180def unbatch(tree: PyTree[Array, ' T']) -> PyTree[Array, ' T']: 

181 """Merge the first two axes into a single axis.""" 

182 

183 def unbatch(x): 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c

184 return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c

185 

186 return tree_map(unbatch, tree) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c

187 

188 

189def reduce( 

190 ufunc: jnp.ufunc, 

191 x: PyTree[Array, ' T'], 

192 axes: PyTree[int, ' T'], 

193 initial: PyTree[Array, ' T'] | None, 

194) -> PyTree[Array, ' T']: 

195 """Reduce each array in `x` along the axes in `axes` starting from `initial` using `ufunc.reduce`.""" 

196 if initial is None: 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf

197 

198 def reduce(x: Array, axis: int) -> Array: 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd_f`f{f|fXdYdZd0d1d2d3d4d}f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dgghgigjg{d|d}d~daebecedekglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFesgtgugvgGeHeIeJeKeLeMeNewgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(eEgFgGgHg)e*e+e,eIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQgRgSgTglfmfnfofpfqfrfsfUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf2g3g4g5gVfWfXfYfZf0f1f2f6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^f

199 return ufunc.reduce(x, axis=axis) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd_f`f{f|fXdYdZd0d1d2d3d4d}f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dgghgigjg{d|d}d~daebecedekglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFesgtgugvgGeHeIeJeKeLeMeNewgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(eEgFgGgHg)e*e+e,eIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQgRgSgTglfmfnfofpfqfrfsfUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf2g3g4g5gVfWfXfYfZf0f1f2f6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^f

200 

201 return tree_map(reduce, x, axes) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd_f`f{f|fXdYdZd0d1d2d3d4d}f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dgghgigjg{d|d}d~daebecedekglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFesgtgugvgGeHeIeJeKeLeMeNewgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(eEgFgGgHg)e*e+e,eIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQgRgSgTglfmfnfofpfqfrfsfUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf2g3g4g5gVfWfXfYfZf0f1f2f6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^f

202 

203 else: 

204 

205 def reduce(x: Array, initial: Array, axis: int) -> Array: 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

206 reduced = ufunc.reduce(x, axis=axis) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

207 return ufunc(initial, reduced) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

208 

209 return tree_map(reduce, x, initial, axes) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

210 

211 

212def identity( 

213 ufunc: jnp.ufunc, x: PyTree[ShapeDtypeStruct, ' T'] 

214) -> PyTree[Array, ' T']: 

215 """Get the identity element for `ufunc` and each array in `x`.""" 

216 

217 def identity(x: ShapeDtypeStruct) -> Array: 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

218 identity = identity_for(ufunc, x.dtype) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

219 return jnp.broadcast_to(identity, x.shape) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

220 

221 return tree_map(identity, x) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

222 

223 

224def reduction_dtype(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> DTypeLike: 

225 """Return the output dtype for a reduction with `ufunc` on inputs of type `dtype`.""" 

226 return ufunc.reduce(jnp.empty(1, input_dtype)).dtype 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf

227 

228 

229def identity_for(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> Shaped[Array, '']: 

230 """Return the identity for ufunc as an array scalar with the right dtype.""" 

231 # get output type from input type, e.g., int8 is accumulated to int32 

232 dtype = reduction_dtype(ufunc, input_dtype) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

233 

234 # return as explicitly typed array 

235 return jnp.array(ufunc.identity, dtype) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

236 

237 

238def check_same(tree1, tree2): 

239 def check_same(x1, x2): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

240 assert x1.shape == x2.shape 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

241 assert x1.dtype == x2.dtype 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

242 

243 tree_map(check_same, tree1, tree2) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

244 

245 

246class NotDefined: 

247 pass 

248 

249 

250def autobatch( 

251 func: Callable, 

252 max_io_nbytes: int, 

253 in_axes: PyTree[int | None] = 0, 

254 out_axes: PyTree[int] = 0, 

255 *, 

256 return_nbatches: bool = False, 

257 reduce_ufunc: jnp.ufunc | None = None, 

258 warn_on_overflow: bool = True, 

259 result_shape_dtype: PyTree[ShapeDtypeStruct] = NotDefined, 

260) -> Callable: 

261 """ 

262 Batch a function such that each batch is smaller than a threshold. 

263 

264 Parameters 

265 ---------- 

266 func 

267 A jittable function with positional arguments only, with inputs and 

268 outputs pytrees of arrays. 

269 max_io_nbytes 

270 The maximum number of input + output bytes in each batch (excluding 

271 unbatched arguments.) 

272 in_axes 

273 A tree matching (a prefix of) the structure of the function input, 

274 indicating along which axes each array should be batched. A `None` axis 

275 indicates to not batch an argument. 

276 out_axes 

277 The same for outputs (but non-batching is not allowed). 

278 return_nbatches 

279 If True, the number of batches is returned as a second output. 

280 reduce_ufunc 

281 Function used to reduce the output along the batched axis (e.g., 

282 `jax.numpy.add`). 

283 warn_on_overflow 

284 If True, a warning is raised if the memory limit could not be 

285 respected. 

286 result_shape_dtype 

287 A pytree of dummy arrays matching the expected output. If not provided, 

288 the function is traced an additional time to determine the output 

289 structure. 

290 

291 Returns 

292 ------- 

293 A function with the same signature as `func`, save for the return value if `return_nbatches`. 

294 

295 Notes 

296 ----- 

297 Unless `return_nbatches` or `reduce_ufunc` are set, `autobatch` at given 

298 arguments is idempotent. Furthermore, `autobatch` can be applied multiple 

299 times over multiple axes with the same `max_io_nbytes` limit to work on 

300 multiple axes; in this case it won't unnecessarily loop over additional axes 

301 if one or more outer `autobatch` are already sufficient. 

302 

303 To handle memory used in intermediate values: assuming all intermediate 

304 values have size that scales linearly with the axis batched over, say the 

305 batched input/output total size is ``batched_size * core_io_size``, and the 

306 intermediate values have size ``batched_size * core_int_size``, then to take 

307 them into account divide `max_io_nbytes` by ``(1 + core_int_size / 

308 core_io_size)``. 

309 """ 

310 

311 @jit 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

312 @wraps(func) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

313 def autobatch_wrapper(*args): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

314 return batched_func( 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

315 func, 

316 max_io_nbytes, 

317 in_axes, 

318 out_axes, 

319 return_nbatches, 

320 reduce_ufunc, 

321 warn_on_overflow, 

322 result_shape_dtype, 

323 args, 

324 ) 

325 

326 return autobatch_wrapper 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

327 

328 

329def batched_func( 

330 func: Callable, 

331 max_io_nbytes: int, 

332 in_axes: PyTree[int | None], 

333 out_axes: PyTree[int], 

334 return_nbatches: bool, 

335 reduce_ufunc: jnp.ufunc | None, 

336 warn_on_overflow: bool, 

337 result_shape_dtype: PyTree[ShapeDtypeStruct] | NotDefined, 

338 args: tuple[PyTree[Array], ...], 

339) -> PyTree[Array]: 

340 """Implement the wrapper used in `autobatch`.""" 

341 # determine the output structure of the function 

342 if result_shape_dtype is NotDefined: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

343 example_result = eval_shape(func, *args) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

344 else: 

345 example_result = result_shape_dtype 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd

346 

347 # expand the axes pytrees if they are prefixes 

348 in_axes = expand_axes(in_axes, args) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

349 out_axes = expand_axes(out_axes, example_result) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

350 check_no_nones(out_axes, example_result) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

351 

352 # check the axes are valid 

353 in_axes = normalize_axes(in_axes, args) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

354 out_axes = normalize_axes(out_axes, example_result) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

355 

356 # get the size of the batched axis 

357 size = extract_size((in_axes, out_axes), (args, example_result)) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

358 

359 # split arguments in batched and not batched 

360 original_args = args 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

361 args, nonbatched_args = pull_nonbatched(in_axes, args) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

362 

363 # determine the number of batches to respect the memory limit 

364 total_nbytes = sum_nbytes((args, example_result)) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

365 min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

366 min_nbatches = max(1, min_nbatches) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

367 nbatches = next_divisor(size, min_nbatches) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

368 assert 1 <= nbatches <= max(1, size) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

369 assert size % nbatches == 0 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

370 assert total_nbytes % nbatches == 0 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

371 

372 # warn if the memory limit could not be respected 

373 batch_nbytes = total_nbytes // nbatches 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

374 if batch_nbytes > max_io_nbytes and warn_on_overflow: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

375 assert size == nbatches 2b a c 2c3c4c7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I 'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! .c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbb[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bhd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecodfcgchcic

376 msg = f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}' 2b a c 2c3c4c7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I 'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! .c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbb[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bhd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecodfcgchcic

377 warn(msg) 2b a c 2c3c4c7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I 'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! .c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbb[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bhd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecodfcgchcic

378 

379 # squeeze out the output dims that will be reduced 

380 if reduce_ufunc is not None: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

381 example_result = remove_axis(example_result, out_axes, reduce_ufunc) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf

382 

383 if nbatches > 1: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

384 # prepare arguments for looping 

385 args = move_axes_out(in_axes, args) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

386 args = batch(args, nbatches) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

387 

388 # prepare carry for reduction 

389 if reduce_ufunc is None: 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

390 initial = None 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c

391 else: 

392 initial = identity(reduce_ufunc, example_result) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

393 

394 # loop and invoke the function in batches 

395 loop = partial( 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

396 batching_loop, 

397 func=func, 

398 nonbatched_args=nonbatched_args, 

399 in_axes=in_axes, 

400 out_axes=out_axes, 

401 reduce_ufunc=reduce_ufunc, 

402 ) 

403 reduced_result, result = scan(loop, initial, args) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

404 

405 # remove auxiliary batching axis and reverse transposition 

406 if reduce_ufunc is None: 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

407 assert reduced_result is None 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c

408 result = unbatch(result) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c

409 result = move_axes_in(out_axes, result) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c

410 else: 

411 assert result is None 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

412 result = reduced_result 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

413 

414 # trivial case: no batching needed 

415 else: 

416 result = func(*original_args) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd)g*g+g,g`g8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d#hgghgigjgfh{d|d}d~dghaebecede$hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe'hsgtgugvgohGeHeIeJephKeLeMeNe(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e*hEgFgGgHgxh)e*e+e,e+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkf-hQgRgSgTgFhlfmfnfofGhpfqfrfsf.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUf:h2g3g4g5gOhVfWfXfYfPhZf0f1f2f;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^f'g?hXhYhZh0h1h2h3h4h5h6h7h

417 if reduce_ufunc is not None: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd)g*g+g,g`g8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d#hgghgigjgfh{d|d}d~dghaebecede$hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe'hsgtgugvgohGeHeIeJephKeLeMeNe(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e*hEgFgGgHgxh)e*e+e,e+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkf-hQgRgSgTgFhlfmfnfofGhpfqfrfsf.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUf:h2g3g4g5gOhVfWfXfYfPhZf0f1f2f;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^f'g?hXhYhZh0h1h2h3h4h5h6h7h

418 result = reduce(reduce_ufunc, result, out_axes, None) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd_f`f{f|fXdYdZd0d1d2d3d4d}f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dgghgigjg{d|d}d~daebecedekglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFesgtgugvgGeHeIeJeKeLeMeNewgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(eEgFgGgHg)e*e+e,eIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQgRgSgTglfmfnfofpfqfrfsfUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf2g3g4g5gVfWfXfYfZf0f1f2f6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^f

419 

420 check_same(example_result, result) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

421 

422 if return_nbatches: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h

423 return result, nbatches 2)gqd*grd+gsd,gtd`gXc?h

424 return result 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^ff 6c'gXhYhZh0h1h2h3h4h5h6h7h

425 

426 

427def batching_loop( 

428 initial, args, *, func, nonbatched_args, in_axes, out_axes, reduce_ufunc 

429): 

430 """Implement the batching loop in `autobatch`.""" 

431 # evaluate the function 

432 args = move_axes_in(in_axes, args) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

433 args = push_nonbatched(in_axes, args, nonbatched_args) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

434 result = func(*args) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

435 

436 # unreduced case: transpose for concatenation and return 

437 if reduce_ufunc is None: 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c

438 result = move_axes_out(out_axes, result) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c

439 return None, result 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c

440 

441 # reduced case: reduce starting from initial 

442 else: 

443 reduced_result = reduce(reduce_ufunc, result, out_axes, initial) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf

444 return reduced_result, None 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf