Coverage for src / bartz / grove / _grove.py: 93%

187 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 18:11 +0000

1# bartz/src/bartz/grove/_grove.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions to create and manipulate binary decision trees.""" 

26 

27import math 

28from dataclasses import fields 

29from functools import partial 

30from typing import Literal, Protocol 

31 

32from equinox import Module 

33from jax import jit, lax, vmap 

34from jax import numpy as jnp 

35from jaxtyping import Array, Bool, Float32, Int32, Shaped, UInt 

36from numpy.lib.array_utils import normalize_axis_tuple 

37 

38from bartz.jaxext import autobatch, minimal_unsigned_dtype, vmap_nodoc 

39 

40 

41class TreeHeaps(Protocol): 

42 """A protocol for dataclasses that represent trees. 

43 

44 A tree is represented with arrays as a heap. The root node is at index 1. 

45 The children nodes of a node at index :math:`i` are at indices :math:`2i` 

46 (left child) and :math:`2i + 1` (right child). The array element at index 0 

47 is unused. 

48 

49 Since the nodes at the bottom can only be leaves and not decision nodes, 

50 `var_tree` and `split_tree` are half as long as `leaf_tree`. 

51 

52 Arrays may have additional initial axes to represent multiple trees. 

53 """ 

54 

55 leaf_tree: ( 

56 Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d'] 

57 ) 

58 """The values in the leaves of the trees. This array can be dirty, i.e., 

59 unused nodes can have whatever value. It may have an additional axis 

60 for multivariate leaves.""" 

61 

62 var_tree: UInt[Array, '*batch_shape 2**(d-1)'] 

63 """The axes along which the decision nodes operate. This array can be 

64 dirty but for the always unused node at index 0 which must be set to 0.""" 

65 

66 split_tree: UInt[Array, '*batch_shape 2**(d-1)'] 

67 """The decision boundaries of the trees. The boundaries are open on the 

68 right, i.e., a point belongs to the left child iff x < split. Whether a 

69 node is a leaf is indicated by the corresponding 'split' element being 

70 0. Unused nodes also have split set to 0. This array can't be dirty.""" 

71 

72 

73class TreesTrace(Module): 

74 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace.""" 

75 

76 leaf_tree: ( 

77 Float32[Array, '*trace_shape num_trees 2**d'] 

78 | Float32[Array, '*trace_shape num_trees k 2**d'] 

79 ) 

80 var_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)'] 

81 split_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)'] 

82 

83 @classmethod 

84 def from_dataclass(cls, obj: TreeHeaps) -> 'TreesTrace': 

85 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`.""" 

86 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) 1agc

87 

88 

89def tree_depth(tree: Shaped[Array, '*batch_shape 2**d']) -> int: 

90 """ 

91 Return the maximum depth of a tree. 

92 

93 Parameters 

94 ---------- 

95 tree 

96 A tree array like those in a `TreeHeaps`. If the array is ND, the tree 

97 structure is assumed to be along the last axis. 

98 

99 Returns 

100 ------- 

101 The maximum depth of the tree. 

102 """ 

103 return round(math.log2(tree.shape[-1])) 1ai

104 

105 

106def traverse_tree( 

107 x: UInt[Array, ' p'], 

108 var_tree: UInt[Array, ' 2**(d-1)'], 

109 split_tree: UInt[Array, ' 2**(d-1)'], 

110) -> UInt[Array, '']: 

111 """ 

112 Find the leaf where a point falls into. 

113 

114 Parameters 

115 ---------- 

116 x 

117 The coordinates to evaluate the tree at. 

118 var_tree 

119 The decision axes of the tree. 

120 split_tree 

121 The decision boundaries of the tree. 

122 

123 Returns 

124 ------- 

125 The index of the leaf. 

126 """ 

127 carry = ( 1agc

128 jnp.zeros((), bool), 

129 jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)), 

130 ) 

131 

132 def loop( 1agc

133 carry: tuple[Bool[Array, ''], UInt[Array, '']], _: None 

134 ) -> tuple[tuple[Bool[Array, ''], UInt[Array, '']], None]: 

135 leaf_found, index = carry 1agc

136 

137 split = split_tree[index] 1agc

138 var = var_tree[index] 1agc

139 

140 leaf_found |= split == 0 1agc

141 child_index = (index << 1) + (x[var] >= split) 1agc

