Coverage for src / bartz / grove.py: 100%

97 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-13 00:35 +0000

1# bartz/src/bartz/grove.py 

2# 

3# Copyright (c) 2024-2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions to create and manipulate binary decision trees.""" 

26 

27import math 

28from functools import partial 

29from typing import Protocol 

30 

31from jax import jit, lax, vmap 

32from jax import numpy as jnp 

33from jaxtyping import Array, Bool, DTypeLike, Float32, Int32, Shaped, UInt 

34 

35try: 

36 from numpy.lib.array_utils import normalize_axis_tuple # numpy 2 

37except ImportError: 

38 from numpy.core.numeric import normalize_axis_tuple # numpy 1 

39 

40from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc 

41 

42 

43class TreeHeaps(Protocol): 

44 """A protocol for dataclasses that represent trees. 

45 

46 A tree is represented with arrays as a heap. The root node is at index 1. 

47 The children nodes of a node at index :math:`i` are at indices :math:`2i` 

48 (left child) and :math:`2i + 1` (right child). The array element at index 0 

49 is unused. 

50 

51 Arrays may have additional initial axes to represent multple trees. 

52 

53 Parameters 

54 ---------- 

55 leaf_tree 

56 The values in the leaves of the trees. This array can be dirty, i.e., 

57 unused nodes can have whatever value. It may have an additional axis 

58 for multivariate leaves. 

59 var_tree 

60 The axes along which the decision nodes operate. This array can be 

61 dirty but for the always unused node at index 0 which must be set to 0. 

62 split_tree 

63 The decision boundaries of the trees. The boundaries are open on the 

64 right, i.e., a point belongs to the left child iff x < split. Whether a 

65 node is a leaf is indicated by the corresponding 'split' element being 

66 0. Unused nodes also have split set to 0. This array can't be dirty. 

67 

68 Notes 

69 ----- 

70 Since the nodes at the bottom can only be leaves and not decision nodes, 

71 `var_tree` and `split_tree` are half as long as `leaf_tree`. 

72 """ 

73 

74 leaf_tree: ( 

75 Float32[Array, '*batch_shape 2**d'] | Float32[Array, '*batch_shape k 2**d'] 

76 ) 

77 var_tree: UInt[Array, '*batch_shape 2**(d-1)'] 

78 split_tree: UInt[Array, '*batch_shape 2**(d-1)'] 

79 

80 

81def make_tree( 

82 depth: int, dtype: DTypeLike, batch_shape: tuple[int, ...] = () 

83) -> Shaped[Array, '*batch_shape 2**{depth}']: 

84 """ 

85 Make an array to represent a binary tree. 

86 

87 Parameters 

88 ---------- 

89 depth 

90 The maximum depth of the tree. Depth 1 means that there is only a root 

91 node. 

92 dtype 

93 The dtype of the array. 

94 batch_shape 

95 The leading shape of the array, to represent multiple trees and/or 

96 multivariate trees. 

97 

98 Returns 

99 ------- 

100 An array of zeroes with the appropriate shape. 

101 """ 

102 shape = (*batch_shape, 2**depth) 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %

103 return jnp.zeros(shape, dtype) 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %

104 

105 

106def tree_depth(tree: Shaped[Array, '*batch_shape 2**d']) -> int: 

107 """ 

108 Return the maximum depth of a tree. 

109 

110 Parameters 

111 ---------- 

112 tree 

113 A tree created by `make_tree`. If the array is ND, the tree structure is 

114 assumed to be along the last axis. 

115 

116 Returns 

117 ------- 

118 The maximum depth of the tree. 

119 """ 

120 return round(math.log2(tree.shape[-1])) 2I J a u v w B ' C ( b ) l * c q r s j K k L M F G d m n e N O z A o h i + f p g t P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdb(b)b*b+beb,bS T U V W X Y Z 0 1 2 3 4 5 -b.b/bfb:b;b=b?b@b[b6 7 8 9 ! # $ %

121 

122 

123def traverse_tree( 

124 x: UInt[Array, ' p'], 

125 var_tree: UInt[Array, ' 2**(d-1)'], 

126 split_tree: UInt[Array, ' 2**(d-1)'], 

127) -> UInt[Array, '']: 

