Coverage for src/bartz/grove/_grove.py: 94%

194 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-07-02 09:03 +0000

1# bartz/src/bartz/grove/_grove.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions to create and manipulate binary decision trees.""" 

26 

27import math 

28from dataclasses import fields 

29from functools import partial 

30from typing import Literal, Protocol, runtime_checkable 

31 

32from equinox import tree_at 

33from jax import numpy as jnp 

34from jax import vmap 

35from jaxtyping import Array, Bool, Float32, Int32, Shaped, UInt 

36from numpy.lib.array_utils import normalize_axis_tuple 

37 

38from bartz._jaxext import Module, autobatch, jit, minimal_unsigned_dtype, vmap_nodoc 

39 

40 

41@runtime_checkable 

42class TreeHeaps(Protocol): 

43 """A protocol for dataclasses that represent trees. 

44 

45 A tree is represented with arrays as a heap. The root node is at index 1. 

46 The children nodes of a node at index :math:`i` are at indices :math:`2i` 

47 (left child) and :math:`2i + 1` (right child). The array element at index 0 

48 is unused. 

49 

50 Since the nodes at the bottom can only be leaves and not decision nodes, 

51 `var_tree` and `split_tree` are half as long as `leaf_tree`. 

52 

53 Arrays may have additional initial axes to represent multiple trees. 

54 """ 

55 

56 leaf_tree: ( 

57 Float32[Array, '*batch_shape 2*half_tree_size'] 

58 | Float32[Array, '*batch_shape k 2*half_tree_size'] 

59 ) 

60 """The values in the leaves of the trees. This array can be dirty, i.e., 

61 unused nodes can have whatever value. It may have an additional axis 

62 for multivariate leaves.""" 

63 

64 var_tree: UInt[Array, '*batch_shape half_tree_size'] 

65 """The axes along which the decision nodes operate. This array can be 

66 dirty but for the always unused node at index 0 which must be set to 0.""" 

67 

68 split_tree: UInt[Array, '*batch_shape half_tree_size'] 

69 """The decision boundaries of the trees. The boundaries are open on the 

70 right, i.e., a point belongs to the left child iff x < split. Whether a 

71 node is a leaf is indicated by the corresponding 'split' element being 

72 0. Unused nodes also have split set to 0. This array can't be dirty.""" 

73 

74 

75def is_multivariate(trees: TreeHeaps) -> bool: 

76 """ 

77 Return whether the trees have vector-valued leaves. 

78 

79 Parameters 

80 ---------- 

81 trees 

82 The trees to inspect. 

83 

84 Returns 

85 ------- 

86 Whether the leaves are vector-valued (an extra `k` axis on `leaf_tree`). 

87 """ 

88 return trees.leaf_tree.ndim > trees.var_tree.ndim 

89 

90 

91class TreesTrace(Module): 

92 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace.""" 

93 

94 # `var_tree`/`split_tree` are declared before `leaf_tree` so their single 

95 # (union-free) annotations bind the variadic `*batch_shape` axis first; 

96 # otherwise the runtime typechecker (which evaluates union members in a 

97 # hash-randomized order) can mis-bind it against the `k` axis of 

98 # `leaf_tree`'s union for a multivariate tree (the layouts are 

99 # rank-ambiguous). See `bartz.mcmcstep._state.Forest`. The leaf-bearing axis 

100 # is `2*half_tree_size` rather than `tree_size`, so the half-of-leaf 

101 # relationship is still checked here: `half_tree_size` is bound first by the 

102 # anchors, then `leaf_tree` is checked against twice it. 

103 var_tree: UInt[Array, '*batch_shape half_tree_size'] 

104 """The axes along which the decision nodes operate. This array can be 

105 dirty but for the always unused node at index 0 which must be set to 0.""" 

106 

107 split_tree: UInt[Array, '*batch_shape half_tree_size'] 

108 """The decision boundaries of the trees. The boundaries are open on the 

109 right, i.e., a point belongs to the left child iff x < split. Whether a 

110 node is a leaf is indicated by the corresponding 'split' element being 

