Coverage for src / bartz / jaxext / _autobatch.py: 100%
182 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-13 00:35 +0000
1# bartz/src/bartz/jaxext/_autobatch.py
2#
3# Copyright (c) 2025-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implementation of `autobatch`."""
27import math
28from collections.abc import Callable
29from functools import partial, wraps
30from warnings import warn
32from jax.typing import DTypeLike
34try:
35 from numpy.lib.array_utils import normalize_axis_index # numpy 2
36except ImportError:
37 from numpy.core.numeric import normalize_axis_index # numpy 1
39from jax import ShapeDtypeStruct, eval_shape, jit
40from jax import numpy as jnp
41from jax.lax import scan
42from jax.tree import flatten as tree_flatten
43from jax.tree import map as tree_map
44from jax.tree import reduce as tree_reduce
45from jaxtyping import Array, PyTree, Shaped
48def expand_axes(axes, tree):
49 """Expand `axes` such that they match the pytreedef of `tree`."""
51 def expand_axis(axis, subtree): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
52 return tree_map(lambda _: axis, subtree) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
54 return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
57def normalize_axes(
58 axes: PyTree[int | None, ' T'], tree: PyTree[Array, ' T']
59) -> PyTree[int | None, ' T']:
60 """Normalize axes to be non-negative and valid for the corresponding arrays in the tree."""
62 def normalize_axis(axis: int | None, x: Array) -> int | None: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
63 if axis is None: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
64 return None 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd5cf 6c'g
65 else:
66 return normalize_axis_index(axis, len(x.shape)) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
68 return tree_map(normalize_axis, axes, tree, is_leaf=lambda x: x is None) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
71def check_no_nones(axes, tree):
72 def check_not_none(_, axis): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
73 assert axis is not None 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
75 tree_map(check_not_none, tree, axes, is_leaf=lambda x: x is None) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
78def remove_axis(
79 x: PyTree[ShapeDtypeStruct, ' T'], axis: PyTree[int, ' T'], ufunc: jnp.ufunc
80) -> PyTree[ShapeDtypeStruct, ' T']:
81 """Remove an axis from dummy arrays and change the type to reduction type."""
83 def remove_axis(x: ShapeDtypeStruct, axis: int) -> ShapeDtypeStruct: 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf
84 new_shape = x.shape[:axis] + x.shape[axis + 1 :] 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf
85 new_dtype = reduction_dtype(ufunc, x.dtype) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf
86 return ShapeDtypeStruct(new_shape, new_dtype) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf
88 return tree_map(remove_axis, x, axis) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf
91def extract_size(axes, tree):
92 """Get the size of each array in tree at the axis in axes, check they are equal and return it."""
94 def get_size(x, axis): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
95 if axis is None: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
96 return None 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd5cf 6c'g
97 else:
98 return x.shape[axis] 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
100 sizes = tree_map(get_size, tree, axes) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
101 sizes, _ = tree_flatten(sizes) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
102 assert all(s == sizes[0] for s in sizes) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
103 return sizes[0] 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
106def sum_nbytes(tree):
107 def nbytes(x): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
108 return math.prod(x.shape) * x.dtype.itemsize 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
110 return tree_reduce(lambda size, x: size + nbytes(x), tree, 0) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
113def next_divisor_small(dividend, min_divisor):
114 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd)g*g+g,g`g{gXdYdZd0d|g1d2d3d4d}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[deh]d^d_d`dfh{d|d}d~dghaebecedehheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBenhCeDeEeFeohGeHeIeJephKeLeMeNeqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#ewh$e%e'e(exh)e*e+e,eEdYcZc0c1cyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgfEhhfifjfkfFhlfmfnfofGhpfqfrfsfHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQfNhRfSfTfUfOhVfWfXfYfPhZf0f1f2fQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?fWh@f[f]f^f'gXhYhZh0h1h2h3h4h5h6h7h
115 if dividend % divisor == 0: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd)g*g+g,g`g{gXdYdZd0d|g1d2d3d4d}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[deh]d^d_d`dfh{d|d}d~dghaebecedehheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBenhCeDeEeFeohGeHeIeJephKeLeMeNeqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#ewh$e%e'e(exh)e*e+e,eEdYcZc0c1cyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgfEhhfifjfkfFhlfmfnfofGhpfqfrfsfHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQfNhRfSfTfUfOhVfWfXfYfPhZf0f1f2fQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?fWh@f[f]f^f'gXhYhZh0h1h2h3h4h5h6h7h
116 return divisor 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd)g*g+g,g`g{gXdYdZd0d|g1d2d3d4d}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[deh]d^d_d`dfh{d|d}d~dghaebecedehheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBenhCeDeEeFeohGeHeIeJephKeLeMeNeqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#ewh$e%e'e(exh)e*e+e,eEdYcZc0c1cyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgfEhhfifjfkfFhlfmfnfofGhpfqfrfsfHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQfNhRfSfTfUfOhVfWfXfYfPhZf0f1f2fQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?fWh@f[f]f^f'gXhYhZh0h1h2h3h4h5h6h7h
117 return dividend 1da
120def next_divisor_large(dividend, min_divisor):
121 max_inv_divisor = dividend // min_divisor 2b a c 2c3c4c5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
122 for inv_divisor in range(max_inv_divisor, 0, -1): 2b a c 2c3c4c5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
123 if dividend % inv_divisor == 0: 2b a c 2c3c4c5cqdrdsdtdudjckclcmcvdncocpcqcwdrcsctcucxdvcwcxcycydzcAcBcCczdDcEcFcGcAdHcIcJcKcBdLcMcNcOcCdPcQcRcScDdTcUcVcWcXcf 6c
124 return dividend // inv_divisor 2b a c 2c3c4c5cqdrdsdtdudjckclcmcvdncocpcqcwdrcsctcucxdvcwcxcycydzcAcBcCczdDcEcFcGcAdHcIcJcKcBdLcMcNcOcCdPcQcRcScDdTcUcVcWcXcf 6c
125 return dividend 2b a c 2c3c4c7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I 'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! .c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbb[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bhd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecodfcgchcic
128def next_divisor(dividend, min_divisor):
129 """Return divisor >= min_divisor such that divided % divisor == 0."""
130 if dividend == 0: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
131 return min_divisor 2(g8h_f`f{f|f9h}f~fagbg!hcgdgegfg#hgghgigjg$hkglgmgng%hogpgqgrg'hsgtgugvg(hwgxgygzg)hAgBgCgDg*hEgFgGgHg+hIgJgKgLg,hMgNgOgPg-hQgRgSgTg.hUgVgWgXg/hYgZg0g1g:h2g3g4g5g;h6g7g8g9g=h!g#g$g%g?h
132 if min_divisor * min_divisor <= dividend: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E {gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M }g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[deh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 fh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' hheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBenhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | ohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfbqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#ewh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgfEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQfNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?fWh@f[f]f^fXcf 6c'gXhYhZh0h1h2h3h4h5h6h7h
133 return next_divisor_small(dividend, min_divisor) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd)g*g+g,g`g{gXdYdZd0d|g1d2d3d4d}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[deh]d^d_d`dfh{d|d}d~dghaebecedehheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBenhCeDeEeFeohGeHeIeJephKeLeMeNeqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#ewh$e%e'e(exh)e*e+e,eEdYcZc0c1cyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgfEhhfifjfkfFhlfmfnfofGhpfqfrfsfHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQfNhRfSfTfUfOhVfWfXfYfPhZf0f1f2fQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?fWh@f[f]f^f'gXhYhZh0h1h2h3h4h5h6h7h
134 return next_divisor_large(dividend, min_divisor) 2b a c 2c3c4c5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
137def pull_nonbatched(axes, tree):
138 def pull_nonbatched(x, axis): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
139 if axis is None: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
140 return None 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd5cf 6c'g
141 else:
142 return x 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
144 return tree_map(pull_nonbatched, tree, axes), tree 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
147def push_nonbatched(axes, tree, original_tree):
148 def push_nonbatched(original_x, x, axis): 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
149 if axis is None: 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
150 return original_x 2d e a 5cf 6c
151 else:
152 return x 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
154 return tree_map(push_nonbatched, original_tree, tree, axes) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
157def move_axes_out(axes, tree):
158 def move_axis_out(x, axis): 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
159 return jnp.moveaxis(x, axis, 0) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
161 return tree_map(move_axis_out, tree, axes) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
164def move_axes_in(axes, tree):
165 def move_axis_in(x, axis): 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
166 return jnp.moveaxis(x, 0, axis) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
168 return tree_map(move_axis_in, tree, axes) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
171def batch(tree: PyTree[Array, ' T'], nbatches: int) -> PyTree[Array, ' T']:
172 """Split the first axis into two axes, the first of size `nbatches`."""
174 def batch(x): 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
175 return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:]) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
177 return tree_map(batch, tree) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
180def unbatch(tree: PyTree[Array, ' T']) -> PyTree[Array, ' T']:
181 """Merge the first two axes into a single axis."""
183 def unbatch(x): 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c
184 return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c
186 return tree_map(unbatch, tree) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c
189def reduce(
190 ufunc: jnp.ufunc,
191 x: PyTree[Array, ' T'],
192 axes: PyTree[int, ' T'],
193 initial: PyTree[Array, ' T'] | None,
194) -> PyTree[Array, ' T']:
195 """Reduce each array in `x` along the axes in `axes` starting from `initial` using `ufunc.reduce`."""
196 if initial is None: 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf
198 def reduce(x: Array, axis: int) -> Array: 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd_f`f{f|fXdYdZd0d1d2d3d4d}f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dgghgigjg{d|d}d~daebecedekglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFesgtgugvgGeHeIeJeKeLeMeNewgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(eEgFgGgHg)e*e+e,eIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQgRgSgTglfmfnfofpfqfrfsfUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf2g3g4g5gVfWfXfYfZf0f1f2f6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^f
199 return ufunc.reduce(x, axis=axis) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd_f`f{f|fXdYdZd0d1d2d3d4d}f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dgghgigjg{d|d}d~daebecedekglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFesgtgugvgGeHeIeJeKeLeMeNewgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(eEgFgGgHg)e*e+e,eIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQgRgSgTglfmfnfofpfqfrfsfUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf2g3g4g5gVfWfXfYfZf0f1f2f6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^f
201 return tree_map(reduce, x, axes) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd_f`f{f|fXdYdZd0d1d2d3d4d}f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dgghgigjg{d|d}d~daebecedekglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFesgtgugvgGeHeIeJeKeLeMeNewgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(eEgFgGgHg)e*e+e,eIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQgRgSgTglfmfnfofpfqfrfsfUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf2g3g4g5gVfWfXfYfZf0f1f2f6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^f
203 else:
205 def reduce(x: Array, initial: Array, axis: int) -> Array: 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
206 reduced = ufunc.reduce(x, axis=axis) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
207 return ufunc(initial, reduced) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
209 return tree_map(reduce, x, initial, axes) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
212def identity(
213 ufunc: jnp.ufunc, x: PyTree[ShapeDtypeStruct, ' T']
214) -> PyTree[Array, ' T']:
215 """Get the identity element for `ufunc` and each array in `x`."""
217 def identity(x: ShapeDtypeStruct) -> Array: 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
218 identity = identity_for(ufunc, x.dtype) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
219 return jnp.broadcast_to(identity, x.shape) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
221 return tree_map(identity, x) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
224def reduction_dtype(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> DTypeLike:
225 """Return the output dtype for a reduction with `ufunc` on inputs of type `dtype`."""
226 return ufunc.reduce(jnp.empty(1, input_dtype)).dtype 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf
229def identity_for(ufunc: jnp.ufunc, input_dtype: DTypeLike) -> Shaped[Array, '']:
230 """Return the identity for ufunc as an array scalar with the right dtype."""
231 # get output type from input type, e.g., int8 is accumulated to int32
232 dtype = reduction_dtype(ufunc, input_dtype) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
234 # return as explicitly typed array
235 return jnp.array(ufunc.identity, dtype) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
238def check_same(tree1, tree2):
239 def check_same(x1, x2): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
240 assert x1.shape == x2.shape 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
241 assert x1.dtype == x2.dtype 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
243 tree_map(check_same, tree1, tree2) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
246class NotDefined:
247 pass
250def autobatch(
251 func: Callable,
252 max_io_nbytes: int,
253 in_axes: PyTree[int | None] = 0,
254 out_axes: PyTree[int] = 0,
255 *,
256 return_nbatches: bool = False,
257 reduce_ufunc: jnp.ufunc | None = None,
258 warn_on_overflow: bool = True,
259 result_shape_dtype: PyTree[ShapeDtypeStruct] = NotDefined,
260) -> Callable:
261 """
262 Batch a function such that each batch is smaller than a threshold.
264 Parameters
265 ----------
266 func
267 A jittable function with positional arguments only, with inputs and
268 outputs pytrees of arrays.
269 max_io_nbytes
270 The maximum number of input + output bytes in each batch (excluding
271 unbatched arguments.)
272 in_axes
273 A tree matching (a prefix of) the structure of the function input,
274 indicating along which axes each array should be batched. A `None` axis
275 indicates to not batch an argument.
276 out_axes
277 The same for outputs (but non-batching is not allowed).
278 return_nbatches
279 If True, the number of batches is returned as a second output.
280 reduce_ufunc
281 Function used to reduce the output along the batched axis (e.g.,
282 `jax.numpy.add`).
283 warn_on_overflow
284 If True, a warning is raised if the memory limit could not be
285 respected.
286 result_shape_dtype
287 A pytree of dummy arrays matching the expected output. If not provided,
288 the function is traced an additional time to determine the output
289 structure.
291 Returns
292 -------
293 A function with the same signature as `func`, save for the return value if `return_nbatches`.
295 Notes
296 -----
297 Unless `return_nbatches` or `reduce_ufunc` are set, `autobatch` at given
298 arguments is idempotent. Furthermore, `autobatch` can be applied multiple
299 times over multiple axes with the same `max_io_nbytes` limit to work on
300 multiple axes; in this case it won't unnecessarily loop over additional axes
301 if one or more outer `autobatch` are already sufficient.
303 To handle memory used in intermediate values: assuming all intermediate
304 values have size that scales linearly with the axis batched over, say the
305 batched input/output total size is ``batched_size * core_io_size``, and the
306 intermediate values have size ``batched_size * core_int_size``, then to take
307 them into account divide `max_io_nbytes` by ``(1 + core_int_size /
308 core_io_size)``.
309 """
311 @jit 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
312 @wraps(func) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
313 def autobatch_wrapper(*args): 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
314 return batched_func( 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
315 func,
316 max_io_nbytes,
317 in_axes,
318 out_axes,
319 return_nbatches,
320 reduce_ufunc,
321 warn_on_overflow,
322 result_shape_dtype,
323 args,
324 )
326 return autobatch_wrapper 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
329def batched_func(
330 func: Callable,
331 max_io_nbytes: int,
332 in_axes: PyTree[int | None],
333 out_axes: PyTree[int],
334 return_nbatches: bool,
335 reduce_ufunc: jnp.ufunc | None,
336 warn_on_overflow: bool,
337 result_shape_dtype: PyTree[ShapeDtypeStruct] | NotDefined,
338 args: tuple[PyTree[Array], ...],
339) -> PyTree[Array]:
340 """Implement the wrapper used in `autobatch`."""
341 # determine the output structure of the function
342 if result_shape_dtype is NotDefined: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
343 example_result = eval_shape(func, *args) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
344 else:
345 example_result = result_shape_dtype 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd
347 # expand the axes pytrees if they are prefixes
348 in_axes = expand_axes(in_axes, args) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
349 out_axes = expand_axes(out_axes, example_result) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
350 check_no_nones(out_axes, example_result) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
352 # check the axes are valid
353 in_axes = normalize_axes(in_axes, args) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
354 out_axes = normalize_axes(out_axes, example_result) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
356 # get the size of the batched axis
357 size = extract_size((in_axes, out_axes), (args, example_result)) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
359 # split arguments in batched and not batched
360 original_args = args 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
361 args, nonbatched_args = pull_nonbatched(in_axes, args) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
363 # determine the number of batches to respect the memory limit
364 total_nbytes = sum_nbytes((args, example_result)) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
365 min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
366 min_nbatches = max(1, min_nbatches) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
367 nbatches = next_divisor(size, min_nbatches) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
368 assert 1 <= nbatches <= max(1, size) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
369 assert size % nbatches == 0 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
370 assert total_nbytes % nbatches == 0 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
372 # warn if the memory limit could not be respected
373 batch_nbytes = total_nbytes // nbatches 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
374 if batch_nbytes > max_io_nbytes and warn_on_overflow: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
375 assert size == nbatches 2b a c 2c3c4c7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I 'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! .c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbb[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bhd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecodfcgchcic
376 msg = f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}' 2b a c 2c3c4c7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I 'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! .c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbb[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bhd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecodfcgchcic
377 warn(msg) 2b a c 2c3c4c7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I 'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! .c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbb[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bhd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecodfcgchcic
379 # squeeze out the output dims that will be reduced
380 if reduce_ufunc is not None: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
381 example_result = remove_axis(example_result, out_axes, reduce_ufunc) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWdl m n o p q r s t u v w x y z A B C D E _f`f{f|fXdYdZd0d1d2d3d4dF G H I jckclcmcncocpcqcJ K L M }f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dN O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 gghgigjg{d|d}d~daebecede7 8 9 ! rcsctcucvcwcxcyc# $ % ' kglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFe( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | sgtgugvgGeHeIeJeKeLeMeNe} ~ abbbzcAcBcCcDcEcFcGccbdbebfbwgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(egbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbEgFgGgHg)e*e+e,eYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9bQgRgSgTglfmfnfofpfqfrfsf!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*bUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bac2g3g4g5gVfWfXfYfZf0f1f2fbcccdcecPcQcRcScTcUcVcWcfcgchcic6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^fXcf
383 if nbatches > 1: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
384 # prepare arguments for looping
385 args = move_axes_out(in_axes, args) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
386 args = batch(args, nbatches) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
388 # prepare carry for reduction
389 if reduce_ufunc is None: 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
390 initial = None 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c
391 else:
392 initial = identity(reduce_ufunc, example_result) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
394 # loop and invoke the function in batches
395 loop = partial( 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
396 batching_loop,
397 func=func,
398 nonbatched_args=nonbatched_args,
399 in_axes=in_axes,
400 out_axes=out_axes,
401 reduce_ufunc=reduce_ufunc,
402 )
403 reduced_result, result = scan(loop, initial, args) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
405 # remove auxiliary batching axis and reverse transposition
406 if reduce_ufunc is None: 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
407 assert reduced_result is None 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c
408 result = unbatch(result) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c
409 result = move_axes_in(out_axes, result) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c
410 else:
411 assert result is None 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
412 result = reduced_result 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
414 # trivial case: no batching needed
415 else:
416 result = func(*original_args) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd)g*g+g,g`g8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d#hgghgigjgfh{d|d}d~dghaebecede$hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe'hsgtgugvgohGeHeIeJephKeLeMeNe(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e*hEgFgGgHgxh)e*e+e,e+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkf-hQgRgSgTgFhlfmfnfofGhpfqfrfsf.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUf:h2g3g4g5gOhVfWfXfYfPhZf0f1f2f;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^f'g?hXhYhZh0h1h2h3h4h5h6h7h
417 if reduce_ufunc is not None: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd)g*g+g,g`g8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d#hgghgigjgfh{d|d}d~dghaebecede$hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe'hsgtgugvgohGeHeIeJephKeLeMeNe(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e*hEgFgGgHgxh)e*e+e,e+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkf-hQgRgSgTgFhlfmfnfofGhpfqfrfsf.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUf:h2g3g4g5gOhVfWfXfYfPhZf0f1f2f;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^f'g?hXhYhZh0h1h2h3h4h5h6h7h
418 result = reduce(reduce_ufunc, result, out_axes, None) 2Fdg d h Gde Hdb a c IdJdKdLdMdNdOdPdQdRdSdi j k TdUdVdWd_f`f{f|fXdYdZd0d1d2d3d4d}f~fagbg5d6d7d8d9d!d#d$d%d'd(d)d*d+d,d-d.d/d:d;d=d?d@d[dcgdgegfg]d^d_d`dgghgigjg{d|d}d~daebecedekglgmgngeefegeheiejekelemeneoepeqereseteuevewexeyezeAeBeogpgqgrgCeDeEeFesgtgugvgGeHeIeJeKeLeMeNewgxgygzgOePeQeReSeTeUeVeWeXeYeZe0e1e2e3e4e5e6e7e8e9e!e#eAgBgCgDg$e%e'e(eEgFgGgHg)e*e+e,eIgJgKgLg-e.e/e:e;e=e?e@e[e]e^e_e`e{e|e}e~eafbfcfdfefffgfMgNgOgPghfifjfkfQgRgSgTglfmfnfofpfqfrfsfUgVgWgXgtfufvfwfxfyfzfAfBfCfDfEfFfGfHfIfJfKfLfMfNfOfPfQfYgZg0g1gRfSfTfUf2g3g4g5gVfWfXfYfZf0f1f2f6g7g8g9g3f4f5f6f7f8f9f!f#f$f%f'f(f)f*f+f,f-f.f/f:f;f=f?f!g#g$g%g@f[f]f^f
420 check_same(example_result, result) 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
422 if return_nbatches: 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd`g7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^fXcf 6c'g?hXhYhZh0h1h2h3h4h5h6h7h
423 return result, nbatches 2)gqd*grd+gsd,gtd`gXc?h
424 return result 2Fdg d h Gde Hdb a c IdJdKdLdMd-g.g/g:g;g=gNdOdPdQd?g@gRdSdi j k [gTdUdVdWd]g^g_g2c3c4c(gpd5c)gqd*grd+gsd,gtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E 8h_f`f{f|f{gXdYdZd0d|g1d2d3d4d%cF G H I udjckclcmcvdncocpcqc'cJ K L M 9h}f~fagbg}g5d6d7d8d~g9d!d#d$dah%d'd(d)dbh*d+d,d-dch.d/d:d;ddh=d?d@d[d!hcgdgegfgeh]d^d_d`d(cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 #hgghgigjgfh{d|d}d~dghaebecede-c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' $hkglgmgnghheefegeheihiejekelejhmeneoepekhqeresetelhuevewexemhyezeAeBe%hogpgqgrgnhCeDeEeFe/c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | 'hsgtgugvgohGeHeIeJephKeLeMeNe@c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb(hwgxgygzgqhOePeQeRerhSeTeUeVeshWeXeYeZeth0e1e2e3euh4e5e6e7evh8e9e!e#e)hAgBgCgDgwh$e%e'e(e]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzb*hEgFgGgHgxh)e*e+e,eEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPb+hIgJgKgLgyh-e.e/e:ezh;e=e?e@eAh[e]e^e_eBh`e{e|e}eCh~eafbfcfDhdfefffgf,hMgNgOgPgEhhfifjfkfbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9b-hQgRgSgTgFhlfmfnfofGhpfqfrfsfgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*b.hUgVgWgXgHhtfufvfwfIhxfyfzfAfJhBfCfDfEfKhFfGfHfIfLhJfKfLfMfMhNfOfPfQf/hYgZg0g1gNhRfSfTfUfid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bac:h2g3g4g5gOhVfWfXfYfPhZf0f1f2fndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcic;h6g7g8g9gQh3f4f5f6fRh7f8f9f!fSh#f$f%f'fTh(f)f*f+fUh,f-f.f/fVh:f;f=f?f=h!g#g$g%gWh@f[f]f^ff 6c'gXhYhZh0h1h2h3h4h5h6h7h
427def batching_loop(
428 initial, args, *, func, nonbatched_args, in_axes, out_axes, reduce_ufunc
429):
430 """Implement the batching loop in `autobatch`."""
431 # evaluate the function
432 args = move_axes_in(in_axes, args) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
433 args = push_nonbatched(in_axes, args, nonbatched_args) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
434 result = func(*args) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
436 # unreduced case: transpose for concatenation and return
437 if reduce_ufunc is None: 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8cl m n o 9cp q r s !ct u v w #cx y z A $cB C D E %cF G H I udjckclcmcvdncocpcqc'cJ K L M (cN O P Q )cR S T U *cV W X Y +cZ 0 1 2 ,c3 4 5 6 -c7 8 9 ! wdrcsctcucxdvcwcxcyc.c# $ % ' /c( ) * + :c, - . / ;c: ; = ? =c@ [ ] ^ ?c_ ` { | @c} ~ abbbydzcAcBcCczdDcEcFcGc[ccbdbebfb]cgbhbibjb^ckblbmbnb_cobpbqbrb`csbtbubvb{cwbxbybzbEdYcZc0c1c|cAbBbCbDb}cEbFbGbHb~cIbJbKbLbadMbNbObPbbdQbRbSbTbcdUbVbWbXbddYbZb0b1bed2b3b4b5bfd6b7b8b9bgd!b#b$b%bAdHcIcJcKcBdLcMcNcOchd'b(b)b*bid+b,b-b.bjd/b:b;b=bkd?b@b[b]bld^b_b`b{bmd|b}b~bacndbcccdcecCdPcQcRcScDdTcUcVcWcodfcgchcicXcf 6c
438 result = move_axes_out(out_axes, result) 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c
439 return None, result 2g d h e b a c i j k 2c3c4cpd5cqdrdsdtd7c8c9c!c#c$c%cudvd'c(c)c*c+c,c-cwdxd.c/c:c;c=c?c@cydzd[c]c^c_c`c{cEd|c}c~cadbdcdddedfdgdAdBdhdidjdkdldmdndCdDdod6c
441 # reduced case: reduce starting from initial
442 else:
443 reduced_result = reduce(reduce_ufunc, result, out_axes, initial) 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf
444 return reduced_result, None 2l m n o p q r s t u v w x y z A B C D E F G H I jckclcmcncocpcqcJ K L M N O P Q R S T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! rcsctcucvcwcxcyc# $ % ' ( ) * + , - . / : ; = ? @ [ ] ^ _ ` { | } ~ abbbzcAcBcCcDcEcFcGccbdbebfbgbhbibjbkblbmbnbobpbqbrbsbtbubvbwbxbybzbYcZc0c1cAbBbCbDbEbFbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b5b6b7b8b9b!b#b$b%bHcIcJcKcLcMcNcOc'b(b)b*b+b,b-b.b/b:b;b=b?b@b[b]b^b_b`b{b|b}b~bacbcccdcecPcQcRcScTcUcVcWcfcgchcicXcf