Coverage for src / bartz / grove.py: 93%

184 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/grove.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions to create and manipulate binary decision trees.""" 

26 

27import math 

28from functools import partial 

29from typing import Literal, Protocol 

30 

31from jax import jit, lax, vmap 

32from jax import numpy as jnp 

33from jaxtyping import Array, Bool, Float32, Int32, Shaped, UInt 

34 

35try: 

36 from numpy.lib.array_utils import normalize_axis_tuple # numpy 2 

37except ImportError: 

38 from numpy.core.numeric import normalize_axis_tuple # numpy 1 

39 

40from bartz.jaxext import autobatch, minimal_unsigned_dtype, vmap_nodoc 

41 

42 

43class TreeHeaps(Protocol): 

44 """A protocol for dataclasses that represent trees. 

45 

46 A tree is represented with arrays as a heap. The root node is at index 1. 

47 The children nodes of a node at index :math:`i` are at indices :math:`2i` 

48 (left child) and :math:`2i + 1` (right child). The array element at index 0 

49 is unused. 

50 

51 Since the nodes at the bottom can only be leaves and not decision nodes, 

52 `var_tree` and `split_tree` are half as long as `leaf_tree`. 

53 

54 Arrays may have additional initial axes to represent multiple trees. 

55 """ 

56 

57 leaf_tree: ( 

58 Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d'] 

59 ) 

60 """The values in the leaves of the trees. This array can be dirty, i.e., 

61 unused nodes can have whatever value. It may have an additional axis 

62 for multivariate leaves.""" 

63 

64 var_tree: UInt[Array, '*batch_shape 2**(d-1)'] 

65 """The axes along which the decision nodes operate. This array can be 

66 dirty but for the always unused node at index 0 which must be set to 0.""" 

67 

68 split_tree: UInt[Array, '*batch_shape 2**(d-1)'] 

69 """The decision boundaries of the trees. The boundaries are open on the 

70 right, i.e., a point belongs to the left child iff x < split. Whether a 

71 node is a leaf is indicated by the corresponding 'split' element being 

72 0. Unused nodes also have split set to 0. This array can't be dirty.""" 

73 

74 

75def tree_depth(tree: Shaped[Array, '*batch_shape 2**d']) -> int: 

76 """ 

77 Return the maximum depth of a tree. 

78 

79 Parameters 

80 ---------- 

81 tree 

82 A tree array like those in a `TreeHeaps`. If the array is ND, the tree 

83 structure is assumed to be along the last axis. 

84 

85 Returns 

86 ------- 

87 The maximum depth of the tree. 

88 """ 

89 return round(math.log2(tree.shape[-1])) 2M R S T E F G N . O / a : r ; b k c s d y z A l m p q n o e t B I u J f V W P Q v g j - h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib$b%b'b(bjb)b0 1 2 3 4 5 6 7 8 9 ! # *b+b,b-b.b/b:b;b=b?b$ % ' ( ) * + ,

90 

91 

92def traverse_tree( 

93 x: UInt[Array, ' p'], 

94 var_tree: UInt[Array, ' 2**(d-1)'], 

95 split_tree: UInt[Array, ' 2**(d-1)'], 

96) -> UInt[Array, '']: 

97 """ 

98 Find the leaf where a point falls into. 

99 

100 Parameters 

101 ---------- 

102 x 

103 The coordinates to evaluate the tree at. 

104 var_tree 

105 The decision axes of the tree. 

106 split_tree 

107 The decision boundaries of the tree. 

108 

109 Returns 

110 ------- 

111 The index of the leaf. 

112 """ 

113 carry = ( 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

114 jnp.zeros((), bool), 

115 jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)), 

116 ) 

117 

118 def loop( 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

119 carry: tuple[Bool[Array, ''], UInt[Array, '']], _: None 

120 ) -> tuple[tuple[Bool[Array, ''], UInt[Array, '']], None]: 

121 leaf_found, index = carry 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

122 

123 split = split_tree[index] 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

