Coverage for src / bartz / _interface.py: 92%
333 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-06 15:16 +0000
1# bartz/src/bartz/_interface.py
2#
3# Copyright (c) 2025-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Main high-level interface of the package."""
27import math
28from collections.abc import Mapping, Sequence
29from functools import cached_property, partial
30from types import MappingProxyType
31from typing import Any, Literal, Protocol, TypedDict
33import jax
34import jax.numpy as jnp
35from equinox import Module, error_if, field
36from jax import Device, device_put, jit, lax, make_mesh
37from jax.scipy.special import ndtr
38from jax.sharding import AxisType, Mesh
39from jaxtyping import (
40 Array,
41 Bool,
42 Float,
43 Float32,
44 Int32,
45 Integer,
46 Key,
47 Real,
48 Shaped,
49 UInt,
50)
51from numpy import ndarray
53from bartz import mcmcloop, mcmcstep, prepcovars
54from bartz.jaxext import is_key
55from bartz.jaxext.scipy.special import ndtri
56from bartz.jaxext.scipy.stats import invgamma
57from bartz.mcmcloop import RunMCMCResult, compute_varcount, evaluate_trace, run_mcmc
58from bartz.mcmcstep import make_p_nonterminal
59from bartz.mcmcstep._state import get_num_chains
61FloatLike = float | Float[Any, '']
64class DataFrame(Protocol):
65 """DataFrame duck-type for `Bart`."""
67 columns: Sequence[str]
68 """The names of the columns."""
70 def to_numpy(self) -> ndarray:
71 """Convert the dataframe to a 2d numpy array with columns on the second axis."""
72 ...
75class Series(Protocol):
76 """Series duck-type for `Bart`."""
78 name: str | None
79 """The name of the series."""
81 def to_numpy(self) -> ndarray:
82 """Convert the series to a 1d numpy array."""
83 ...
86class Bart(Module):
87 R"""
88 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
90 Regress `y_train` on `x_train` with a latent mean function represented as
91 a sum of decision trees. The inference is carried out by sampling the
92 posterior distribution of the tree ensemble with an MCMC.
94 Parameters
95 ----------
96 x_train
97 The training predictors.
98 y_train
99 The training responses.
100 x_test
101 The test predictors.
102 type
103 The type of regression. 'wbart' for continuous regression, 'pbart' for
104 binary regression with probit link.
105 sparse
106 Whether to activate variable selection on the predictors as done in
107 [1]_.
108 theta
109 a
110 b
111 rho
112 Hyperparameters of the sparsity prior used for variable selection.
114 The prior distribution on the choice of predictor for each decision rule
115 is
117 .. math::
118 (s_1, \ldots, s_p) \sim
119 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
121 If `theta` is not specified, it's a priori distributed according to
123 .. math::
124 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
125 \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
127 If not specified, `rho` is set to the number of predictors p. To tune
128 the prior, consider setting a lower `rho` to prefer more sparsity.
129 If setting `theta` directly, it should be in the ballpark of p or lower
130 as well.
131 varprob
132 The probability distribution over the `p` predictors for choosing a
133 predictor to split on in a decision node a priori. Must be > 0. It does
134 not need to be normalized to sum to 1. If not specified, use a uniform
135 distribution. If ``sparse=True``, this is used as initial value for the
136 MCMC.
137 xinfo
138 A matrix with the cutpoins to use to bin each predictor. If not
139 specified, it is generated automatically according to `usequants` and
140 `numcut`.
142 Each row shall contain a sorted list of cutpoints for a predictor. If
143 there are less cutpoints than the number of columns in the matrix,
144 fill the remaining cells with NaN.
146 `xinfo` shall be a matrix even if `x_train` is a dataframe.
147 usequants
148 Whether to use predictors quantiles instead of a uniform grid to bin
149 predictors. Ignored if `xinfo` is specified.
150 rm_const
151 How to treat predictors with no associated decision rules (i.e., there
152 are no available cutpoints for that predictor). If `True` (default),
153 they are ignored. If `False`, an error is raised if there are any.
154 sigest
155 An estimate of the residual standard deviation on `y_train`, used to set
156 `lamda`. If not specified, it is estimated by linear regression (with
157 intercept, and without taking into account `w`). If `y_train` has less
158 than two elements, it is set to 1. If n <= p, it is set to the standard
159 deviation of `y_train`. Ignored if `lamda` is specified.
160 sigdf
161 The degrees of freedom of the scaled inverse-chisquared prior on the
162 noise variance.
163 sigquant
164 The quantile of the prior on the noise variance that shall match
165 `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
166 k
167 The inverse scale of the prior standard deviation on the latent mean
168 function, relative to half the observed range of `y_train`. If `y_train`
169 has less than two elements, `k` is ignored and the scale is set to 1.
170 power
171 base
172 Parameters of the prior on tree node generation. The probability that a
173 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
174 power``.
175 lamda
176 The prior harmonic mean of the error variance. (The harmonic mean of x
177 is 1/mean(1/x).) If not specified, it is set based on `sigest` and
178 `sigquant`.
179 tau_num
180 The numerator in the expression that determines the prior standard
181 deviation of leaves. If not specified, default to ``(max(y_train) -
182 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
183 continuous regression, and 3 for binary regression.
184 offset
185 The prior mean of the latent mean function. If not specified, it is set
186 to the mean of `y_train` for continuous regression, and to
187 ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
188 `offset` is set to 0. With binary regression, if `y_train` is all
189 `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
190 ``Phi^-1(n/(n+1))``, respectively.
191 w
192 Coefficients that rescale the error standard deviation on each
193 datapoint. Not specifying `w` is equivalent to setting it to 1 for all
194 datapoints. Note: `w` is ignored in the automatic determination of
195 `sigest`, so either the weights should be O(1), or `sigest` should be
196 specified by the user.
197 num_trees
198 The number of trees used to represent the latent mean function.
199 numcut
200 If `usequants` is `False`: the exact number of cutpoints used to bin the
201 predictors, ranging between the minimum and maximum observed values
202 (excluded).
204 If `usequants` is `True`: the maximum number of cutpoints to use for
205 binning the predictors. Each predictor is binned such that its
206 distribution in `x_train` is approximately uniform across bins. The
207 number of bins is at most the number of unique values appearing in
208 `x_train`, or ``numcut + 1``.
210 Before running the algorithm, the predictors are compressed to the
211 smallest integer type that fits the bin indices, so `numcut` is best set
212 to the maximum value of an unsigned integer type, like 255.
214 Ignored if `xinfo` is specified.
215 ndpost
216 The number of MCMC samples to save, after burn-in. `ndpost` is the
217 total number of samples across all chains. `ndpost` is rounded up to the
218 first multiple of `mc_cores`.
219 nskip
220 The number of initial MCMC samples to discard as burn-in. This number
221 of samples is discarded from each chain.
222 keepevery
223 The thinning factor for the MCMC samples, after burn-in.
224 printevery
225 The number of iterations (including thinned-away ones) between each log
226 line. Set to `None` to disable logging. ^C interrupts the MCMC only
227 every `printevery` iterations, so with logging disabled it's impossible
228 to kill the MCMC conveniently.
229 num_chains
230 The number of independent Markov chains to run.
232 The difference between ``num_chains=None`` and ``num_chains=1`` is that
233 in the latter case in the object attributes and some methods there will
234 be an explicit chain axis of size 1.
235 num_chain_devices
236 The number of devices to spread the chains across. Must be a divisor of
237 `num_chains`. Each device will run a fraction of the chains.
238 num_data_devices
239 The number of devices to split datapoints across. Must be a divisor of
240 `n`. This is useful only with very high `n`, about > 1000_000.
242 If both num_chain_devices and num_data_devices are specified, the total
243 number of devices used is the product of the two.
244 devices
245 One or more devices used to run the MCMC on. If not specified, the
246 computation will follow the placement of the input arrays. If a list of
247 devices, this argument can be longer than the number of devices needed.
248 seed
249 The seed for the random number generator.
250 maxdepth
251 The maximum depth of the trees. This is 1-based, so with the default
252 ``maxdepth=6``, the depths of the levels range from 0 to 5.
253 init_kw
254 Additional arguments passed to `bartz.mcmcstep.init`.
255 run_mcmc_kw
256 Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
258 References
259 ----------
260 .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
261 High-Dimensional Prediction and Variable Selection”. In: Journal of the
262 American Statistical Association 113.522, pp. 626-636.
263 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
264 Bayesian additive regression trees," The Annals of Applied Statistics,
265 Ann. Appl. Stat. 4(1), 266-298, (March 2010).
266 """
268 _main_trace: mcmcloop.MainTrace
269 _burnin_trace: mcmcloop.BurninTrace
270 _mcmc_state: mcmcstep.State
271 _splits: Real[Array, 'p max_num_splits']
272 _x_train_fmt: Any = field(static=True)
274 offset: Float32[Array, '']
275 """The prior mean of the latent mean function."""
277 sigest: Float32[Array, ''] | None = None
278 """The estimated standard deviation of the error used to set `lamda`."""
280 yhat_test: Float32[Array, 'ndpost m'] | None = None
281 """The conditional posterior mean at `x_test` for each MCMC iteration."""
283 def __init__(
284 self,
285 x_train: Real[Array, 'p n'] | DataFrame,
286 y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
287 *,
288 x_test: Real[Array, 'p m'] | DataFrame | None = None,
289 type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
290 sparse: bool = False,
291 theta: FloatLike | None = None,
292 a: FloatLike = 0.5,
293 b: FloatLike = 1.0,
294 rho: FloatLike | None = None,
295 varprob: Float[Array, ' p'] | None = None,
296 xinfo: Float[Array, 'p n'] | None = None,
297 usequants: bool = False,
298 rm_const: bool = True,
299 sigest: FloatLike | None = None,
300 sigdf: FloatLike = 3.0,
301 sigquant: FloatLike = 0.9,
302 k: FloatLike = 2.0,
303 power: FloatLike = 2.0,
304 base: FloatLike = 0.95,
305 lamda: FloatLike | None = None,
306 tau_num: FloatLike | None = None,
307 offset: FloatLike | None = None,
308 w: Float[Array, ' n'] | Series | None = None,
309 num_trees: int = 200,
310 numcut: int = 255,
311 ndpost: int = 1000,
312 nskip: int = 1000,
313 keepevery: int = 1,
314 printevery: int | None = 100,
315 num_chains: int | None = 4,
316 num_chain_devices: int | None = None,
317 num_data_devices: int | None = None,
318 devices: Device | Sequence[Device] | None = None,
319 seed: int | Key[Array, ''] = 0,
320 maxdepth: int = 6,
321 init_kw: Mapping = MappingProxyType({}),
322 run_mcmc_kw: Mapping = MappingProxyType({}),
323 ) -> None:
324 # check data and put it in the right format
325 x_train, x_train_fmt = self._process_predictor_input(x_train) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
326 y_train = self._process_response_input(y_train) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
327 self._check_same_length(x_train, y_train) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
328 if w is not None: 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
329 w = self._process_response_input(w) 1sw?]phNzCn$JPQMGVaerDkAcR
330 self._check_same_length(x_train, w) 1sw?]phNzCn$JPQMGVaerDkAcR
332 # check data types are correct for continuous/binary regression
333 self._check_type_settings(y_train, type, w) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
334 # from here onwards, the type is determined by y_train.dtype == bool
336 # process sparsity settings
337 theta, a, b, rho = self._process_sparsity_settings( 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
338 x_train, sparse, theta, a, b, rho
339 )
341 # process "standardization" settings
342 offset = self._process_offset_settings(y_train, offset) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
343 sigma_mu = self._process_leaf_sdev_settings(y_train, k, num_trees, tau_num) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
344 lamda, sigest = self._process_error_variance_settings( 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
345 x_train, y_train, sigest, sigdf, sigquant, lamda
346 )
348 # determine splits
349 splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
350 x_train = self._bin_predictors(x_train, splits) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
352 # setup and run mcmc
353 initial_state = self._setup_mcmc( 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
354 x_train,
355 y_train,
356 offset,
357 w,
358 max_split,
359 lamda,
360 sigma_mu,
361 sigdf,
362 power,
363 base,
364 maxdepth,
365 num_trees,
366 init_kw,
367 rm_const,
368 theta,
369 a,
370 b,
371 rho,
372 varprob,
373 num_chains,
374 num_chain_devices,
375 num_data_devices,
376 devices,
377 sparse,
378 nskip,
379 )
380 result = self._run_mcmc( 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
381 initial_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
382 )
384 # set public attributes
385 # set offset from the state because of buffer donation
386 self.offset = result.final_state.offset 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
387 self.sigest = sigest 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
389 # set private attributes
390 self._main_trace = result.main_trace 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
391 self._burnin_trace = result.burnin_trace 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
392 self._mcmc_state = result.final_state 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
393 self._splits = splits 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
394 self._x_train_fmt = x_train_fmt 1vYsXyZwt#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
396 # predict at test points
397 if x_test is not None: 1vYsXyZw]t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
398 self.yhat_test = self.predict(x_test) 1vYsyZw]t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxDlkE5AdjcW/R
400 @property
401 def ndpost(self) -> int:
402 """The total number of posterior samples after burn-in across all chains.
404 May be larger than the initialization argument `ndpost` if it was not
405 divisible by the number of chains.
406 """
407 return self._main_trace.grow_prop_count.size 1YX`{^}_mbgafqe-89.E5Aj
409 @property
410 def num_trees(self) -> int:
411 """Return the number of trees used in the model."""
412 return self._mcmc_state.forest.split_tree.shape[-2] 1K!G
414 @cached_property
415 def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
416 """The posterior probability of y being True at `x_test` for each MCMC iteration."""
417 if self.yhat_test is None or self._mcmc_state.y.dtype != bool: 1|imhbgaqdjc
418 return None 1ihbadc
419 else:
420 return ndtr(self.yhat_test) 1|mgqj
422 @cached_property
423 def prob_test_mean(self) -> Float32[Array, ' m'] | None:
424 """The marginal posterior probability of y being True at `x_test`."""
425 if self.prob_test is None: 1bgaqx
426 return None 1ba
427 else:
428 return self.prob_test.mean(axis=0) 1gqx
430 @cached_property
431 def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
432 """The posterior probability of y being True at `x_train` for each MCMC iteration."""
433 if self._mcmc_state.y.dtype == bool: 1|}imhbgaqdjc
434 return ndtr(self.yhat_train) 1|}mgqj
435 else:
436 return None 1ihbadc
438 @cached_property
439 def prob_train_mean(self) -> Float32[Array, ' n'] | None:
440 """The marginal posterior probability of y being True at `x_train`."""
441 if self.prob_train is None: 1bgaqxdjc
442 return None 1badc
443 else:
444 return self.prob_train.mean(axis=0) 1gqxj
446 @cached_property
447 def sigma(
448 self,
449 ) -> (
450 Float32[Array, ' nskip+ndpost']
451 | Float32[Array, 'nskip+ndpost/mc_cores mc_cores']
452 | None
453 ):
454 """The standard deviation of the error, including burn-in samples."""
455 if self._burnin_trace.error_cov_inv is None: 1^_imhbgafeu3rxlkdjc
456 return None 1mg3xj
457 assert self._main_trace.error_cov_inv is not None 1^_ihbafeurlkdc
458 return jnp.sqrt( 1^_ihbafeurlkdc
459 jnp.reciprocal(
460 jnp.concatenate(
461 [
462 self._burnin_trace.error_cov_inv.T,
463 self._main_trace.error_cov_inv.T,
464 ],
465 axis=0,
466 # error_cov_inv has shape (chains? samples) in the trace
467 )
468 )
469 )
471 @cached_property
472 def sigma_(self) -> Float32[Array, 'ndpost'] | None:
473 """The standard deviation of the error, only over the post-burnin samples and flattened."""
474 error_cov_inv = self._main_trace.error_cov_inv 1`{imhbgafelkdjc
475 if error_cov_inv is None: 1`{imhbgafexlkdjc
476 return None 1mgxj
477 else:
478 return jnp.sqrt(jnp.reciprocal(error_cov_inv)).reshape(-1) 1`{ihbafelkdc
480 @cached_property
481 def sigma_mean(self) -> Float32[Array, ''] | None:
482 """The mean of `sigma`, only over the post-burnin samples."""
483 if self.sigma_ is None: 1`{bgafexlkdjc
484 return None 1gxj
485 return self.sigma_.mean() 1`{bafelkdc
487 @cached_property
488 def varcount(self) -> Int32[Array, 'ndpost p']:
489 """Histogram of predictor usage for decision rules in the trees."""
490 p = self._mcmc_state.forest.max_split.size 1yZw|^_imhbgafqe89.djc
491 return varcount(p, self._main_trace) 1yZw|^_imhbgafqe89.djc
493 @cached_property
494 def varcount_mean(self) -> Float32[Array, ' p']:
495 """Average of `varcount` across MCMC iterations."""
496 return self.varcount.mean(axis=0) 1yZw`{bgafqedjc
498 @cached_property
499 def varprob(self) -> Float32[Array, 'ndpost p']:
500 """Posterior samples of the probability of choosing each predictor for a decision rule."""
501 max_split = self._mcmc_state.forest.max_split 1vYsX^_imhbgafqedjcLB
502 p = max_split.size 1vYsX^_imhbgafqedjcLB
503 varprob = self._main_trace.varprob 1vYsX^_imhbgafqedjcLB
504 if varprob is None: 1vYsX^_imhbgafqexdjcLB
505 peff = jnp.count_nonzero(max_split) 1YXmgqxj
506 varprob = jnp.where(max_split, 1 / peff, 0) 1YXmgqxj
507 varprob = jnp.broadcast_to(varprob, (self.ndpost, p)) 1YXmgqxj
508 else:
509 varprob = varprob.reshape(-1, p) 1vs^_ihbafedcLB
510 return varprob 1vYsX^_imhbgafqexdjcLB
512 @cached_property
513 def varprob_mean(self) -> Float32[Array, ' p']:
514 """The marginal posterior probability of each predictor being chosen for a decision rule."""
515 return self.varprob.mean(axis=0) 1vYsX`{bgafqedjcLB
517 @cached_property
518 def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
519 """The marginal posterior mean at `x_test`.
521 Not defined with binary regression because it's error-prone, typically
522 the right thing to consider would be `prob_test_mean`.
523 """
524 if self.yhat_test is None or self._mcmc_state.y.dtype == bool: 1`{bgafexlkdjc
525 return None 1gxj
526 else:
527 return self.yhat_test.mean(axis=0) 1`{bafelkdc
529 @cached_property
530 def yhat_train(self) -> Float32[Array, 'ndpost n']:
531 """The conditional posterior mean at `x_train` for each MCMC iteration."""
532 x_train = self._mcmc_state.X 1^}_imho7nO(J0+Mbgafqeu3rHxDlkE5Adjc
533 return self._predict(x_train) 1^}_imho7nO(J0+Mbgafqeu3rHxDlkE5Adjc
535 @cached_property
536 def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
537 """The marginal posterior mean at `x_train`.
539 Not defined with binary regression because it's error-prone, typically
540 the right thing to consider would be `prob_train_mean`.
541 """
542 if self._mcmc_state.y.dtype == bool: 1`{bgafexlkdjc
543 return None 1gxj
544 else:
545 return self.yhat_train.mean(axis=0) 1`{bafelkdc
547 def predict(
548 self, x_test: Real[Array, 'p m'] | DataFrame
549 ) -> Float32[Array, 'ndpost m']:
550 """
551 Compute the posterior mean at `x_test` for each MCMC iteration.
553 Parameters
554 ----------
555 x_test
556 The test predictors.
558 Returns
559 -------
560 The conditional posterior mean at `x_test` for each MCMC iteration.
562 Raises
563 ------
564 ValueError
565 If `x_test` has a different format than `x_train`.
566 """
567 x_test, x_test_fmt = self._process_predictor_input(x_test) 1vYsyZw]t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxDlkE5AdjcW/R
568 if x_test_fmt != self._x_train_fmt: 1vYsyZw]t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxDlkE5AdjcW/R
569 msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}' 1F6z
570 raise ValueError(msg) 1F6z
571 x_test = self._bin_predictors(x_test, self._splits) 1vYsyZw]t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxDlkE5AdjcW/R
572 return self._predict(x_test) 1vYsyZw]t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxDlkE5AdjcW/R
574 @staticmethod
575 def _process_predictor_input(
576 x: Real[Any, 'p n'] | DataFrame,
577 ) -> tuple[Shaped[Array, 'p n'], Any]:
578 if hasattr(x, 'columns'): 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
579 fmt = dict(kind='dataframe', columns=x.columns) 1F6zu3r
580 x = x.to_numpy().T 1F6zu3r
581 else:
582 fmt = dict(kind='array', num_covar=x.shape[0]) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
583 x = jnp.asarray(x) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
584 assert x.ndim == 2 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
585 return x, fmt 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
587 @staticmethod
588 def _process_response_input(y: Shaped[Array, ' n'] | Series) -> Shaped[Array, ' n']:
589 if hasattr(y, 'to_numpy'): 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
590 y = y.to_numpy() 1zu3r
591 y = jnp.asarray(y) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
592 assert y.ndim == 1 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
593 return y 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
595 @staticmethod
596 def _check_same_length(x1: Array, x2: Array) -> None:
597 get_length = lambda x: x.shape[-1] 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
598 assert get_length(x1) == get_length(x2) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
600 @classmethod
601 def _process_error_variance_settings(
602 cls,
603 x_train: Shaped[Array, 'p n'],
604 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
605 sigest: FloatLike | None,
606 sigdf: FloatLike,
607 sigquant: FloatLike,
608 lamda: FloatLike | None,
609 ) -> tuple[Float32[Array, ''] | None, ...]:
610 """Return (lamda, sigest)."""
611 if y_train.dtype == bool: 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
612 if sigest is not None: 612 ↛ 613line 612 didn't jump to line 613 because the condition on line 612 was never true1YZ@#m%627;()*+!,gq3x5j/
613 msg = 'Let `sigest=None` for binary regression'
614 raise ValueError(msg)
615 if lamda is not None: 615 ↛ 616line 615 didn't jump to line 616 because the condition on line 615 was never true1YZ@#m%627;()*+!,gq3x5j/
616 msg = 'Let `lamda=None` for binary regression'
617 raise ValueError(msg)
618 return None, None 1YZ@#m%627;()*+!,gq3x5j/
619 elif lamda is not None: 619 ↛ 620line 619 didn't jump to line 620 because the condition on line 619 was never true1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:
620 if sigest is not None:
621 msg = 'Let `sigest=None` if `lamda` is specified'
622 raise ValueError(msg)
623 return lamda, None
624 else:
625 if sigest is not None: 625 ↛ 626line 625 didn't jump to line 626 because the condition on line 625 was never true1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:
626 sigest2 = jnp.square(sigest)
627 elif y_train.size < 2: 1vsXyw=?]tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:
628 sigest2 = 1 10M1V-89.[:
629 elif y_train.size <= x_train.shape[0]: 1vsXyw=?]tpihSNFzICon'$OJTPUQKGbafeurHDlkEAdcWRLB
630 sigest2 = jnp.var(y_train) 1WR
631 else:
632 sigest2 = cls._linear_regression(x_train, y_train) 1vsXyw=?]tpihSNFzICon'$OJTPUQKGbafeurHDlkEAdcLB
633 alpha = sigdf / 2 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:
634 invchi2 = invgamma.ppf(sigquant, alpha) / 2 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:
635 invchi2rid = invchi2 * sigdf 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:
636 return sigest2 / invchi2rid, jnp.sqrt(sigest2) 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:
638 @staticmethod
639 @jit
640 def _linear_regression(
641 x_train: Shaped[Array, 'p n'], y_train: Float32[Array, ' n']
642 ) -> Float32[Array, '']:
643 """Return the error variance estimated with OLS with intercept."""
644 x_centered = x_train.T - x_train.mean(axis=1) 1]tponLB
645 y_centered = y_train - y_train.mean() 1]tponLB
646 # centering is equivalent to adding an intercept column
647 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1]tponLB
648 chisq = chisq.squeeze(0) 1]tponLB
649 dof = len(y_train) - rank 1]tponLB
650 return chisq / dof 1]tponLB
652 @staticmethod
653 def _check_type_settings(
654 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
655 type: str, # noqa: A002
656 w: Float[Array, ' n'] | None,
657 ) -> None:
658 match type: 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
659 case 'wbart': 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
660 if y_train.dtype != jnp.float32: 660 ↛ 661line 660 didn't jump to line 661 because the condition on line 660 was never true1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:
661 msg = (
662 'Continuous regression requires y_train.dtype=float32,'
663 f' got {y_train.dtype=} instead.'
664 )
665 raise TypeError(msg) 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:
666 case 'pbart': 666 ↛ 676line 666 didn't jump to line 676 because the pattern on line 666 always matched1YZ@#m%627;()*+!,gq3x5j/
667 if w is not None: 667 ↛ 668line 667 didn't jump to line 668 because the condition on line 667 was never true1YZ@#m%627;()*+!,gq3x5j/
668 msg = 'Binary regression does not support weights, set `w=None`'
669 raise ValueError(msg)
670 if y_train.dtype != bool: 670 ↛ 671line 670 didn't jump to line 671 because the condition on line 670 was never true1YZ@#m%627;()*+!,gq3x5j/
671 msg = (
672 'Binary regression requires y_train.dtype=bool,'
673 f' got {y_train.dtype=} instead.'
674 )
675 raise TypeError(msg) 1YZ@#m%627;()*+!,gq3x5j/
676 case _:
677 msg = f'Invalid {type=}'
678 raise ValueError(msg)
680 @staticmethod
681 def _process_sparsity_settings(
682 x_train: Real[Array, 'p n'],
683 sparse: bool,
684 theta: FloatLike | None,
685 a: FloatLike,
686 b: FloatLike,
687 rho: FloatLike | None,
688 ) -> (
689 tuple[None, None, None, None]
690 | tuple[FloatLike, None, None, None]
691 | tuple[None, FloatLike, FloatLike, FloatLike]
692 ):
693 """Return (theta, a, b, rho)."""
694 if not sparse: 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
695 return None, None, None, None 1YXyZw@#m%627;()*+!,gq3x-89.5j/[:
696 elif theta is not None: 1vs=?]tpihSNFzICon'$OJTPUQ0MKG1VbafeurHDlkEAdcWRLB
697 return theta, None, None, None 1s?]phNzCn$JPQMGVaerDkAcRL
698 else:
699 if rho is None: 699 ↛ 702line 699 didn't jump to line 702 because the condition on line 699 was always true1v=tiSFIo'OTU0K1bfuHlEdWB
700 p, _ = x_train.shape 1v=tiSFIo'OTU0K1bfuHlEdWB
701 rho = float(p) 1v=tiSFIo'OTU0K1bfuHlEdWB
702 return None, a, b, rho 1v=tiSFIo'OTU0K1bfuHlEdWB
704 @staticmethod
705 def _process_offset_settings(
706 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
707 offset: float | Float32[Any, ''] | None,
708 ) -> Float32[Array, '']:
709 """Return offset."""
710 if offset is not None: 710 ↛ 711line 710 didn't jump to line 711 because the condition on line 710 was never true1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
711 return jnp.asarray(offset)
712 elif y_train.size < 1: 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
713 return jnp.array(0.0) 10+M-89.[:
714 else:
715 mean = y_train.mean() 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*QK!G1,Vbgafqeu3rHxDlkE5AdjcW/RLB
717 if y_train.dtype == bool: 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*QK!G1,Vbgafqeu3rHxDlkE5AdjcW/RLB
718 bound = 1 / (1 + y_train.size) 1YZ@#m%627;()*!,gq3x5j/
719 mean = jnp.clip(mean, bound, 1 - bound) 1YZ@#m%627;()*!,gq3x5j/
720 return ndtri(mean) 1YZ@#m%627;()*!,gq3x5j/
721 else:
722 return mean 1vsXyw=?]tpihSNFzICon'$OJTPUQKG1VbafeurHDlkEAdcWRLB
724 @staticmethod
725 def _process_leaf_sdev_settings(
726 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
727 k: FloatLike,
728 num_trees: int,
729 tau_num: FloatLike | None,
730 ) -> FloatLike:
731 """Return sigma_mu."""
732 if tau_num is None: 732 ↛ 740line 732 didn't jump to line 740 because the condition on line 732 was always true1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
733 if y_train.dtype == bool: 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
734 tau_num = 3.0 1YZ@#m%627;()*+!,gq3x5j/
735 elif y_train.size < 2: 1vsXyw=?]tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB[:
736 tau_num = 1.0 10M1V-89.[:
737 else:
738 tau_num = (y_train.max() - y_train.min()) / 2 1vsXyw=?]tpihSNFzICon'$OJTPUQKGbafeurHDlkEAdcWRLB
740 return tau_num / (k * math.sqrt(num_trees)) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
742 @staticmethod
743 def _determine_splits(
744 x_train: Real[Array, 'p n'],
745 usequants: bool,
746 numcut: int,
747 xinfo: Float[Array, 'p n'] | None,
748 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
749 if xinfo is not None: 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB[:
750 if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]: 1X0+M,V-89.[:
751 msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)' 1[
752 raise ValueError(msg) 1[
753 return prepcovars.parse_xinfo(xinfo) 1X0+M,V-89.:
754 elif usequants: 1vYsyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*QK!G1bgafqeu3rHxDlkE5AdjcW/RLB
755 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1YsZw@?]#pmh%N6z2C7n;$(J)P*Q!Ggaqe3rxDk5Ajc/R
756 else:
757 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1vy=tiSFIo'OTUK1bfuHlEdWLB
759 @staticmethod
760 def _bin_predictors(
761 x: Real[Array, 'p n'], splits: Real[Array, 'p max_num_splits']
762 ) -> UInt[Array, 'p n']:
763 return prepcovars.bin_predictors(x, splits) 1vYsXyZw=@?`|{t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
765 @staticmethod
766 def _setup_mcmc(
767 x_train: Real[Array, 'p n'],
768 y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
769 offset: Float32[Array, ''],
770 w: Float[Array, ' n'] | None,
771 max_split: UInt[Array, ' p'],
772 lamda: Float32[Array, ''] | None,
773 sigma_mu: FloatLike,
774 sigdf: FloatLike,
775 power: FloatLike,
776 base: FloatLike,
777 maxdepth: int,
778 num_trees: int,
779 init_kw: Mapping[str, Any],
780 rm_const: bool,
781 theta: FloatLike | None,
782 a: FloatLike | None,
783 b: FloatLike | None,
784 rho: FloatLike | None,
785 varprob: Float[Any, ' p'] | None,
786 num_chains: int | None,
787 num_chain_devices: int | None,
788 num_data_devices: int | None,
789 devices: Device | Sequence[Device] | None,
790 sparse: bool,
791 nskip: int,
792 ) -> mcmcstep.State:
793 p_nonterminal = make_p_nonterminal(maxdepth, base, power) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
795 if y_train.dtype == bool: 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
796 error_cov_df = None 1YZ@#m%627;()*+!,gq3x5j/
797 error_cov_scale = None 1YZ@#m%627;()*+!,gq3x5j/
798 else:
799 assert lamda is not None 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB:
800 # inverse gamma prior: alpha = df / 2, beta = scale / 2
801 error_cov_df = sigdf 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB:
802 error_cov_scale = lamda * sigdf 1vsXyw=?tpihSNFzICon'$OJTPUQ0MKG1VbafeurHD-89.lkEAdcWRLB:
804 # process device settings
805 device_kw, device = process_device_settings( 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
806 y_train, num_chains, num_chain_devices, num_data_devices, devices
807 )
809 kw: dict = dict( 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
810 X=x_train,
811 # copy y_train because it's going to be donated in the mcmc loop
812 y=jnp.array(y_train),
813 offset=offset,
814 error_scale=w,
815 max_split=max_split,
816 num_trees=num_trees,
817 p_nonterminal=p_nonterminal,
818 leaf_prior_cov_inv=jnp.reciprocal(jnp.square(sigma_mu)),
819 error_cov_df=error_cov_df,
820 error_cov_scale=error_cov_scale,
821 min_points_per_decision_node=10,
822 log_s=process_varprob(varprob, max_split),
823 theta=theta,
824 a=a,
825 b=b,
826 rho=rho,
827 sparse_on_at=nskip // 2 if sparse else None,
828 **device_kw,
829 )
831 if rm_const: 1vYsXyZw]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
832 n_empty = jnp.sum(max_split == 0).item() 1vYsXyZwt#pimhS%NF6zI2Co7n';$T)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
833 kw.update(filter_splitless_vars=n_empty) 1vYsXyZwt#pimhS%NF6zI2Co7n';$T)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
835 kw.update(init_kw) 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
837 state = mcmcstep.init(**kw) 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
839 # put state on device if requested explicitly by the user
840 if device is not None: 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
841 state = device_put(state, device, donate=True) 1(J
843 return state 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
845 @classmethod
846 def _run_mcmc(
847 cls,
848 mcmc_state: mcmcstep.State,
849 ndpost: int,
850 nskip: int,
851 keepevery: int,
852 printevery: int | None,
853 seed: int | Integer[Array, ''] | Key[Array, ''],
854 run_mcmc_kw: Mapping,
855 ) -> RunMCMCResult:
856 # prepare random generator seed
857 if is_key(seed): 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
858 key = jnp.copy(seed) 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB
859 else:
860 key = jax.random.key(seed) 1:
862 # round up ndpost
863 num_chains = get_num_chains(mcmc_state) 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
864 if num_chains is None: 1vYsXyZw]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
865 num_chains = 1 1vyiSFo'OTU0K1bfuH-89.lEdW
866 n_save = ndpost // num_chains + bool(ndpost % num_chains) 1vYsXyZw]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
868 # prepare arguments
869 kw: dict = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery) 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
870 kw.update( 1vYsXyZw]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
871 mcmcloop.make_default_callback(
872 mcmc_state,
873 dot_every=None if printevery is None or printevery == 1 else 1,
874 report_every=printevery,
875 )
876 )
877 kw.update(run_mcmc_kw) 1vYsXyZw]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
879 return run_mcmc(key, mcmc_state, n_save, **kw) 1vYsXyZwt#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
881 def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']:
882 """Evaluate trees on already quantized `x`."""
883 return predict(x, self._main_trace) 1vYsyZw^}_t#pimhS%NF6zI2Co7nO(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/R
886@partial(jit, static_argnames='p')
887# this is jitted such that lax.collapse below does not create a copy
888def varcount(p: int, trace: mcmcloop.MainTrace) -> Int32[Array, 'ndpost p']:
889 """Histogram of predictor usage for decision rules in the trees, squashing chains."""
890 varcount: Int32[Array, '*chains samples p']
891 varcount = compute_varcount(p, trace) 1yZw|^_mbga89
892 return lax.collapse(varcount, 0, -1) 1yZw|^_mbga89
895@jit
896# this is jitted such that lax.collapse below does not create a copy
897def predict(
898 x: UInt[Array, 'p m'], trace: mcmcloop.MainTrace
899) -> Float32[Array, 'ndpost m']:
900 """Evaluate trees on already quantized `x`, and squash chains."""
901 out = evaluate_trace(x, trace) 1yZw^}_t#pmI2Co7nOJ0+MK!G1HD-89E5A
902 return lax.collapse(out, 0, -1) 1yZw^}_t#pmI2Co7nOJ0+MK!G1HD-89E5A
905class DeviceKwArgs(TypedDict):
906 num_chains: int | None
907 mesh: Mesh | None
908 target_platform: Literal['cpu', 'gpu'] | None
911def process_device_settings(
912 y_train: Array,
913 num_chains: int | None,
914 num_chain_devices: int | None,
915 num_data_devices: int | None,
916 devices: Device | Sequence[Device] | None,
917) -> tuple[DeviceKwArgs, Device | None]:
918 """Return the arguments for `mcmcstep.init` related to devices, and an optional device where to put the state."""
919 # determine devices
920 if devices is not None: 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
921 if not hasattr(devices, '__len__'): 921 ↛ 922line 921 didn't jump to line 922 because the condition on line 921 was never true1O(J
922 devices = (devices,)
923 device = devices[0] 1O(J
924 platform = device.platform 1O(J
925 elif hasattr(y_train, 'platform'): 925 ↛ 933line 925 didn't jump to line 933 because the condition on line 925 was always true1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$T)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
926 platform = y_train.platform() 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$T)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
927 device = None 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$T)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
928 # set device=None because if the devices were not specified explicitly
929 # we may be in the case where computation will follow data placement,
930 # do not disturb jax as the user may be playing with vmap, jit, reshard...
931 devices = jax.devices(platform) 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$T)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
932 else:
933 msg = 'not possible to infer device from `y_train`, please set `devices`'
934 raise ValueError(msg)
936 # create mesh
937 if num_chain_devices is None and num_data_devices is None: 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
938 mesh = None 1YZ@#m%6I2C7;()*0+!1,gq3x-89.5j/
939 else:
940 mesh = dict() 1vsXyw=?tpihSNFzI2Con'$OJTPUQMKGVbafeurHDlkEAdcWRLB:
941 if num_chain_devices is not None: 1vsXyw=?tpihSNFzI2Con'$OJTPUQMKGVbafeurHDlkEAdcWRLB:
942 mesh.update(chains=num_chain_devices) 1sXw?tphNzI2Cn$JPQMGVaerDkAcRLB:
943 if num_data_devices is not None: 1vsXyw=?tpihSNFzI2Con'$OJTPUQMKGVbafeurHDlkEAdcWRLB:
944 mesh.update(data=num_data_devices) 1vy=tiSFI2Co'OTUKbfuHlEdW
945 mesh = make_mesh( 1vsXyw=?tpihSNFzI2Con'$OJTPUQMKGVbafeurHDlkEAdcWRLB:
946 axis_shapes=tuple(mesh.values()),
947 axis_names=tuple(mesh),
948 axis_types=(AxisType.Auto,) * len(mesh),
949 devices=devices,
950 )
951 device = None 1vsXyw=?tpihSNFzI2Con'$OJTPUQMKGVbafeurHDlkEAdcWRLB:
952 # set device=None because `mcmcstep.init` will `device_put` with the
953 # mesh already, we don't want to undo its work
955 # prepare arguments to `init`
956 settings = DeviceKwArgs( 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
957 num_chains=num_chains,
958 mesh=mesh,
959 target_platform=None
960 if mesh is not None or hasattr(y_train, 'platform')
961 else platform,
962 # here we don't take into account the case where the user has set both
963 # batch sizes; since the user has to be playing with `init_kw` to do
964 # that, we'll let `init` throw the error and the user set
965 # `target_platform` themselves so they have a clearer idea how the
966 # thing works.
967 )
969 return settings, device 1vYsXyZw=@?t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
972def process_varprob(
973 varprob: Float[Any, ' p'] | None, max_split: UInt[Array, ' p']
974) -> Float32[Array, ' p'] | None:
975 """Convert varprob to log_s."""
976 if varprob is None: 1vYsXyZw=@?]t#pimhS%NF6zI2Co7n';$O(JT)PU*Q0+MK!G1,Vbgafqeu3rHxD-89.lkE5AdjcW/RLB:
977 return None 1vYXt#imS%F6I2o7';O(T)U*0+K!1,bgfqu3Hx-89.lE5djW/LB:
978 varprob = jnp.asarray(varprob) 1syZw=@?]phNzCn$JPQMGVaerDkAcR
979 assert varprob.shape == max_split.shape, 'varprob must have shape (p,)' 1syZw=@?]phNzCn$JPQMGVaerDkAcR
980 varprob = error_if(varprob, varprob <= 0, 'varprob must be > 0') 1syZw=@?]phNzCn$JPQMGVaerDkAcR
981 return jnp.log(varprob) 1syZw]phNzCn$JPQMGVaerDkAcR