142 index = jnp.where(leaf_found, index, child_index) 1agc

143 

144 return (leaf_found, index), None 1agc

145 

146 depth = tree_depth(var_tree) 1agc

147 (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16) 1agc

148 return index 1agc

149 

150 

151@jit 

152@partial(jnp.vectorize, excluded=(0,), signature='(hts),(hts)->(n)') 

153@partial(vmap_nodoc, in_axes=(1, None, None)) 

154def traverse_forest( 

155 X: UInt[Array, 'p n'], 

156 var_trees: UInt[Array, '*forest_shape 2**(d-1)'], 

157 split_trees: UInt[Array, '*forest_shape 2**(d-1)'], 

158) -> UInt[Array, '*forest_shape n']: 

159 """ 

160 Find the leaves where points falls into for each tree in a set. 

161 

162 Parameters 

163 ---------- 

164 X 

165 The coordinates to evaluate the trees at. 

166 var_trees 

167 The decision axes of the trees. 

168 split_trees 

169 The decision boundaries of the trees. 

170 

171 Returns 

172 ------- 

173 The indices of the leaves. 

174 """ 

175 return traverse_tree(X, var_trees, split_trees) 1agc

176 

177 

178@partial(jit, static_argnames=('sum_batch_axis',)) 

179def evaluate_forest( 

180 X: UInt[Array, 'p n'], 

181 trees: TreeHeaps, 

182 *, 

183 sum_batch_axis: int | tuple[int, ...] = (), 

184) -> ( 

185 Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n'] 

186): 

187 """ 

188 Evaluate an ensemble of trees at an array of points. 

189 

190 Parameters 

191 ---------- 

192 X 

193 The coordinates to evaluate the trees at. 

194 trees 

195 The trees. 

196 sum_batch_axis 

197 The batch axes to sum over. By default, no summation is performed. 

198 Note that negative indices count from the end of the batch dimensions, 

199 the core dimensions n and k can't be summed over by this function. 

200 

201 Returns 

202 ------- 

203 The (sum of) the values of the trees at the points in `X`. 

204 """ 

205 indices: UInt[Array, '*forest_shape n'] 

206 indices = traverse_forest(X, trees.var_tree, trees.split_tree) 1agc

207 

208 is_mv = trees.leaf_tree.ndim != trees.var_tree.ndim 1agc

209 

210 bc_indices: UInt[Array, '*forest_shape n 1'] | UInt[Array, '*forest_shape 1 n 1'] 

211 bc_indices = indices[..., None, :, None] if is_mv else indices[..., None] 1algc

212 

213 bc_leaf_tree: ( 

214 Float32[Array, '*forest_shape 1 tree_size'] 

215 | Float32[Array, '*forest_shape k 1 tree_size'] 

216 ) 

217 bc_leaf_tree = ( 1algc

218 trees.leaf_tree[..., :, None, :] if is_mv else trees.leaf_tree[..., None, :] 

219 ) 

220 

221 bc_leaves: ( 

222 Float32[Array, '*forest_shape n 1'] | Float32[Array, '*forest_shape k n 1'] 

223 ) 

224 bc_leaves = jnp.take_along_axis(bc_leaf_tree, bc_indices, -1) 1algc

225 

226 leaves: Float32[Array, '*forest_shape n'] | Float32[Array, '*forest_shape k n'] 

227 leaves = jnp.squeeze(bc_leaves, -1) 1agc

228 

229 axis = normalize_axis_tuple(sum_batch_axis, trees.var_tree.ndim - 1) 1agc

230 return jnp.sum(leaves, axis=axis) 1agc

231 

232 

233def is_actual_leaf( 

234 split_tree: UInt[Array, ' 2**(d-1)'], *, add_bottom_level: bool = False 

235) -> Bool[Array, ' 2**(d-1)'] | Bool[Array, ' 2**d']: 

236 """ 

237 Return a mask indicating the leaf nodes in a tree. 

238 

239 Parameters 

240 ---------- 

241 split_tree 

242 The splitting points of the tree. 

243 add_bottom_level 

244 If True, the bottom level of the tree is also considered. 

245 

246 Returns 

247 ------- 

248 The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True. 

249 """ 

250 size = split_tree.size 1ai

251 is_leaf = split_tree == 0 1ai

252 if add_bottom_level: 1aim

253 size *= 2 1am

254 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 1am

255 index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1)) 1ai