124 var = var_tree[index] 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

125 

126 leaf_found |= split == 0 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

127 child_index = (index << 1) + (x[var] >= split) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

128 index = jnp.where(leaf_found, index, child_index) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

129 

130 return (leaf_found, index), None 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

131 

132 depth = tree_depth(var_tree) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

133 (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

134 return index 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

135 

136 

137@jit 

138@partial(jnp.vectorize, excluded=(0,), signature='(hts),(hts)->(n)') 

139@partial(vmap_nodoc, in_axes=(1, None, None)) 

140def traverse_forest( 

141 X: UInt[Array, 'p n'], 

142 var_trees: UInt[Array, '*forest_shape 2**(d-1)'], 

143 split_trees: UInt[Array, '*forest_shape 2**(d-1)'], 

144) -> UInt[Array, '*forest_shape n']: 

145 """ 

146 Find the leaves where points falls into for each tree in a set. 

147 

148 Parameters 

149 ---------- 

150 X 

151 The coordinates to evaluate the trees at. 

152 var_trees 

153 The decision axes of the trees. 

154 split_trees 

155 The decision boundaries of the trees. 

156 

157 Returns 

158 ------- 

159 The indices of the leaves. 

160 """ 

161 return traverse_tree(X, var_trees, split_trees) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

162 

163 

164@partial(jit, static_argnames=('sum_batch_axis',)) 

165def evaluate_forest( 

166 X: UInt[Array, 'p n'], 

167 trees: TreeHeaps, 

168 *, 

169 sum_batch_axis: int | tuple[int, ...] = (), 

170) -> ( 

171 Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n'] 

172): 

173 """ 

174 Evaluate an ensemble of trees at an array of points. 

175 

176 Parameters 

177 ---------- 

178 X 

179 The coordinates to evaluate the trees at. 

180 trees 

181 The trees. 

182 sum_batch_axis 

183 The batch axes to sum over. By default, no summation is performed. 

184 Note that negative indices count from the end of the batch dimensions, 

185 the core dimensions n and k can't be summed over by this function. 

186 

187 Returns 

188 ------- 

189 The (sum of) the values of the trees at the points in `X`. 

190 """ 

191 indices: UInt[Array, '*forest_shape n'] 

192 indices = traverse_forest(X, trees.var_tree, trees.split_tree) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

193 

194 is_mv = trees.leaf_tree.ndim != trees.var_tree.ndim 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

195 

196 bc_indices: UInt[Array, '*forest_shape n 1'] | UInt[Array, '*forest_shape 1 n 1'] 

197 bc_indices = indices[..., None, :, None] if is_mv else indices[..., None] 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

198 

199 bc_leaf_tree: ( 

200 Float32[Array, '*forest_shape 1 tree_size'] 

201 | Float32[Array, '*forest_shape k 1 tree_size'] 

202 ) 

203 bc_leaf_tree = ( 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

204 trees.leaf_tree[..., :, None, :] if is_mv else trees.leaf_tree[..., None, :] 

205 ) 

206 

207 bc_leaves: ( 

208 Float32[Array, '*forest_shape n 1'] | Float32[Array, '*forest_shape k n 1'] 

209 ) 

210 bc_leaves = jnp.take_along_axis(bc_leaf_tree, bc_indices, -1) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

211 

212 leaves: Float32[Array, '*forest_shape n'] | Float32[Array, '*forest_shape k n'] 

213 leaves = jnp.squeeze(bc_leaves, -1) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

214 

215 axis = normalize_axis_tuple(sum_batch_axis, trees.var_tree.ndim - 1) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

216 return jnp.sum(leaves, axis=axis) 1EFGN.O/a:r;bkcsdyzAlmetBIuJfPQvgjhwi

217 

218 

219def is_actual_leaf( 

220 split_tree: UInt[Array, ' 2**(d-1)'], *, add_bottom_level: bool = False 

221) -> Bool[Array, ' 2**(d-1)'] | Bool[Array, ' 2**d']: 

222 """ 

223 Return a mask indicating the leaf nodes in a tree. 

224 

225 Parameters 

226 ---------- 

227 split_tree 

228 The splitting points of the tree. 

229 add_bottom_level 

230 If True, the bottom level of the tree is also considered. 

231 

232 Returns 

233 ------- 

234 The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True. 

235 """ 

236 size = split_tree.size 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,

237 is_leaf = split_tree == 0 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,

238 if add_bottom_level: 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,

239 size *= 2 2M R S T E F G a r b k c s d U l m n C o e t I u J f v g j - h w i L D = kbx lbmbjbnb

240 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 2M R S T E F G a r b k c s d U l m n C o e t I u J f v g j - h w i L D = kbx lbmbjbnb

241 index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1)) 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,