128 """ 

129 Find the leaf where a point falls into. 

130 

131 Parameters 

132 ---------- 

133 x 

134 The coordinates to evaluate the tree at. 

135 var_tree 

136 The decision axes of the tree. 

137 split_tree 

138 The decision boundaries of the tree. 

139 

140 Returns 

141 ------- 

142 The index of the leaf. 

143 """ 

144 carry = ( 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

145 jnp.zeros((), bool), 

146 jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)), 

147 ) 

148 

149 def loop(carry, _): 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

150 leaf_found, index = carry 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

151 

152 split = split_tree[index] 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

153 var = var_tree[index] 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

154 

155 leaf_found |= split == 0 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

156 child_index = (index << 1) + (x[var] >= split) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

157 index = jnp.where(leaf_found, index, child_index) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

158 

159 return (leaf_found, index), None 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

160 

161 depth = tree_depth(var_tree) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

162 (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

163 return index 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

164 

165 

166@jit 

167@partial(jnp.vectorize, excluded=(0,), signature='(hts),(hts)->(n)') 

168@partial(vmap_nodoc, in_axes=(1, None, None)) 

169def traverse_forest( 

170 X: UInt[Array, 'p n'], 

171 var_trees: UInt[Array, '*forest_shape 2**(d-1)'], 

172 split_trees: UInt[Array, '*forest_shape 2**(d-1)'], 

173) -> UInt[Array, '*forest_shape n']: 

174 """ 

175 Find the leaves where points falls into for each tree in a set. 

176 

177 Parameters 

178 ---------- 

179 X 

180 The coordinates to evaluate the trees at. 

181 var_trees 

182 The decision axes of the trees. 

183 split_trees 

184 The decision boundaries of the trees. 

185 

186 Returns 

187 ------- 

188 The indices of the leaves. 

189 """ 

190 return traverse_tree(X, var_trees, split_trees) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

191 

192 

193@partial(jit, static_argnames=('sum_batch_axis',)) 

194def evaluate_forest( 

195 X: UInt[Array, 'p n'], 

196 trees: TreeHeaps, 

197 *, 

198 sum_batch_axis: int | tuple[int, ...] = (), 

199) -> ( 

200 Float32[Array, '*reduced_batch_size n'] | Float32[Array, '*reduced_batch_size k n'] 

201): 

202 """ 

203 Evaluate an ensemble of trees at an array of points. 

204 

205 Parameters 

206 ---------- 

207 X 

208 The coordinates to evaluate the trees at. 

209 trees 

210 The trees. 

211 sum_batch_axis 

212 The batch axes to sum over. By default, no summation is performed. 

213 Note that negative indices count from the end of the batch dimensions, 

214 the core dimensions n and k can't be summed over by this function. 

215 

216 Returns 

217 ------- 

218 The (sum of) the values of the trees at the points in `X`. 

219 """ 

220 indices: UInt[Array, '*forest_shape n'] 

221 indices = traverse_forest(X, trees.var_tree, trees.split_tree) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

222 

223 is_mv = trees.leaf_tree.ndim != trees.var_tree.ndim 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

224 

225 bc_indices: UInt[Array, '*forest_shape n 1'] | UInt[Array, '*forest_shape 1 n 1'] 

226 bc_indices = indices[..., None, :, None] if is_mv else indices[..., None] 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

227 

228 bc_leaf_tree: ( 

229 Float32[Array, '*forest_shape 1 tree_size'] 

230 | Float32[Array, '*forest_shape k 1 tree_size'] 

231 ) 

232 bc_leaf_tree = ( 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

233 trees.leaf_tree[..., :, None, :] if is_mv else trees.leaf_tree[..., None, :] 

234 ) 

235 

236 bc_leaves: ( 

237 Float32[Array, '*forest_shape n 1'] | Float32[Array, '*forest_shape k n 1'] 

238 ) 

239 bc_leaves = jnp.take_along_axis(bc_leaf_tree, bc_indices, -1) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

240 

241 leaves: Float32[Array, '*forest_shape n'] | Float32[Array, '*forest_shape k n'] 

242 leaves = jnp.squeeze(bc_leaves, -1) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

243 

244 axis = normalize_axis_tuple(sum_batch_axis, trees.var_tree.ndim - 1) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

245 return jnp.sum(leaves, axis=axis) 1auvwB'C(b)l*cqrsjkdmnezAohifpgt

246 

247 

248def is_actual_leaf( 

249 split_tree: UInt[Array, ' 2**(d-1)'], *, add_bottom_level: bool = False 

250) -> Bool[Array, ' 2**(d-1)'] | Bool[Array, ' 2**d']: 

251 """ 

252 Return a mask indicating the leaf nodes in a tree. 

253 

254 Parameters 

255 ---------- 

256 split_tree 

257 The splitting points of the tree. 

258 add_bottom_level 

259 If True, the bottom level of the tree is also considered. 

260 

261 Returns 

262 ------- 

263 The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True. 

264 """ 

265 size = split_tree.size 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %

266 is_leaf = split_tree == 0 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %

267 if add_bottom_level: 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %

268 size *= 2 2a u v w b l c E F , G d m n e o h i + f p g x H y : 0b1b2b3b4b5beb6b7b8bfb9b!b#b$b%b'b

269 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 2a u v w b l c E F , G d m n e o h i + f p g x H y : 0b1b2b3b4b5beb6b7b8bfb9b!b#b$b%b'b

270 index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1)) 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %

271 parent_index = index >> 1 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %

272 parent_nonleaf = split_tree[parent_index].astype(bool) 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %

273 parent_nonleaf = parent_nonleaf.at[1].set(True) 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %

274 return is_leaf & parent_nonleaf 2I J a u v w b l c q r s E j K k L M F , G d m n e N O o h i + f p g P Q R x H y : 0b1b= ? @ [ ] ^ _ ` { | } ~ abbbcbdb2b3b4b5beb6bS T U V W X Y Z 0 1 2 3 4 5 7b8bfb9b!b#b$b%b'b6 7 8 9 ! # $ %

275 

276 

277def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(d-1)']: 

278 """ 

279 Return a mask indicating the nodes with leaf (and only leaf) children. 

280 

281 Parameters 

282 ---------- 

283 split_tree 

284 The decision boundaries of the tree. 

285 

286 Returns 

287 ------- 

288 The mask indicating which nodes have leaf children. 

289 """ 

290 index = jnp.arange( 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %

291 split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1) 

292 ) 

293 left_index = index << 1 # left child 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %

294 right_index = left_index + 1 # right child 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %

295 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %

296 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %

297 is_not_leaf = split_tree.astype(bool) 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %

298 return is_not_leaf & left_leaf & right_leaf 2I J a b l c q r s j K k L ; M F G d m n e N O o h i f p g P Q R x H y = ? @ [ ] ^ _ ` { | } ~ abbbcbdbS T U V W X Y Z 0 1 2 3 4 5 6 7 8 9 ! # $ %

299 # the 0-th item has split == 0, so it's not counted 

300 

301 

302def tree_depths(tree_size: int) -> Int32[Array, ' {tree_size}']: 

303 """ 

304 Return the depth of each node in a binary tree. 

305 

306 Parameters 

307 ---------- 

308 tree_size 

309 The length of the tree array, i.e., 2 ** d. 

310 

311 Returns 

312 ------- 

313 The depth of each node. 

314 

315 Notes 

316 ----- 

317 The root node (index 1) has depth 0. The depth is the position of the most 

318 significant non-zero bit in the index. The first element (the unused node) 

319 is marked as depth 0. 

320 """ 

321 depths = [] 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %

322 depth = 0 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %

323 for i in range(tree_size): 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %

324 if i == 2**depth: 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %

325 depth += 1 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %

326 depths.append(depth - 1) 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %

327 depths[0] = 0 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %

328 return jnp.array(depths, minimal_unsigned_dtype(max(depths))) 2I J a gbhbibb l c jbkblbmbnbobq r s pbE qbj K k L ; M F , G d m n e N O - . / rbsbtbubvbwbz xbA o h i + ybzbf p g AbBbt P Q R x H CbDbEby : FbGbHbIbJbKbLbMbNbObPbQbRbSbTbUbS T U V W X Y Z 0 1 2 3 4 5 VbWbXbYbZb6 7 8 9 ! # $ %

329 

330 

331@partial(jnp.vectorize, signature='(half_tree_size)->(tree_size)') 

332def is_used( 

333 split_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

334) -> Bool[Array, '*batch_shape 2**d']: 

335 """ 

336 Return a mask indicating the used nodes in a tree. 

337 

338 Parameters 

339 ---------- 

340 split_tree 

341 The decision boundaries of the tree. 

342 

343 Returns 

344 ------- 

345 A mask indicating which nodes are actually used. 

346 """ 

347 internal_node = split_tree.astype(bool) 1abcEdefgy

348 internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)]) 1abcEdefgy

349 actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1abcEdefgy

350 return internal_node | actual_leaf 1abcEdefgy

351 

352 

353@jit 

354def forest_fill(split_tree: UInt[Array, '*batch_shape 2**(d-1)']) -> Float32[Array, '']: 

355 """ 

356 Return the fraction of used nodes in a set of trees. 

357 

358 Parameters 

359 ---------- 

360 split_tree 

361 The decision boundaries of the trees. 

362 

363 Returns 

364 ------- 

365 Number of tree nodes over the maximum number that could be stored. 

366 """ 

367 used = is_used(split_tree) 1abcEdefgy

368 count = jnp.count_nonzero(used) 1abcEdefgy

369 batch_size = split_tree.size // split_tree.shape[-1] 1abcEdefgy

370 return count / (used.size - batch_size) 1abcEdefgy

371 

372 

373@partial(jit, static_argnames=('p', 'sum_batch_axis')) 

374def var_histogram( 

375 p: int, 

376 var_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

377 split_tree: UInt[Array, '*batch_shape 2**(d-1)'], 

378 *, 

379 sum_batch_axis: int | tuple[int, ...] = (), 

380) -> Int32[Array, '*reduced_batch_shape {p}']: 

381 """ 

382 Count how many times each variable appears in a tree. 

383 

384 Parameters 

385 ---------- 

386 p 

387 The number of variables (the maximum value that can occur in `var_tree` 

388 is ``p - 1``). 

389 var_tree 

390 The decision axes of the tree. 

391 split_tree 

392 The decision boundaries of the tree. 

393 sum_batch_axis 

394 The batch axes to sum over. By default, no summation is performed. Note 

395 that negative indices count from the end of the batch dimensions, the 

396 core dimension p can't be summed over by this function. 

397 

398 Returns 

399 ------- 

400 The histogram(s) of the variables used in the tree. 

401 """ 

402 is_internal = split_tree.astype(bool) 1auvwBCbcjkde-./hifgtx

403 

404 def scatter_add( 1auvwBCbcjkde-./hifgtx

405 var_tree: UInt[Array, '*summed_batch_axes half_tree_size'], 

406 is_internal: Bool[Array, '*summed_batch_axes half_tree_size'], 

407 ) -> Int32[Array, ' p']: 

408 return jnp.zeros(p, int).at[var_tree].add(is_internal) 1auvwBCbcjkde-./hifgtx

409 

410 # vmap scatter_add over non-batched dims 

411 batch_ndim = var_tree.ndim - 1 1auvwBCbcjkde-./hifgtx

412 axes = normalize_axis_tuple(sum_batch_axis, batch_ndim) 1auvwBCbcjkde-./hifgtx

413 for i in reversed(range(batch_ndim)): 1auvwBCbcjkde-./hifgtx

414 neg_i = i - var_tree.ndim 1auvwBCbcjkde-./hifgtx

415 if i not in axes: 1auvwBCbcjkde-./hifgtx

416 scatter_add = vmap(scatter_add, in_axes=neg_i) 1uvwBC-./hit

417 

418 return scatter_add(var_tree, is_internal) 1auvwBCbcjkde-./hifgtx