256 parent_index = index >> 1 1ai

257 parent_nonleaf = split_tree[parent_index].astype(bool) 1ai

258 parent_nonleaf = parent_nonleaf.at[1].set(True) 1ai

259 return is_leaf & parent_nonleaf 1ai

260 

261 

262def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(d-1)']: 

263 """ 

264 Return a mask indicating the nodes with leaf (and only leaf) children. 

265 

266 Parameters 

267 ---------- 

268 split_tree 

269 The decision boundaries of the tree. 

270 

271 Returns 

272 ------- 

273 The mask indicating which nodes have leaf children. 

274 """ 

275 index = jnp.arange( 1ai

276 split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1) 

277 ) 

278 left_index = index << 1 # left child 1ai

279 right_index = left_index + 1 # right child 1ai

280 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 1ai

281 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 1ai

282 is_not_leaf = split_tree.astype(bool) 1ai

283 return is_not_leaf & left_leaf & right_leaf 1ai

284 # the 0-th item has split == 0, so it's not counted 

285 

286 

287def tree_depths(tree_size: int) -> Int32[Array, ' {tree_size}']: 

288 """ 

289 Return the depth of each node in a binary tree. 

290 

291 Parameters 

292 ---------- 

293 tree_size 

294 The length of the tree array, i.e., 2 ** d. 

295 

296 Returns 

297 ------- 

298 The depth of each node. 

299 

300 Notes 

301 ----- 

302 The root node (index 1) has depth 0. The depth is the position of the most 

303 significant non-zero bit in the index. The first element (the unused node) 

304 is marked as depth 0. 

305 """ 

306 depths = [] 1aj

307 depth = 0 1aj

308 for i in range(tree_size): 1aj

309 if i == 2**depth: 1aj

310 depth += 1 1aj

311 depths.append(depth - 1) 1aj

312 depths[0] = 0 1aj

313 return jnp.array(depths, minimal_unsigned_dtype(max(depths))) 1aj

314 

315 

316@partial(jnp.vectorize, signature='(half_tree_size)->(tree_size)') 

317def is_used( 

318 split_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

319) -> Bool[Array, '*batch_shape 2**d']: 

320 """ 

321 Return a mask indicating the used nodes in a tree. 

322 

323 Parameters 

324 ---------- 

325 split_tree 

326 The decision boundaries of the tree. 

327 

328 Returns 

329 ------- 

330 A mask indicating which nodes are actually used. 

331 """ 

332 internal_node = split_tree.astype(bool) 1ahc

333 internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)]) 1ahc

334 actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ahc

335 return internal_node | actual_leaf 1ahc

336 

337 

338@jit 

339def forest_fill(split_tree: UInt[Array, '*batch_shape 2**(d-1)']) -> Float32[Array, '']: 

340 """ 

341 Return the fraction of used nodes in a set of trees. 

342 

343 Parameters 

344 ---------- 

345 split_tree 

346 The decision boundaries of the trees. 

347 

348 Returns 

349 ------- 

350 Number of tree nodes over the maximum number that could be stored. 

351 """ 

352 used = is_used(split_tree) 1ahc

353 count = jnp.count_nonzero(used) 1ahc

354 batch_size = split_tree.size // split_tree.shape[-1] 1ahc

355 return count / (used.size - batch_size) 1ahc

356 

357 

358@partial(jit, static_argnames=('p', 'sum_batch_axis')) 

359def var_histogram( 

360 p: int, 

361 var_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

362 split_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

363 *, 

364 sum_batch_axis: int | tuple[int, ...] = (), 

365) -> Int32[Array, '*reduced_batch_shape {p}']: 

366 """ 

367 Count how many times each variable appears in a tree. 

368 

369 Parameters 

370 ---------- 

371 p 

372 The number of variables (the maximum value that can occur in `var_tree` 

373 is ``p - 1``). 

374 var_tree 

375 The decision axes of the tree. 

376 split_tree 

377 The decision boundaries of the tree. 

378 sum_batch_axis 

379 The batch axes to sum over. By default, no summation is performed. Note 

380 that negative indices count from the end of the batch dimensions, the 

381 core dimension p can't be summed over by this function. 

382 

383 Returns 

384 ------- 

385 The histogram(s) of the variables used in the tree. 

386 """ 

387 is_internal = split_tree.astype(bool) 1ahc

388 