242 parent_index = index >> 1 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,

243 parent_nonleaf = split_tree[parent_index].astype(bool) 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,

244 parent_nonleaf = parent_nonleaf.at[1].set(True) 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,

245 return is_leaf & parent_nonleaf 2M R S T E F G a r b k c s d y z A U l m p q n C o e t B I u J f V W v g j - h w i X Y Z L D = kbx ] ^ _ ` { | } ~ abbbcbdbebfbgbhbiblbmbjb0 1 2 3 4 5 6 7 8 9 ! # nb$ % ' ( ) * + ,

246 

247 

248def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(d-1)']: 

249 """ 

250 Return a mask indicating the nodes with leaf (and only leaf) children. 

251 

252 Parameters 

253 ---------- 

254 split_tree 

255 The decision boundaries of the tree. 

256 

257 Returns 

258 ------- 

259 The mask indicating which nodes have leaf children. 

260 """ 

261 index = jnp.arange( 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,

262 split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1) 

263 ) 

264 left_index = index << 1 # left child 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,

265 right_index = left_index + 1 # right child 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,

266 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,

267 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,

268 is_not_leaf = split_tree.astype(bool) 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,

269 return is_not_leaf & left_leaf & right_leaf 2M R S T a r b k c s d y z A l m p H q n o e t B u f V W v g j h w i X Y Z L D ] ^ _ ` { | } ~ abbbcbdbebfbgbhbib0 1 2 3 4 5 6 7 8 9 ! # $ % ' ( ) * + ,

270 # the 0-th item has split == 0, so it's not counted 

271 

272 

273def tree_depths(tree_size: int) -> Int32[Array, ' {tree_size}']: 

274 """ 

275 Return the depth of each node in a binary tree. 

276 

277 Parameters 

278 ---------- 

279 tree_size 

280 The length of the tree array, i.e., 2 ** d. 

281 

282 Returns 

283 ------- 

284 The depth of each node. 

285 

286 Notes 

287 ----- 

288 The root node (index 1) has depth 0. The depth is the position of the most 

289 significant non-zero bit in the index. The first element (the unused node) 

290 is marked as depth 0. 

291 """ 

292 depths = [] 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,

293 depth = 0 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,

294 for i in range(tree_size): 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,

295 if i == 2**depth: 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,

296 depth += 1 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,

297 depths.append(depth - 1) 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,

298 depths[0] = 0 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,

299 return jnp.array(depths, minimal_unsigned_dtype(max(depths))) 2obpbqbM R S T a r b rbk sbtbubvbwbxbybc s d y z A zbU Abl Bbm p H q n C o e t B I u J f V W ? @ [ CbDbEbFbGbHbP IbQ v g j - JbKbh w i LbMbNbX Y Z L D = ObPbQbRbSbTbUbVbWbXbYbZb0b1b2b3b4b0 1 2 3 4 5 6 7 8 9 ! # 5b6b7b8b9b!b$ % ' ( ) * + ,

300 

301 

302@partial(jnp.vectorize, signature='(half_tree_size)->(tree_size)') 

303def is_used( 

304 split_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

305) -> Bool[Array, '*batch_shape 2**d']: 

306 """ 

307 Return a mask indicating the used nodes in a tree. 

308 

309 Parameters 

310 ---------- 

311 split_tree 

312 The decision boundaries of the tree. 

313 

314 Returns 

315 ------- 

316 A mask indicating which nodes are actually used. 

317 """ 

318 internal_node = split_tree.astype(bool) 2M #ba b c d U e f h i D

319 internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)]) 2M #ba b c d U e f h i D

