Coverage for src / bartz / debug / _prior.py: 100%

97 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-06 15:16 +0000

1# bartz/src/bartz/debug/_prior.py 

2# 

3# Copyright (c) 2026, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Debugging utilities. The main functionality is the class `debug_mc_gbart`.""" 

26 

27from dataclasses import replace 

28from functools import partial 

29 

30from equinox import Module 

31from jax import jit, lax, random 

32from jax import numpy as jnp 

33from jax.tree_util import tree_map 

34from jaxtyping import Array, Bool, Float32, Int32, Key, UInt 

35 

36from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc 

37from bartz.jaxext import split as split_key 

38from bartz.mcmcstep._moves import randint_masked 

39 

40 

41class SamplePriorStack(Module): 

42 """Represent the manually managed stack used in `sample_prior`. 

43 

44 Each level of the stack represents a recursion into a child node in a 

45 binary tree of maximum depth `d`. 

46 """ 

47 

48 nonterminal: Bool[Array, ' d-1'] 

49 """Whether the node is valid or the recursion is into unused node slots.""" 

50 

51 lower: UInt[Array, 'd-1 p'] 

52 """The available cutpoints along ``var`` are in the integer range 

53 ``[1 + lower[var], 1 + upper[var])``.""" 

54 

55 upper: UInt[Array, 'd-1 p'] 

56 """The available cutpoints along ``var`` are in the integer range 

57 ``[1 + lower[var], 1 + upper[var])``.""" 

58 

59 var: UInt[Array, ' d-1'] 

60 """The variable of a decision node.""" 

61 

62 split: UInt[Array, ' d-1'] 

63 """The cutpoint of a decision node.""" 

64 

65 @classmethod 

66 def initial( 

67 cls, p_nonterminal: Float32[Array, ' d-1'], max_split: UInt[Array, ' p'] 

68 ) -> 'SamplePriorStack': 

69 """Initialize the stack. 

70 

71 Parameters 

72 ---------- 

73 p_nonterminal 

74 The prior probability of a node being non-terminal conditional on 

75 its ancestors and on having available decision rules, at each depth. 

76 max_split 

77 The number of cutpoints along each variable. 

78 

79 Returns 

80 ------- 

81 A `SamplePriorStack` initialized to start the recursion. 

82 """ 

83 var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1abcd

84 return cls( 1abcd

85 nonterminal=jnp.ones(p_nonterminal.size, bool), 

86 lower=jnp.zeros((p_nonterminal.size, max_split.size), max_split.dtype), 

87 upper=jnp.broadcast_to(max_split, (p_nonterminal.size, max_split.size)), 

88 var=jnp.zeros(p_nonterminal.size, var_dtype), 

89 split=jnp.zeros(p_nonterminal.size, max_split.dtype), 

90 ) 

91 

92 

93class SamplePriorTrees(Module): 

94 """Object holding the trees generated by `sample_prior`.""" 

95 

96 leaf_tree: Float32[Array, '* 2**d'] 

97 """The array representing the trees, see `bartz.grove`.""" 

98 

99 var_tree: UInt[Array, '* 2**(d-1)'] 

100 """The array representing the trees, see `bartz.grove`.""" 

101 

102 split_tree: UInt[Array, '* 2**(d-1)'] 

103 """The array representing the trees, see `bartz.grove`.""" 

104 

105 @classmethod 

106 def initial( 

107 cls, 

108 key: Key[Array, ''], 

109 sigma_mu: Float32[Array, ''], 

110 p_nonterminal: Float32[Array, ' d-1'], 

111 max_split: UInt[Array, ' p'], 

112 ) -> 'SamplePriorTrees': 

113 """Initialize the trees. 

114 

115 The leaves are already correct and do not need to be changed. 

116 

117 Parameters 

118 ---------- 

119 key 

120 A jax random key. 

121 sigma_mu 

122 The prior standard deviation of each leaf. 

123 p_nonterminal 

124 The prior probability of a node being non-terminal conditional on 

125 its ancestors and on having available decision rules, at each depth. 

126 max_split 

127 The number of cutpoints along each variable. 

128 

129 Returns 

130 ------- 

131 Trees initialized with random leaves and stub tree structures. 

132 """ 

133 heap_size = 2 ** (p_nonterminal.size + 1) 1abcd