389 def scatter_add( 1ahc

390 var_tree: UInt[Array, '*summed_batch_axes half_tree_size'], 

391 is_internal: Bool[Array, '*summed_batch_axes half_tree_size'], 

392 ) -> Int32[Array, ' p']: 

393 return jnp.zeros(p, int).at[var_tree].add(is_internal) 1ahc

394 

395 # vmap scatter_add over non-batched dims 

396 batch_ndim = var_tree.ndim - 1 1ahc

397 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) 1ahc

398 for i in reversed(range(batch_ndim)): 1ahc

399 neg_i = i - var_tree.ndim 1ahc

400 if i not in axes: 1pahqc

401 scatter_add = vmap(scatter_add, in_axes=neg_i) 1pqc

402 

403 return scatter_add(var_tree, is_internal) 1ahc

404 

405 

406def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str: 

407 """Convert a tree to a human-readable string. 

408 

409 Parameters 

410 ---------- 

411 tree 

412 A single tree to format. 

413 print_all 

414 If `True`, also print the contents of unused node slots in the arrays. 

415 

416 Returns 

417 ------- 

418 A string representation of the tree. 

419 """ 

420 tee = '├──' 1d

421 corner = '└──' 1d

422 join = '│ ' 1d

423 space = ' ' 1d

424 down = '┐' 1d

425 bottom = '╢' # '┨' # 1d

426 

427 def traverse_tree( 1d

428 lines: list[str], 

429 index: int, 

430 depth: int, 

431 indent: str, 

432 first_indent: str, 

433 next_indent: str, 

434 unused: bool, 

435 ) -> None: 

436 if index >= len(tree.leaf_tree): 436 ↛ 437line 436 didn't jump to line 437 because the condition on line 436 was never true1d

437 return 

438 

439 var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item() 1d

440 split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item() 1d

441 

442 is_leaf = split == 0 1d

443 left_child = 2 * index 1d

444 right_child = 2 * index + 1 1d

445 

446 if print_all: 446 ↛ 447line 446 didn't jump to line 447 because the condition on line 446 was never true1d

447 if unused: 

448 category = 'unused' 

449 elif is_leaf: 

450 category = 'leaf' 

451 else: 

452 category = 'decision' 

453 node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})' 

454 else: 

455 assert not unused 1d

456 if is_leaf: 1d

457 node_str = f'{tree.leaf_tree[index]:#.2g}' 1d

458 else: 

459 node_str = f'x{var} < {split}' 1d

460 

461 if not is_leaf or (print_all and left_child < len(tree.leaf_tree)): 1d

462 link = down 1d

463 elif not print_all and left_child >= len(tree.leaf_tree): 1d

464 link = bottom 1d

465 else: 

466 link = ' ' 1d

467 

468 max_number = len(tree.leaf_tree) - 1 1d

469 ndigits = len(str(max_number)) 1d

470 number = str(index).rjust(ndigits) 1d

471 

472 lines.append(f' {number} {indent}{first_indent}{link}{node_str}') 1d

473 

474 indent += next_indent 1d

475 unused = unused or is_leaf 1d

476 

477 if unused and not print_all: 1d

478 return 1d

479 

480 traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused) 1d

481 traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused) 1d

482 

483 lines = [] 1d

484 traverse_tree(lines, 1, 0, '', '', '', False) 1d

485 return '\n'.join(lines) 1d

486 

487 

488def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']: 

489 """Measure the depth of the tree. 

490 

491 Parameters 

492 ---------- 

493 split_tree 

494 The cutpoints of the decision rules. 

495 

496 Returns 

497 ------- 

498 The depth of the deepest leaf in the tree. The root is at depth 0. 

499 """ 

500 # this could be done just with split_tree != 0 

501 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1k

502 depth = tree_depths(is_leaf.size) 1k

503 depth = jnp.where(is_leaf, depth, 0) 1k

504 return jnp.max(depth) 1k

505 

506 

507@jit 

508@partial(jnp.vectorize, signature='(nt,hts)->(d)') 

509def forest_depth_distr( 

510 split_tree: UInt[Array, '*batch_shape num_trees 2**(d-1)'], 

511) -> Int32[Array, '*batch_shape d']: 

512 """Histogram the depths of a set of trees. 

513 

514 Parameters 

515 ---------- 

516 split_tree 

517 The cutpoints of the decision rules of the trees. 

518 

519 Returns 

520 ------- 

521 An integer vector where the i-th element counts how many trees have depth i. 

522 """ 