320 actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 2M #ba b c d U e f h i D

321 return internal_node | actual_leaf 2M #ba b c d U e f h i D

322 

323 

324@jit 

325def forest_fill(split_tree: UInt[Array, '*batch_shape 2**(d-1)']) -> Float32[Array, '']: 

326 """ 

327 Return the fraction of used nodes in a set of trees. 

328 

329 Parameters 

330 ---------- 

331 split_tree 

332 The decision boundaries of the trees. 

333 

334 Returns 

335 ------- 

336 Number of tree nodes over the maximum number that could be stored. 

337 """ 

338 used = is_used(split_tree) 2M #ba b c d U e f h i D

339 count = jnp.count_nonzero(used) 2M #ba b c d U e f h i D

340 batch_size = split_tree.size // split_tree.shape[-1] 2M #ba b c d U e f h i D

341 return count / (used.size - batch_size) 2M #ba b c d U e f h i D

342 

343 

344@partial(jit, static_argnames=('p', 'sum_batch_axis')) 

345def var_histogram( 

346 p: int, 

347 var_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

348 split_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

349 *, 

350 sum_batch_axis: int | tuple[int, ...] = (), 

351) -> Int32[Array, '*reduced_batch_shape {p}']: 

352 """ 

353 Count how many times each variable appears in a tree. 

354 

355 Parameters 

356 ---------- 

357 p 

358 The number of variables (the maximum value that can occur in `var_tree` 

359 is ``p - 1``). 

360 var_tree 

361 The decision axes of the tree. 

362 split_tree 

363 The decision boundaries of the tree. 

364 sum_batch_axis 

365 The batch axes to sum over. By default, no summation is performed. Note 

366 that negative indices count from the end of the batch dimensions, the 

367 core dimension p can't be summed over by this function. 

368 

369 Returns 

370 ------- 

371 The histogram(s) of the variables used in the tree. 

372 """ 