134 return cls( 1abcd

135 leaf_tree=sigma_mu * random.normal(key, (heap_size,)), 

136 var_tree=jnp.zeros( 

137 heap_size // 2, dtype=minimal_unsigned_dtype(max_split.size - 1) 

138 ), 

139 split_tree=jnp.zeros(heap_size // 2, dtype=max_split.dtype), 

140 ) 

141 

142 

143class SamplePriorCarry(Module): 

144 """Object holding values carried along the recursion in `sample_prior`.""" 

145 

146 key: Key[Array, ''] 

147 """A jax random key used to sample decision rules.""" 

148 

149 stack: SamplePriorStack 

150 """The stack used to manage the recursion.""" 

151 

152 trees: SamplePriorTrees 

153 """The output arrays.""" 

154 

155 @classmethod 

156 def initial( 

157 cls, 

158 key: Key[Array, ''], 

159 sigma_mu: Float32[Array, ''], 

160 p_nonterminal: Float32[Array, ' d-1'], 

161 max_split: UInt[Array, ' p'], 

162 ) -> 'SamplePriorCarry': 

163 """Initialize the carry object. 

164 

165 Parameters 

166 ---------- 

167 key 

168 A jax random key. 

169 sigma_mu 

170 The prior standard deviation of each leaf. 

171 p_nonterminal 

172 The prior probability of a node being non-terminal conditional on 

173 its ancestors and on having available decision rules, at each depth. 

174 max_split 

175 The number of cutpoints along each variable. 

176 

177 Returns 

178 ------- 

179 A `SamplePriorCarry` initialized to start the recursion. 

180 """ 

181 keys = split_key(key) 1abcd

182 return cls( 1abcd

183 keys.pop(), 

184 SamplePriorStack.initial(p_nonterminal, max_split), 

185 SamplePriorTrees.initial(keys.pop(), sigma_mu, p_nonterminal, max_split), 

186 ) 

187 

188 

189class SamplePriorX(Module): 

190 """Object representing the recursion scan in `sample_prior`. 

191 

192 The sequence of nodes to visit is pre-computed recursively once, unrolling 

193 the recursion schedule. 

194 """ 

195 

196 node: Int32[Array, ' 2**(d-1)-1'] 

197 """The heap index of the node to visit.""" 

198 

199 depth: Int32[Array, ' 2**(d-1)-1'] 

200 """The depth of the node.""" 

201 

202 next_depth: Int32[Array, ' 2**(d-1)-1'] 

203 """The depth of the next node to visit, either the left child or the right 

204 sibling of the node or of an ancestor.""" 

205 

206 @classmethod 

207 def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX': 

208 """Initialize the sequence of nodes to visit. 

209 

210 Parameters 

211 ---------- 

212 p_nonterminal 

213 The prior probability of a node being non-terminal conditional on 

214 its ancestors and on having available decision rules, at each depth. 

215 

216 Returns 

217 ------- 

218 A `SamplePriorX` initialized with the sequence of nodes to visit. 

219 """ 

220 seq = cls._sequence(p_nonterminal.size) 1abcd

221 assert len(seq) == 2**p_nonterminal.size - 1 1abcd

222 node = [node for node, depth in seq] 1abcd

223 depth = [depth for node, depth in seq] 1abcd

224 next_depth = [*depth[1:], p_nonterminal.size] 1abcd

225 return cls( 1abcd

226 node=jnp.array(node), 

227 depth=jnp.array(depth), 

228 next_depth=jnp.array(next_depth), 

229 ) 

230 

231 @classmethod 

232 def _sequence( 

233 cls, max_depth: int, depth: int = 0, node: int = 1 

234 ) -> tuple[tuple[int, int], ...]: 

235 """Recursively generate a sequence [(node, depth), ...].""" 

236 if depth < max_depth: 1abcd

237 out = ((node, depth),) 1abcd

238 out += cls._sequence(max_depth, depth + 1, 2 * node) 1abcd

239 out += cls._sequence(max_depth, depth + 1, 2 * node + 1) 1abcd

240 return out 1abcd

241 return () 1abcd

242 

243 

244def sample_prior_onetree( 

245 key: Key[Array, ''], 

246 max_split: UInt[Array, ' p'], 

247 p_nonterminal: Float32[Array, ' d-1'], 

248 sigma_mu: Float32[Array, ''], 

249) -> SamplePriorTrees: 

250 """Sample a tree from the BART prior. 

251 

252 Parameters 

253 ---------- 

254 key 

255 A jax random key. 

256 max_split 

257 The maximum split value for each variable. 

258 p_nonterminal 

259 The prior probability of a node being non-terminal conditional on 

260 its ancestors and on having available decision rules, at each depth. 

261 sigma_mu 

262 The prior standard deviation of each leaf. 

263 

264 Returns 

265 ------- 

266 An object containing a generated tree. 

267 """ 

268 carry = SamplePriorCarry.initial(key, sigma_mu, p_nonterminal, max_split) 1abcd

269 xs = SamplePriorX.initial(p_nonterminal) 1abcd

270 

271 def loop(carry: SamplePriorCarry, x: SamplePriorX) -> tuple[SamplePriorCarry, None]: 1abcd

272 keys = split_key(carry.key, 4) 1abcd

273 

274 # get variables at current stack level 

275 stack = carry.stack 1abcd

276 nonterminal = stack.nonterminal[x.depth] 1abcd

277 lower = stack.lower[x.depth, :] 1abcd

278 upper = stack.upper[x.depth, :] 1abcd

279 

280 # sample a random decision rule 

281 available: Bool[Array, ' p'] = lower < upper 1abcd

282 allowed = jnp.any(available) 1abcd

283 var = randint_masked(keys.pop(), available) 1abcd

284 split = 1 + random.randint(keys.pop(), (), lower[var], upper[var]) 1abcd

285 

286 # cast to shorter integer types 

287 var = var.astype(carry.trees.var_tree.dtype) 1abcd

288 split = split.astype(carry.trees.split_tree.dtype) 1abcd

289 

290 # decide whether to try to grow the node if it is growable 

291 pnt = p_nonterminal[x.depth] 1abcd

292 try_nonterminal: Bool[Array, ''] = random.bernoulli(keys.pop(), pnt) 1abcd

293 nonterminal &= try_nonterminal & allowed 1abcd

294 

295 # update trees 

296 trees = carry.trees 1abcd

297 trees = replace( 1abcd

298 trees, 

299 var_tree=trees.var_tree.at[x.node].set(var), 

300 split_tree=trees.split_tree.at[x.node].set( 

301 jnp.where(nonterminal, split, 0) 

302 ), 

303 ) 

304 

305 def write_push_stack() -> SamplePriorStack: 1abcd

306 """Update the stack to go to the left child.""" 

307 return replace( 1abcd

308 stack, 

309 nonterminal=stack.nonterminal.at[x.next_depth].set(nonterminal), 

310 lower=stack.lower.at[x.next_depth, :].set(lower), 

311 upper=stack.upper.at[x.next_depth, :].set(upper.at[var].set(split - 1)), 

312 var=stack.var.at[x.depth].set(var), 

313 split=stack.split.at[x.depth].set(split), 

314 ) 

315 

316 def pop_push_stack() -> SamplePriorStack: 1abcd

317 """Update the stack to go to the right sibling, possibly at lower depth.""" 

318 var = stack.var[x.next_depth - 1] 1abcd

319 split = stack.split[x.next_depth - 1] 1abcd

320 lower = stack.lower[x.next_depth - 1, :] 1abcd

321 upper = stack.upper[x.next_depth - 1, :] 1abcd

322 return replace( 1abcd

323 stack, 

324 lower=stack.lower.at[x.next_depth, :].set(lower.at[var].set(split)), 

325 upper=stack.upper.at[x.next_depth, :].set(upper), 

326 ) 

327 

328 # update stack 

329 stack = lax.cond(x.next_depth > x.depth, write_push_stack, pop_push_stack) 1abcd

330 

331 # update carry 

332 carry = replace(carry, key=keys.pop(), stack=stack, trees=trees) 1abcd

333 return carry, None 1abcd

334 

335 carry, _ = lax.scan(loop, carry, xs) 1abcd

336 return carry.trees 1abcd

337 

338 

339@partial(vmap_nodoc, in_axes=(0, None, None, None)) 

340def sample_prior_forest( 

341 keys: Key[Array, ' num_trees'], 

342 max_split: UInt[Array, ' p'], 

343 p_nonterminal: Float32[Array, ' d-1'], 

344 sigma_mu: Float32[Array, ''], 

345) -> SamplePriorTrees: 

346 """Sample a set of independent trees from the BART prior. 

347 

348 Parameters 

349 ---------- 

350 keys 

351 A sequence of jax random keys, one for each tree. This determined the 

352 number of trees sampled. 

353 max_split 

354 The maximum split value for each variable. 

355 p_nonterminal 

356 The prior probability of a node being non-terminal conditional on 

357 its ancestors and on having available decision rules, at each depth. 

358 sigma_mu 

359 The prior standard deviation of each leaf. 

360 

361 Returns 

362 ------- 

363 An object containing the generated trees. 

364 """ 

365 return sample_prior_onetree(keys, max_split, p_nonterminal, sigma_mu) 1abcd

366 

367 

368@partial(jit, static_argnums=(1, 2)) 

369def sample_prior( 

370 key: Key[Array, ''], 

371 trace_length: int, 

372 num_trees: int, 

373 max_split: UInt[Array, ' p'], 

374 p_nonterminal: Float32[Array, ' d-1'], 

375 sigma_mu: Float32[Array, ''], 

376) -> SamplePriorTrees: 

377 """Sample independent trees from the BART prior. 

378 

379 Parameters 

380 ---------- 

381 key 

382 A jax random key. 

383 trace_length 

384 The number of iterations. 

385 num_trees 

386 The number of trees for each iteration. 

387 max_split 

388 The number of cutpoints along each variable. 

389 p_nonterminal 

390 The prior probability of a node being non-terminal conditional on 

391 its ancestors and on having available decision rules, at each depth. 

392 This determines the maximum depth of the trees. 

393 sigma_mu 

394 The prior standard deviation of each leaf. 

395 

396 Returns 

397 ------- 

398 An object containing the generated trees, with batch shape (trace_length, num_trees). 

399 """ 

400 keys = random.split(key, trace_length * num_trees) 1abcd

401 trees = sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu) 1abcd

402 return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees) 1abcd