523 depth = tree_depth(split_tree) + 1 1k

524 depths = vmap(tree_actual_depth)(split_tree) 1k

525 return jnp.bincount(depths, length=depth) 1k

526 

527 

528@partial(jit, static_argnames=('node_type', 'sum_batch_axis')) 

529def points_per_node_distr( 

530 X: UInt[Array, 'p n'], 

531 var_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

532 split_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

533 node_type: Literal['leaf', 'leaf-parent'], 

534 *, 

535 sum_batch_axis: int | tuple[int, ...] = (), 

536) -> Int32[Array, '*reduced_batch_shape n+1']: 

537 """Histogram points-per-node counts in a set of trees. 

538 

539 Count how many nodes in a tree select each possible amount of points, 

540 over a certain subset of nodes. 

541 

542 Parameters 

543 ---------- 

544 X 

545 The set of points to count. 

546 var_tree 

547 The variables of the decision rules. 

548 split_tree 

549 The cutpoints of the decision rules. 

550 node_type 

551 The type of nodes to consider. Can be: 

552 

553 'leaf' 

554 Count only leaf nodes. 

555 'leaf-parent' 

556 Count only parent-of-leaf nodes. 

557 sum_batch_axis 

558 Aggregate the histogram over these batch axes, counting how many nodes 

559 have each possible amount of points over subsets of trees instead of 

560 in each tree separately. 

561 

562 Returns 

563 ------- 

564 A vector where the i-th element counts how many nodes have i points. 

565 """ 

566 batch_ndim = var_tree.ndim - 1 1ef

567 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) 1ef

568 

569 def func( 1ef

570 var_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

571 split_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

572 ) -> Int32[Array, '*reduced_batch_shape n+1']: 

573 indices: UInt[Array, '*batch_shape n'] 

574 indices = traverse_forest(X, var_tree, split_tree) 1ef

575 

576 @partial(jnp.vectorize, signature='(hts),(n)->(ts_or_hts),(ts_or_hts)') 1ef

577 def count_points( 1ef

578 split_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

579 indices: UInt[Array, '*batch_shape n'], 

580 ) -> ( 

581 tuple[UInt[Array, '*batch_shape 2**d'], Bool[Array, '*batch_shape 2**d']] 

582 | tuple[ 

583 UInt[Array, '*batch_shape 2**(d-1)'], 

584 Bool[Array, '*batch_shape 2**(d-1)'], 

585 ] 

586 ): 

587 if node_type == 'leaf-parent': 1enfo

588 indices >>= 1 1ef

589 predicate = is_leaves_parent(split_tree) 1ef

590 elif node_type == 'leaf': 590 ↛ 593line 590 didn't jump to line 593 because the condition on line 590 was always true1no

591 predicate = is_actual_leaf(split_tree, add_bottom_level=True) 1no

592 else: 

593 raise ValueError(node_type) 

594 count_tree = jnp.zeros(predicate.size, int).at[indices].add(1).at[0].set(0) 1ef

595 return count_tree, predicate 1ef

596 

597 count_tree, predicate = count_points(split_tree, indices) 1ef

598 

599 def count_nodes( 1ef

600 count_tree: UInt[Array, '*summed_batch_axes half_tree_size'], 

601 predicate: Bool[Array, '*summed_batch_axes half_tree_size'], 

602 ) -> Int32[Array, ' n+1']: 

603 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(predicate) 1ef

604 

605 # vmap count_nodes over non-batched dims 

606 for i in reversed(range(batch_ndim)): 1ef

607 neg_i = i - var_tree.ndim 1ef

608 if i not in axes: 1ef

609 count_nodes = vmap(count_nodes, in_axes=neg_i) 1ef

610 

611 return count_nodes(count_tree, predicate) 1ef

612 

613 # automatically batch over all batch dimensions 

614 max_io_nbytes = 2**27 # 128 MiB 1ef

615 out_dim_shift = len(axes) 1ef

616 for i in reversed(range(batch_ndim)): 1ef

617 if i in axes: 1ef

618 out_dim_shift -= 1 1ef

619 else: 

620 func = autobatch(func, max_io_nbytes, i, i - out_dim_shift) 1ef

621 assert out_dim_shift == 0 1ef

622 

623 return func(var_tree, split_tree) 1ef