373 is_internal = split_tree.astype(bool) 1EFGNOabkcdlmef?@[gjhiLD

374 

375 def scatter_add( 1EFGNOabkcdlmef?@[gjhiLD

376 var_tree: UInt[Array, '*summed_batch_axes half_tree_size'], 

377 is_internal: Bool[Array, '*summed_batch_axes half_tree_size'], 

378 ) -> Int32[Array, ' p']: 

379 return jnp.zeros(p, int).at[var_tree].add(is_internal) 1EFGNOabkcdlmef?@[gjhiLD

380 

381 # vmap scatter_add over non-batched dims 

382 batch_ndim = var_tree.ndim - 1 1EFGNOabkcdlmef?@[gjhiLD

383 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) 1EFGNOabkcdlmef?@[gjhiLD

384 for i in reversed(range(batch_ndim)): 1EFGNOabkcdlmef?@[gjhiLD

385 neg_i = i - var_tree.ndim 1EFGNOabkcdlmef?@[gjhiLD

386 if i not in axes: 1EFGNOabkcdlmef?@[gjhiLD

387 scatter_add = vmap(scatter_add, in_axes=neg_i) 1EFGNOk?@[gj

388 

389 return scatter_add(var_tree, is_internal) 1EFGNOabkcdlmef?@[gjhiLD

390 

391 

392def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str: 

393 """Convert a tree to a human-readable string. 

394 

395 Parameters 

396 ---------- 

397 tree 

398 A single tree to format. 

399 print_all 

400 If `True`, also print the contents of unused node slots in the arrays. 

401 

402 Returns 

403 ------- 

404 A string representation of the tree. 

405 """ 

406 tee = '├──' 1x

407 corner = '└──' 1x

408 join = '│ ' 1x

409 space = ' ' 1x

410 down = '┐' 1x

411 bottom = '╢' # '┨' # 1x

412 

413 def traverse_tree( 1x

414 lines: list[str], 

415 index: int, 

416 depth: int, 

417 indent: str, 

418 first_indent: str, 

419 next_indent: str, 

420 unused: bool, 

421 ) -> None: 

422 if index >= len(tree.leaf_tree): 422 ↛ 423line 422 didn't jump to line 423 because the condition on line 422 was never true1x

423 return 

424 

425 var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item() 1x

426 split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item() 1x

427 

428 is_leaf = split == 0 1x

429 left_child = 2 * index 1x

430 right_child = 2 * index + 1 1x

431 

432 if print_all: 432 ↛ 433line 432 didn't jump to line 433 because the condition on line 432 was never true1x

433 if unused: 

434 category = 'unused' 

435 elif is_leaf: 

436 category = 'leaf' 

437 else: 

438 category = 'decision' 

439 node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})' 

440 else: 

441 assert not unused 1x

442 if is_leaf: 1x

443 node_str = f'{tree.leaf_tree[index]:#.2g}' 1x

444 else: 

445 node_str = f'x{var} < {split}' 1x

446 

447 if not is_leaf or (print_all and left_child < len(tree.leaf_tree)): 1x

448 link = down 1x

449 elif not print_all and left_child >= len(tree.leaf_tree): 1x

450 link = bottom 1x

451 else: 

452 link = ' ' 1x

453 

454 max_number = len(tree.leaf_tree) - 1 1x

455 ndigits = len(str(max_number)) 1x

456 number = str(index).rjust(ndigits) 1x

457 

458 lines.append(f' {number} {indent}{first_indent}{link}{node_str}') 1x

459 

460 indent += next_indent 1x

461 unused = unused or is_leaf 1x

462 

463 if unused and not print_all: 1x

464 return 1x

465 

466 traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused) 1x

467 traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused) 1x

468 

469 lines = [] 1x

470 traverse_tree(lines, 1, 0, '', '', '', False) 1x

471 return '\n'.join(lines) 1x

472 

473 

474def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']: 

475 """Measure the depth of the tree. 

476 

477 Parameters 

478 ---------- 

479 split_tree 

480 The cutpoints of the decision rules. 

481 

482 Returns 

483 ------- 

484 The depth of the deepest leaf in the tree. The root is at depth 0. 

485 """ 

486 # this could be done just with split_tree != 0 

487 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1gj-

488 depth = tree_depths(is_leaf.size) 1gj-

489 depth = jnp.where(is_leaf, depth, 0) 1gj-

490 return jnp.max(depth) 1gj-

491 

492 

493@jit 

494@partial(jnp.vectorize, signature='(nt,hts)->(d)') 

495def forest_depth_distr( 

496 split_tree: UInt[Array, '*batch_shape num_trees 2**(d-1)'], 

497) -> Int32[Array, '*batch_shape d']: 

498 """Histogram the depths of a set of trees. 

499 

500 Parameters 

501 ---------- 

502 split_tree 

503 The cutpoints of the decision rules of the trees. 

504 

505 Returns 

506 ------- 

507 An integer vector where the i-th element counts how many trees have depth i. 

508 """ 

509 depth = tree_depth(split_tree) + 1 1g

510 depths = vmap(tree_actual_depth)(split_tree) 1g

511 return jnp.bincount(depths, length=depth) 1g

512 

513 

514@partial(jit, static_argnames=('node_type', 'sum_batch_axis')) 

515def points_per_node_distr( 

516 X: UInt[Array, 'p n'], 

517 var_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

518 split_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

519 node_type: Literal['leaf', 'leaf-parent'], 

520 *, 

521 sum_batch_axis: int | tuple[int, ...] = (), 

522) -> Int32[Array, '*reduced_batch_shape n+1']: 

523 """Histogram points-per-node counts in a set of trees. 

524 

525 Count how many nodes in a tree select each possible amount of points, 

526 over a certain subset of nodes. 

527 

528 Parameters 

529 ---------- 

530 X 

531 The set of points to count. 

532 var_tree 

533 The variables of the decision rules. 

534 split_tree 

535 The cutpoints of the decision rules. 

536 node_type 

537 The type of nodes to consider. Can be: 

538 

539 'leaf' 

540 Count only leaf nodes. 

541 'leaf-parent' 

542 Count only parent-of-leaf nodes. 

543 sum_batch_axis 

544 Aggregate the histogram over these batch axes, counting how many nodes 

545 have each possible amount of points over subsets of trees instead of 

546 in each tree separately. 

547 

548 Returns 

549 ------- 

550 A vector where the i-th element counts how many nodes have i points. 

551 """ 

552 batch_ndim = var_tree.ndim - 1 1pHqnCo

553 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) 1pHqnCo

554 

555 def func( 1pHqnCo

556 var_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

557 split_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

558 ) -> Int32[Array, '*reduced_batch_shape n+1']: 

559 indices: UInt[Array, '*batch_shape n'] 

560 indices = traverse_forest(X, var_tree, split_tree) 1pHqnCo

561 

562 @partial(jnp.vectorize, signature='(hts),(n)->(ts_or_hts),(ts_or_hts)') 1pHqnCo

563 def count_points( 1pHqnCo

564 split_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

565 indices: UInt[Array, '*batch_shape n'], 

566 ) -> ( 

567 tuple[UInt[Array, '*batch_shape 2**d'], Bool[Array, '*batch_shape 2**d']] 

568 | tuple[ 

569 UInt[Array, '*batch_shape 2**(d-1)'], 

570 Bool[Array, '*batch_shape 2**(d-1)'], 

571 ] 

572 ): 

573 if node_type == 'leaf-parent': 1pHqnCo

574 indices >>= 1 1pHq

575 predicate = is_leaves_parent(split_tree) 1pHq

576 elif node_type == 'leaf': 576 ↛ 579line 576 didn't jump to line 579 because the condition on line 576 was always true1nCo

577 predicate = is_actual_leaf(split_tree, add_bottom_level=True) 1nCo

578 else: 

579 raise ValueError(node_type) 

580 count_tree = jnp.zeros(predicate.size, int).at[indices].add(1).at[0].set(0) 1pHqnCo

581 return count_tree, predicate 1pHqnCo

582 

583 count_tree, predicate = count_points(split_tree, indices) 1pHqnCo

584 

585 def count_nodes( 1pHqnCo

586 count_tree: UInt[Array, '*summed_batch_axes half_tree_size'], 

587 predicate: Bool[Array, '*summed_batch_axes half_tree_size'], 

588 ) -> Int32[Array, ' n+1']: 

589 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(predicate) 1pHqnCo

590 

591 # vmap count_nodes over non-batched dims 

592 for i in reversed(range(batch_ndim)): 1pHqnCo

593 neg_i = i - var_tree.ndim 1pHqnCo

594 if i not in axes: 1pHqnCo

595 count_nodes = vmap(count_nodes, in_axes=neg_i) 1pHqnCo

596 

597 return count_nodes(count_tree, predicate) 1pHqnCo

598 

599 # automatically batch over all batch dimensions 

600 max_io_nbytes = 2**27 # 128 MiB 1pHqnCo

601 out_dim_shift = len(axes) 1pHqnCo

602 for i in reversed(range(batch_ndim)): 1pHqnCo

603 if i in axes: 1pHqnCo

604 out_dim_shift -= 1 1pHqnCo

605 else: 

606 func = autobatch(func, max_io_nbytes, i, i - out_dim_shift) 1pHqnCo

607 assert out_dim_shift == 0 1pHqnCo

608 

609 return func(var_tree, split_tree) 1pHqnCo