111 0. Unused nodes also have split set to 0. This array can't be dirty.""" 

112 

113 leaf_tree: ( 

114 Float32[Array, '*batch_shape 2*half_tree_size'] 

115 | Float32[Array, '*batch_shape k 2*half_tree_size'] 

116 ) 

117 """The values in the leaves of the trees. This array can be dirty, i.e., 

118 unused nodes can have whatever value. It may have an additional axis 

119 for multivariate leaves.""" 

120 

121 @classmethod 

122 def from_dataclass(cls, obj: TreeHeaps) -> 'TreesTrace': 

123 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`.""" 

124 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) 

125 

126 def axes_from_dataclass(self, obj: TreeHeaps) -> 'TreesTrace': 

127 """Project the per-field vmap axis specs of `obj` onto this template. 

128 

129 `self` supplies the (array) pytree; the same-named fields of `obj` 

130 (axis specs, i.e. ints or `None`) replace its leaves. Built with 

131 `equinox.tree_at`, which bypasses the type-checked `__init__`, so the 

132 deliberately off-type axis values are allowed. 

133 """ 

134 names = [f.name for f in fields(type(self))] 

135 return tree_at( 

136 lambda t: [getattr(t, name) for name in names], 

137 self, 

138 [getattr(obj, name) for name in names], 

139 ) 

140 

141 

142def tree_depth(tree: Shaped[Array, '*batch_shape tree_size']) -> int: 

143 """ 

144 Return the maximum depth of a tree. 

145 

146 Parameters 

147 ---------- 

148 tree 

149 A tree array like those in a `TreeHeaps`. If the array is ND, the tree 

150 structure is assumed to be along the last axis. 

151 

152 Returns 

153 ------- 

154 The maximum depth of the tree. 

155 """ 

156 return round(math.log2(tree.shape[-1])) 

157 

158 

159def traverse_tree( 

160 x: UInt[Array, ' p'], 

161 var_tree: UInt[Array, ' half_tree_size'], 

162 split_tree: UInt[Array, ' half_tree_size'], 

163) -> UInt[Array, '']: 

164 """ 

165 Find the leaf where a point falls into. 

166 

167 Parameters 

168 ---------- 

169 x 

170 The coordinates to evaluate the tree at. 

171 var_tree 

172 The decision axes of the tree. 

173 split_tree 

174 The decision boundaries of the tree. 

175 

176 Returns 

177 ------- 

178 The index of the leaf. 

179 """ 

180 leaf_found = jnp.zeros((), bool) 

181 index = jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)) 

182 

183 # the depth is a small static integer, so a plain python loop is equivalent 

184 # to (and clearer than) a fully-unrolled lax.scan 

185 for _ in range(tree_depth(var_tree)): 

186 split = split_tree[index] 

187 var = var_tree[index] 

188 

189 leaf_found |= split == 0 

190 child_index = (index << 1) + (x[var] >= split) 

191 index = jnp.where(leaf_found, index, child_index) 

192 

193 return index 

194 

195 

196@jit 

197def traverse_forest( 

198 X: UInt[Array, 'p n'], 

199 var_trees: UInt[Array, '*forest_shape half_tree_size'], 

200 split_trees: UInt[Array, '*forest_shape half_tree_size'], 

201) -> UInt[Array, '*forest_shape n']: 

202 """ 

203 Find the leaves where points falls into for each tree in a set. 

204 

205 Parameters 

206 ---------- 

207 X 

208 The coordinates to evaluate the trees at. 

209 var_trees 

210 The decision axes of the trees. 

211 split_trees 

212 The decision boundaries of the trees. 

213 

214 Returns 

215 ------- 

216 The indices of the leaves. 

217 """ 

218 return _traverse_forest(X, var_trees, split_trees) 

219 

220 

221@partial(jnp.vectorize, excluded=(0,), signature='(hts),(hts)->(n)') 

222@partial(vmap_nodoc, in_axes=(1, None, None)) 

223def _traverse_forest( 

224 X: UInt[Array, ' p'], 

225 var_trees: UInt[Array, ' half_tree_size'], 

226 split_trees: UInt[Array, ' half_tree_size'], 

227) -> UInt[Array, '']: 

228 """Implement `traverse_forest`.""" 

229 return traverse_tree(X, var_trees, split_trees) 

230 

231 

232@jit(static_argnames=('sum_batch_axis',)) 

233def evaluate_forest( 

234 X: UInt[Array, 'p n'], 

235 trees: TreeHeaps, 

236 *, 

237 sum_batch_axis: int | tuple[int, ...] = (), 

238) -> ( 

239 Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n'] 

240): 

241 """ 

242 Evaluate an ensemble of trees at an array of points. 

243 

244 Parameters 

245 ---------- 

246 X 

247 The coordinates to evaluate the trees at. 

248 trees 

249 The trees. 

250 sum_batch_axis 

251 The batch axes to sum over. By default, no summation is performed. 

252 Note that negative indices count from the end of the batch dimensions, 

253 the core dimensions n and k can't be summed over by this function. 

254 

255 Returns 

256 ------- 

257 The (sum of) the values of the trees at the points in `X`. 

258 """ 

259 indices: UInt[Array, '*forest_shape n'] 

260 indices = traverse_forest(X, trees.var_tree, trees.split_tree) 

261 

262 is_mv = is_multivariate(trees) 

263 

264 bc_indices: UInt[Array, '*forest_shape n 1'] | UInt[Array, '*forest_shape 1 n 1'] 

265 bc_indices = indices[..., None, :, None] if is_mv else indices[..., None] 

266 

267 bc_leaf_tree: ( 

268 Float32[Array, '*forest_shape 1 tree_size'] 

269 | Float32[Array, '*forest_shape k 1 tree_size'] 

270 ) 

271 bc_leaf_tree = ( 

272 trees.leaf_tree[..., :, None, :] if is_mv else trees.leaf_tree[..., None, :] 

273 ) 

274 

275 bc_leaves: ( 

276 Float32[Array, '*forest_shape n 1'] | Float32[Array, '*forest_shape k n 1'] 

277 ) 

278 bc_leaves = jnp.take_along_axis(bc_leaf_tree, bc_indices, -1) 

279 

280 leaves: Float32[Array, '*forest_shape n'] | Float32[Array, '*forest_shape k n'] 

281 leaves = jnp.squeeze(bc_leaves, -1) 

282 

283 axis = normalize_axis_tuple(sum_batch_axis, trees.var_tree.ndim - 1) 

284 return jnp.sum(leaves, axis=axis) 

285 

286 

287def is_actual_leaf( 

288 split_tree: UInt[Array, ' half_tree_size'], *, add_bottom_level: bool = False 

289) -> Bool[Array, ' half_tree_size'] | Bool[Array, ' 2*half_tree_size']: 

290 """ 

291 Return a mask indicating the leaf nodes in a tree. 

292 

293 Parameters 

294 ---------- 

295 split_tree 

296 The splitting points of the tree. 

297 add_bottom_level 

298 If True, the bottom level of the tree is also considered. 

299 

300 Returns 

301 ------- 

302 The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True. 

303 """ 

304 size = split_tree.size 

305 is_leaf = split_tree == 0 

306 if add_bottom_level: 

307 size *= 2 

308 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 

309 index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1)) 

310 parent_index = index >> 1 

311 parent_nonleaf = split_tree[parent_index].astype(bool) 

312 parent_nonleaf = parent_nonleaf.at[1].set(True) 

313 return is_leaf & parent_nonleaf 

314 

315 

316def is_leaves_parent( 

317 split_tree: UInt[Array, ' half_tree_size'], 

318) -> Bool[Array, ' half_tree_size']: 

319 """ 

320 Return a mask indicating the nodes with leaf (and only leaf) children. 

321 

322 Parameters 

323 ---------- 

324 split_tree 

325 The decision boundaries of the tree. 

326 

327 Returns 

328 ------- 

329 The mask indicating which nodes have leaf children. 

330 """ 

331 index = jnp.arange( 

332 split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1) 

333 ) 

334 left_index = index << 1 # left child 

335 right_index = left_index + 1 # right child 

336 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 

337 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 

338 is_not_leaf = split_tree.astype(bool) 

339 return is_not_leaf & left_leaf & right_leaf 

340 # the 0-th item has split == 0, so it's not counted 

341 

342 

343def tree_depths(tree_size: int) -> UInt[Array, ' {tree_size}']: 

344 """ 

345 Return the depth of each node in a binary tree. 

346 

347 Parameters 

348 ---------- 

349 tree_size 

350 The length of the tree array, i.e., 2 ** d. 

351 

352 Returns 

353 ------- 

354 The depth of each node. 

355 

356 Notes 

357 ----- 

358 The root node (index 1) has depth 0. The depth is the position of the most 

359 significant non-zero bit in the index. The first element (the unused node) 

360 is marked as depth 0. 

361 """ 

362 depths = [] 

363 depth = 0 

364 for i in range(tree_size): 

365 if i == 2**depth: 

366 depth += 1 

367 depths.append(depth - 1) 

368 depths[0] = 0 

369 return jnp.array(depths, minimal_unsigned_dtype(max(depths))) 

370 

371 

372@jit 

373def forest_mean_leaves( 

374 split_tree: UInt[Array, '*batch_shape half_tree_size'], 

375) -> Float32[Array, '']: 

376 """ 

377 Return the average number of leaves per tree in a set of trees. 

378 

379 Parameters 

380 ---------- 

381 split_tree 

382 The decision boundaries of the trees. 

383 

384 Returns 

385 ------- 

386 The mean number of leaves across the trees. 

387 """ 

388 # a tree with k internal nodes (the nonzero entries of split_tree) has k + 1 

389 # leaves; the maximum possible is split_tree.shape[-1] 

390 num_internal = jnp.count_nonzero(split_tree, axis=-1) 

391 return (num_internal + 1).mean() 

392 

393 

394@jit(static_argnames=('p', 'sum_batch_axis')) 

395def var_histogram( 

396 p: int, 

397 var_tree: UInt[Array, '*batch_shape half_tree_size'], 

398 split_tree: UInt[Array, '*batch_shape half_tree_size'], 

399 *, 

400 sum_batch_axis: int | tuple[int, ...] = (), 

401) -> Int32[Array, '*reduced_batch_shape {p}']: 

402 """ 

403 Count how many times each variable appears in a tree. 

404 

405 Parameters 

406 ---------- 

407 p 

408 The number of variables (the maximum value that can occur in `var_tree` 

409 is ``p - 1``). 

410 var_tree 

411 The decision axes of the tree. 

412 split_tree 

413 The decision boundaries of the tree. 

414 sum_batch_axis 

415 The batch axes to sum over. By default, no summation is performed. Note 

416 that negative indices count from the end of the batch dimensions, the 

417 core dimension p can't be summed over by this function. 

418 

419 Returns 

420 ------- 

421 The histogram(s) of the variables used in the tree. 

422 """ 

423 is_internal = split_tree.astype(bool) 

424 

425 def scatter_add( 

426 var_tree: UInt[Array, '*summed_batch_axes half_tree_size'], 

427 is_internal: Bool[Array, '*summed_batch_axes half_tree_size'], 

428 ) -> Int32[Array, ' p']: 

429 return jnp.zeros(p, int).at[var_tree].add(is_internal) 

430 

431 # vmap scatter_add over non-batched dims 

432 batch_ndim = var_tree.ndim - 1 

433 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) 

434 for i in reversed(range(batch_ndim)): 

435 neg_i = i - var_tree.ndim 

436 if i not in axes: 

437 scatter_add = vmap(scatter_add, in_axes=neg_i) 

438 

439 return scatter_add(var_tree, is_internal) 

440 

441 

442def _format_leaf(leaf: Float32[Array, ''] | Float32[Array, ' k'], is_mv: bool) -> str: 

443 """Format a (possibly multivariate) leaf value to 2 significant digits.""" 

444 if is_mv: 

445 return '[' + ', '.join(f'{v:#.2g}' for v in leaf) + ']' 

446 return f'{leaf:#.2g}' 

447 

448 

449def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str: 

450 """Convert a tree to a human-readable string. 

451 

452 Parameters 

453 ---------- 

454 tree 

455 A single tree to format. 

456 print_all 

457 If `True`, also print the contents of unused node slots in the arrays. 

458 

459 Returns 

460 ------- 

461 A string representation of the tree. 

462 """ 

463 tee = '├──' 

464 corner = '└──' 

465 join = '│ ' 

466 space = ' ' 

467 down = '┐' 

468 bottom = '╢' # '┨' # 

469 

470 *_, tree_size = tree.leaf_tree.shape 

471 is_mv = is_multivariate(tree) 

472 

473 def traverse_tree( 

474 lines: list[str], 

475 index: int, 

476 depth: int, 

477 indent: str, 

478 first_indent: str, 

479 next_indent: str, 

480 unused: bool, 

481 ) -> None: 

482 if index >= tree_size: 482 ↛ 483line 482 didn't jump to line 483 because the condition on line 482 was never true

483 return 

484 

485 var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item() 

486 split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item() 

487 

488 is_leaf = split == 0 

489 left_child = 2 * index 

490 right_child = 2 * index + 1 

491 

492 if print_all: 492 ↛ 493line 492 didn't jump to line 493 because the condition on line 492 was never true

493 if unused: 

494 category = 'unused' 

495 elif is_leaf: 

496 category = 'leaf' 

497 else: 

498 category = 'decision' 

499 node_str = f'{category}({var}, {split}, {tree.leaf_tree[..., index]})' 

500 else: 

501 assert not unused 

502 if is_leaf: 

503 node_str = _format_leaf(tree.leaf_tree[..., index], is_mv) 

504 else: 

505 node_str = f'x{var} < {split}' 

506 

507 if not is_leaf or (print_all and left_child < tree_size): 

508 link = down 

509 elif not print_all and left_child >= tree_size: 

510 link = bottom 

511 else: 

512 link = ' ' 

513 

514 max_number = tree_size - 1 

515 ndigits = len(str(max_number)) 

516 number = str(index).rjust(ndigits) 

517 

518 lines.append(f' {number} {indent}{first_indent}{link}{node_str}') 

519 

520 indent += next_indent 

521 unused = unused or is_leaf 

522 

523 if unused and not print_all: 

524 return 

525 

526 traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused) 

527 traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused) 

528 

529 lines = [] 

530 traverse_tree(lines, 1, 0, '', '', '', False) 

531 return '\n'.join(lines) 

532 

533 

534def tree_actual_depth(split_tree: UInt[Array, ' half_tree_size']) -> UInt[Array, '']: 

535 """Measure the depth of the tree. 

536 

537 Parameters 

538 ---------- 

539 split_tree 

540 The cutpoints of the decision rules. 

541 

542 Returns 

543 ------- 

544 The depth of the deepest leaf in the tree. The root is at depth 0. 

545 """ 

546 # this could be done just with split_tree != 0 

547 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 

548 depth = tree_depths(is_leaf.size) 

549 depth = jnp.where(is_leaf, depth, 0) 

550 return jnp.max(depth) 

551 

552 

553@jit 

554@partial(jnp.vectorize, signature='(nt,hts)->(d)') 

555def forest_depth_distr( 

556 split_tree: UInt[Array, '*batch_shape num_trees half_tree_size'], 

557) -> Int32[Array, '*batch_shape d']: 

558 """Histogram the depths of a set of trees. 

559 

560 Parameters 

561 ---------- 

562 split_tree 

563 The cutpoints of the decision rules of the trees. 

564 

565 Returns 

566 ------- 

567 An integer vector where the i-th element counts how many trees have depth i. 

568 """ 

569 depth = tree_depth(split_tree) + 1 

570 depths = vmap(tree_actual_depth)(split_tree) 

571 return jnp.bincount(depths, length=depth) 

572 

573 

574@jit(static_argnames=('node_type', 'sum_batch_axis')) 

575def points_per_node_distr( 

576 X: UInt[Array, 'p n'], 

577 var_tree: UInt[Array, '*batch_shape half_tree_size'], 

578 split_tree: UInt[Array, '*batch_shape half_tree_size'], 

579 node_type: Literal['leaf', 'leaf-parent'], 

580 *, 

581 sum_batch_axis: int | tuple[int, ...] = (), 

582) -> Int32[Array, '*reduced_batch_shape n+1']: 

583 """Histogram points-per-node counts in a set of trees. 

584 

585 Count how many nodes in a tree select each possible amount of points, 

586 over a certain subset of nodes. 

587 

588 Parameters 

589 ---------- 

590 X 

591 The set of points to count. 

592 var_tree 

593 The variables of the decision rules. 

594 split_tree 

595 The cutpoints of the decision rules. 

596 node_type 

597 The type of nodes to consider. Can be: 

598 

599 'leaf' 

600 Count only leaf nodes. 

601 'leaf-parent' 

602 Count only parent-of-leaf nodes. 

603 sum_batch_axis 

604 Aggregate the histogram over these batch axes, counting how many nodes 

605 have each possible amount of points over subsets of trees instead of 

606 in each tree separately. 

607 

608 Returns 

609 ------- 

610 A vector where the i-th element counts how many nodes have i points. 

611 """ 

612 batch_ndim = var_tree.ndim - 1 

613 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) 

614 

615 def func( 

616 var_tree: UInt[Array, '*batch_shape half_tree_size'], 

617 split_tree: UInt[Array, '*batch_shape half_tree_size'], 

618 ) -> Int32[Array, '*reduced_batch_shape n_plus_1']: 

619 indices: UInt[Array, '*batch_shape n'] 

620 indices = traverse_forest(X, var_tree, split_tree) 

621 

622 @partial(jnp.vectorize, signature='(hts),(n)->(ts_or_hts),(ts_or_hts)') 

623 def count_points( 

624 split_tree: UInt[Array, '*batch_shape half_tree_size'], 

625 indices: UInt[Array, '*batch_shape n'], 

626 ) -> ( 

627 tuple[ 

628 Int32[Array, '*batch_shape 2*half_tree_size'], 

629 Bool[Array, '*batch_shape 2*half_tree_size'], 

630 ] 

631 | tuple[ 

632 Int32[Array, '*batch_shape half_tree_size'], 

633 Bool[Array, '*batch_shape half_tree_size'], 

634 ] 

635 ): 

636 if node_type == 'leaf-parent': 

637 indices >>= 1 

638 predicate = is_leaves_parent(split_tree) 

639 elif node_type == 'leaf': 639 ↛ 642line 639 didn't jump to line 642 because the condition on line 639 was always true

640 predicate = is_actual_leaf(split_tree, add_bottom_level=True) 

641 else: 

642 raise ValueError(node_type) 

643 count_tree = jnp.zeros(predicate.size, int).at[indices].add(1).at[0].set(0) 

644 return count_tree, predicate 

645 

646 count_tree, predicate = count_points(split_tree, indices) 

647 

648 def count_nodes( 

649 count_tree: Int32[Array, '*summed_batch_axes half_tree_size'], 

650 predicate: Bool[Array, '*summed_batch_axes half_tree_size'], 

651 ) -> Int32[Array, ' n_plus_1']: 

652 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(predicate) 

653 

654 # vmap count_nodes over non-batched dims 

655 for i in reversed(range(batch_ndim)): 

656 neg_i = i - var_tree.ndim 

657 if i not in axes: 

658 count_nodes = vmap(count_nodes, in_axes=neg_i) 

659 

660 return count_nodes(count_tree, predicate) 

661 

662 # automatically batch over all batch dimensions 

663 max_io_nbytes = 2**27 # 128 MiB 

664 out_dim_shift = len(axes) 

665 batched_func = func 

666 for i in reversed(range(batch_ndim)): 

667 if i in axes: 

668 out_dim_shift -= 1 

669 else: 

670 batched_func = autobatch(batched_func, max_io_nbytes, i, i - out_dim_shift) 

671 assert out_dim_shift == 0 

672 

673 return batched_func(var_tree, split_tree)