Coverage for src / bartz / _interface.py: 89%
457 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
1# bartz/src/bartz/_interface.py
2#
3# Copyright (c) 2025-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Main high-level interface of the package."""
27import math
28from collections.abc import Mapping, Sequence
29from dataclasses import replace
30from enum import Enum
31from functools import cached_property, partial
32from types import MappingProxyType
33from typing import Any, Literal, Protocol, TypedDict
35import jax
36import jax.numpy as jnp
37from equinox import Module, error_if, field
38from jax import Device, debug_nans, device_put, jit, lax, make_mesh, random, tree
39from jax.scipy.linalg import solve_triangular
40from jax.scipy.special import ndtr
41from jax.sharding import AxisType, Mesh, PartitionSpec
42from jaxtyping import (
43 Array,
44 Bool,
45 Float,
46 Float32,
47 Int32,
48 Integer,
49 Key,
50 Real,
51 Shaped,
52 UInt,
53)
54from numpy import ndarray
56from bartz import mcmcloop, mcmcstep, prepcovars
57from bartz.grove import (
58 TreesTrace,
59 check_trace,
60 evaluate_forest,
61 forest_depth_distr,
62 points_per_node_distr,
63)
64from bartz.jaxext import equal_shards, is_key
65from bartz.jaxext.scipy.special import ndtri
66from bartz.jaxext.scipy.stats import invgamma
67from bartz.mcmcloop import RunMCMCResult, compute_varcount, evaluate_trace, run_mcmc
68from bartz.mcmcstep import OutcomeType, make_p_nonterminal
69from bartz.mcmcstep._state import (
70 _inv_via_chol_with_gersh,
71 chol_with_gersh,
72 get_num_chains,
73)
75FloatLike = float | Float[Any, '']
78class PredictKind(Enum):
79 """Kind of output of `Bart.predict`."""
81 mean = 'mean'
82 """The posterior mean of the conditional mean, shape ``(m,)`` (or
83 ``(k, m)`` for multivariate regression)."""
85 mean_samples = 'mean_samples'
86 """Per-sample conditional mean, shape ``(ndpost, m)`` (or ``(ndpost,
87 k, m)``). For binary regression, this is the probit-transformed
88 sum-of-trees."""
90 outcome_samples = 'outcome_samples'
91 """Samples of the outcome variable, shape ``(ndpost, m)`` (or
92 ``(ndpost, k, m)``). For binary regression, these are Bernoulli
93 draws. For continuous regression, these are Gaussian draws with the
94 posterior noise variance."""
96 latent_samples = 'latent_samples'
97 """Raw sum-of-trees values, shape ``(ndpost, m)`` (or ``(ndpost, k,
98 m)``)."""
101class DataFrame(Protocol):
102 """DataFrame duck-type for `Bart`."""
104 columns: Sequence[str]
105 """The names of the columns."""
107 def to_numpy(self) -> ndarray:
108 """Convert the dataframe to a 2d numpy array with columns on the second axis."""
109 ...
112class Series(Protocol):
113 """Series duck-type for `Bart`."""
115 name: str | None
116 """The name of the series."""
118 def to_numpy(self) -> ndarray:
119 """Convert the series to a 1d numpy array."""
120 ...
123class Bart(Module):
124 R"""
125 Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
127 Regress `y_train` on `x_train` with a latent mean function represented as
128 a sum of decision trees. The inference is carried out by sampling the
129 posterior distribution of the tree ensemble with an MCMC.
131 Parameters
132 ----------
133 x_train
134 The training predictors.
135 y_train
136 The training responses. For univariate regression, a 1D array of shape
137 `(n,)`. For multivariate regression, a 2D array of shape `(k, n)` where
138 `k` is the number of response components, as introduced in [3]_. For
139 binary regression, the convention is that non-zero values mean 1, zero
140 mean 0, like booleans.
141 outcome_type
142 The type of regression. ``'continuous'`` for continuous regression,
143 ``'binary'`` for binary regression with probit link. For multivariate
144 regression, a scalar value applies to all components; alternatively, a
145 sequence of per-component types (e.g., ``['binary', 'continuous']``)
146 specifies mixed outcome types.
147 sparse
148 Whether to activate variable selection on the predictors as done in
149 [1]_.
150 theta
151 a
152 b
153 rho
154 Hyperparameters of the sparsity prior used for variable selection.
156 The prior distribution on the choice of predictor for each decision rule
157 is
159 .. math::
160 (s_1, \ldots, s_p) \sim
161 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
163 If `theta` is not specified, it's a priori distributed according to
165 .. math::
166 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
167 \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
169 If not specified, `rho` is set to the number of predictors p. To tune
170 the prior, consider setting a lower `rho` to prefer more sparsity.
171 If setting `theta` directly, it should be in the ballpark of p or lower
172 as well.
173 varprob
174 The probability distribution over the `p` predictors for choosing a
175 predictor to split on in a decision node a priori. Must be > 0. It does
176 not need to be normalized to sum to 1. If not specified, use a uniform
177 distribution. If ``sparse=True``, this is used as initial value for the
178 MCMC.
179 xinfo
180 A matrix with the cutpoins to use to bin each predictor. If not
181 specified, it is generated automatically according to `usequants` and
182 `numcut`.
184 Each row shall contain a sorted list of cutpoints for a predictor. If
185 there are less cutpoints than the number of columns in the matrix,
186 fill the remaining cells with NaN.
188 `xinfo` shall be a matrix even if `x_train` is a dataframe.
189 usequants
190 Whether to use predictors quantiles instead of a uniform grid to bin
191 predictors. Ignored if `xinfo` is specified.
192 rm_const
193 How to treat predictors with no associated decision rules (i.e., there
194 are no available cutpoints for that predictor). If `True` (default),
195 they are ignored. If `False`, an error is raised if there are any.
196 sigest
197 An estimate of the residual standard deviation on `y_train`, used to set
198 `lamda`. If not specified, it is estimated by linear regression (with
199 intercept, and without taking into account `w`). If `y_train` has less
200 than two elements, it is set to 1. If n <= p, it is set to the standard
201 deviation of `y_train`. Ignored if `lamda` is specified. For
202 multivariate regression, can be a scalar (broadcast to all components)
203 or a `(k,)` vector of per-component estimates. For mixed outcome types,
204 binary component values are ignored.
205 sigdf
206 The degrees of freedom of the scaled inverse-chisquared prior on the
207 noise variance. For multivariate regression, the Inverse-Wishart
208 degrees of freedom are set to `sigdf + k - 1`.
209 sigquant
210 The quantile of the prior on the noise variance that shall match
211 `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
212 k
213 The inverse scale of the prior standard deviation on the latent mean
214 function, relative to half the observed range of `y_train`. If `y_train`
215 has less than two elements, `k` is ignored and the scale is set to 1.
216 power
217 base
218 Parameters of the prior on tree node generation. The probability that a
219 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
220 power``.
221 lamda
222 The prior harmonic mean of the error variance. (The harmonic mean of x
223 is 1/mean(1/x).) If not specified, it is set based on `sigest` and
224 `sigquant`. For multivariate regression, can be a scalar (broadcast
225 to all components) or a `(k,)` vector. For mixed outcome types, binary
226 component values are ignored.
227 tau_num
228 The numerator in the expression that determines the prior standard
229 deviation of leaves. If not specified, default to ``(max(y_train) -
230 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
231 continuous regression, and 3 for binary regression. For multivariate
232 regression, the range is computed per component. For mixed outcome
233 types, each component uses the default for its type.
234 offset
235 The prior mean of the latent mean function. If not specified, it is set
236 to the mean of `y_train` for continuous regression, and to
237 ``Phi^-1(mean(y_train != 0))`` for binary regression. If `y_train` is
238 empty, `offset` is set to 0. With binary regression, if `y_train` is
239 all zero or all non-zero, it is set to ``Phi^-1(1/(n+1))`` or
240 ``Phi^-1(n/(n+1))``, respectively. For multivariate regression, can be
241 a scalar (broadcast to all components) or a `(k,)` vector. If not
242 specified, it is set to the per-component mean of `y_train`. For mixed
243 outcome types, each component uses the default for its type.
244 w
245 Coefficients that rescale the error standard deviation on each
246 datapoint. Not specifying `w` is equivalent to setting it to 1 for all
247 datapoints. Note: `w` is ignored in the automatic determination of
248 `sigest`, so either the weights should be O(1), or `sigest` should be
249 specified by the user. Not supported for multivariate regression.
250 num_trees
251 The number of trees used to represent the latent mean function.
252 numcut
253 If `usequants` is `False`: the exact number of cutpoints used to bin the
254 predictors, ranging between the minimum and maximum observed values
255 (excluded).
257 If `usequants` is `True`: the maximum number of cutpoints to use for
258 binning the predictors. Each predictor is binned such that its
259 distribution in `x_train` is approximately uniform across bins. The
260 number of bins is at most the number of unique values appearing in
261 `x_train`, or ``numcut + 1``.
263 Before running the algorithm, the predictors are compressed to the
264 smallest integer type that fits the bin indices, so `numcut` is best set
265 to the maximum value of an unsigned integer type, like 255.
267 Ignored if `xinfo` is specified.
268 ndpost
269 The number of MCMC samples to save, after burn-in. `ndpost` is the
270 total number of samples across all chains. `ndpost` is rounded up to the
271 first multiple of `num_chains`.
272 nskip
273 The number of initial MCMC samples to discard as burn-in. This number
274 of samples is discarded from each chain.
275 keepevery
276 The thinning factor for the MCMC samples, after burn-in.
277 printevery
278 The number of iterations (including thinned-away ones) between each log
279 line. Set to `None` to disable logging. ^C interrupts the MCMC only
280 every `printevery` iterations, so with logging disabled it's impossible
281 to kill the MCMC conveniently.
282 num_chains
283 The number of independent Markov chains to run.
285 The difference between ``num_chains=None`` and ``num_chains=1`` is that
286 in the latter case in the object attributes and some methods there will
287 be an explicit chain axis of size 1.
288 num_chain_devices
289 The number of devices to spread the chains across. Must be a divisor of
290 `num_chains`. Each device will run a fraction of the chains.
291 num_data_devices
292 The number of devices to split datapoints across. Must be a divisor of
293 `n`. This is useful only with very high `n`, about > 1000_000.
295 If both num_chain_devices and num_data_devices are specified, the total
296 number of devices used is the product of the two.
297 devices
298 One or more devices used to run the MCMC on. If not specified, the
299 computation will follow the placement of the input arrays. If a list of
300 devices, this argument can be longer than the number of devices needed.
301 seed
302 The seed for the random number generator.
303 maxdepth
304 The maximum depth of the trees. This is 1-based, so with the default
305 ``maxdepth=6``, the depths of the levels range from 0 to 5.
306 init_kw
307 Additional arguments passed to `bartz.mcmcstep.init`.
308 run_mcmc_kw
309 Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
311 References
312 ----------
313 .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
314 High-Dimensional Prediction and Variable Selection”. In: Journal of the
315 American Statistical Association 113.522, pp. 626-636.
316 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
317 Bayesian additive regression trees," The Annals of Applied Statistics,
318 Ann. Appl. Stat. 4(1), 266-298, (March 2010).
319 .. [3] Um, Seungha, Antonio R. Linero, Debajyoti Sinha, and Dipankar
320 Bandyopadhyay (2023). "Bayesian additive regression trees for
321 multivariate skewed responses". In: Statistics in Medicine 42.3,
322 pp. 246-263.
324 """
326 _main_trace: mcmcloop.MainTrace
327 _burnin_trace: mcmcloop.BurninTrace
328 _mcmc_state: mcmcstep.State
329 _splits: Real[Array, 'p max_num_splits']
330 _binary_mask: Bool[Array, ''] | Bool[Array, ' k']
331 _x_train_fmt: Any = field(static=True)
333 offset: Float32[Array, ''] | Float32[Array, ' k']
334 """The prior mean of the latent mean function."""
336 sigest: Float32[Array, ''] | Float32[Array, ' k'] | None = None
337 """The estimated standard deviation of the error used to set `lamda`."""
339 def __init__(
340 self,
341 x_train: Real[Array, 'p n'] | DataFrame,
342 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'] | Series,
343 *,
344 outcome_type: OutcomeType | str | Sequence[OutcomeType | str] = 'continuous',
345 sparse: bool = False,
346 theta: FloatLike | None = None,
347 a: FloatLike = 0.5,
348 b: FloatLike = 1.0,
349 rho: FloatLike | None = None,
350 varprob: Float[Array, ' p'] | None = None,
351 xinfo: Float[Array, 'p n'] | None = None,
352 usequants: bool = False,
353 rm_const: bool = True,
354 sigest: FloatLike | Float[Array, ' k'] | None = None,
355 sigdf: FloatLike = 3.0,
356 sigquant: FloatLike = 0.9,
357 k: FloatLike = 2.0,
358 power: FloatLike = 2.0,
359 base: FloatLike = 0.95,
360 lamda: FloatLike | Float[Array, ' k'] | None = None,
361 tau_num: FloatLike | None = None,
362 offset: FloatLike | Float[Array, ' k'] | None = None,
363 w: Float[Array, ' n'] | Series | None = None,
364 num_trees: int = 200,
365 numcut: int = 255,
366 ndpost: int = 1000,
367 nskip: int = 1000,
368 keepevery: int = 1,
369 printevery: int | None = 100,
370 num_chains: int | None = 4,
371 num_chain_devices: int | None = None,
372 num_data_devices: int | None = None,
373 devices: Device | Sequence[Device] | None = None,
374 seed: int | Key[Array, ''] = 0,
375 maxdepth: int = 6,
376 init_kw: Mapping = MappingProxyType({}),
377 run_mcmc_kw: Mapping = MappingProxyType({}),
378 ) -> None:
379 # check data and put it in the right format
380 x_train, x_train_fmt = self._process_predictor_input(x_train) 1bca
381 y_train = self._process_response_input(y_train) 1bca
382 self._check_same_length(x_train, y_train) 1bca
384 if w is not None: 1bzca
385 w = self._process_response_input(w) 1z
386 self._check_same_length(x_train, w) 1z
388 # check data types are correct for continuous/binary/multivariate regression
389 outcome_type, binary_mask = self._check_type_settings(y_train, outcome_type, w) 1bca
391 # process sparsity settings
392 theta, a, b, rho = self._process_sparsity_settings( 1bca
393 x_train, sparse, theta, a, b, rho
394 )
396 # process "standardization" settings
397 offset = self._process_offset_settings(y_train, binary_mask, offset) 1bca
398 leaf_prior_cov_inv = self._process_leaf_variance_settings( 1bca
399 y_train, binary_mask, k, num_trees, tau_num
400 )
401 error_cov_df, error_cov_scale, sigest = self._process_error_variance_settings( 1bca
402 x_train, y_train, outcome_type, binary_mask, sigest, sigdf, sigquant, lamda
403 )
405 # determine splits
406 splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo) 1bca
407 x_train = self._bin_predictors(x_train, splits) 1bca
409 # setup and run mcmc
410 initial_state = self._setup_mcmc( 1bca
411 x_train,
412 y_train,
413 outcome_type,
414 offset,
415 w,
416 max_split,
417 leaf_prior_cov_inv,
418 error_cov_df,
419 error_cov_scale,
420 power,
421 base,
422 maxdepth,
423 num_trees,
424 init_kw,
425 rm_const,
426 theta,
427 a,
428 b,
429 rho,
430 varprob,
431 num_chains,
432 num_chain_devices,
433 num_data_devices,
434 devices,
435 sparse,
436 nskip,
437 )
438 result = self._run_mcmc( 1bca
439 initial_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
440 )
442 # set public attributes
443 # set offset from the state because of buffer donation
444 self.offset = result.final_state.offset 1bca
445 self.sigest = sigest 1bca
447 # set private attributes
448 self._main_trace = result.main_trace 1bca
449 self._burnin_trace = result.burnin_trace 1bca
450 self._mcmc_state = result.final_state 1bca
451 self._splits = splits 1bca
452 self._x_train_fmt = x_train_fmt 1bca
453 self._binary_mask = binary_mask 1bca
455 def predict(
456 self,
457 x_test: Real[Array, 'p m'] | DataFrame | str,
458 *,
459 kind: PredictKind | str = 'mean',
460 key: Key[Array, ''] | None = None,
461 w: Float[Array, ' m'] | Series | None = None,
462 ) -> (
463 Float32[Array, ' m']
464 | Float32[Array, 'k m']
465 | Float32[Array, 'ndpost m']
466 | Float32[Array, 'ndpost k m']
467 ):
468 """
469 Compute predictions at `x_test`.
471 Parameters
472 ----------
473 x_test
474 The test predictors, or the string ``'train'`` to compute
475 predictions on the training data.
476 kind
477 The kind of output. See `PredictKind` for details.
478 key
479 Jax random key, required when ``kind='outcome_samples'``.
480 w
481 Per-observation error scale for ``kind='outcome_samples'``.
482 Required when the model was fit with weights and ``x_test`` is
483 new data.
485 Returns
486 -------
487 Predictions at `x_test` in the requested format.
489 Raises
490 ------
491 ValueError
492 If `x_test` has a different format than `x_train`, or if `w`
493 is specified when it should be `None`, or if `w` is not
494 specified when it is required.
496 """
497 # parse arguments
498 kind = PredictKind(kind) 1bga
499 if kind is PredictKind.outcome_samples and key is None: 499 ↛ 500line 499 didn't jump to line 500 because the condition on line 499 was never true1bgfea
500 msg = '`key` not specified'
501 raise ValueError(msg)
502 w = self._process_w_test(x_test, kind, w) 1bgfea
503 x_test = self._process_x_test(x_test, w) 1bga
505 # get latent i.e. bare sum-of-trees predictions
506 latent = self._predict(x_test) 1bga
507 if kind is PredictKind.latent_samples: 1bgfa
508 return latent 1bga
510 # sample posterior (uses latent directly, no probit squash needed)
511 binary_indices = self._mcmc_state.binary_indices 1gfa
512 if kind is PredictKind.outcome_samples: 1gfea
513 return self._sample_outcome(key, latent, binary_indices, w) 1fea
515 # squash predictions to (0, 1) if probit
516 if binary_indices is not None: 1gnfAa
517 indexing = jnp.s_[..., binary_indices, :] 1nA
518 mean_samples = latent.at[indexing].set(ndtr(latent[indexing])) 1nA
519 elif self._mcmc_state.binary_y is not None: 1gfa
520 mean_samples = ndtr(latent) 1g
521 else:
522 mean_samples = latent 1fa
524 # take mean or return samples
525 if kind is PredictKind.mean: 1gfea
526 return mean_samples.mean(axis=0) 1fea
527 return mean_samples 1gfa
529 @property
530 def ndpost(self) -> int:
531 """The total number of posterior samples after burn-in across all chains.
533 May be larger than the initialization argument `ndpost` if it was not
534 divisible by the number of chains.
535 """
536 return self._main_trace.grow_prop_count.size 1lgE
538 @property
539 def num_trees(self) -> int:
540 """Return the number of trees used in the model."""
541 return self._mcmc_state.forest.split_tree.shape[-2] 1789
543 def get_latent_prec(
544 self, only_continuous: bool = False
545 ) -> (
546 Float32[Array, ' nskip+ndpost']
547 | Float32[Array, 'nskip+ndpost k k']
548 | Float32[Array, 'num_chains nskip+ndpost/num_chains']
549 | Float32[Array, 'num_chains nskip+ndpost/num_chains k k']
550 ):
551 """Return the posterior samples of the latent error precision matrix.
553 Parameters
554 ----------
555 only_continuous
556 If `True` and the model has mixed binary-continuous outcomes,
557 return only the submatrix for the continuous components.
559 Returns
560 -------
561 MCMC samples of the error precision matrix.
563 Notes
564 -----
565 This method is meant to check for convergence, so it returns the full
566 MCMC trace and does not concatenate chains together. For probit
567 regression, this returns the precision of the latent error term, not
568 the Bernoulli precision for the binary outcome. For heteroskedastic
569 regression, the returned precision is the global precision parameter,
570 that would have to be divided by a squared weight to get the precision
571 on a given datapoint.
573 Raises
574 ------
575 ValueError
576 If `only_continuous` is `True` but the model has only binary
577 outcomes, so there is no continuous submatrix to return.
578 """
579 binary_indices = self._mcmc_state.binary_indices 1kcf
580 if ( 1ikcEPsf
581 only_continuous
582 and binary_indices is None
583 and self._mcmc_state.binary_y is not None
584 ):
585 msg = 'Model has only binary outcomes, so there is no continuous submatrix to return.' 1iP
586 raise ValueError(msg) 1iP
588 burnin = self._burnin_trace.error_cov_inv 1kcEsf
589 main = self._main_trace.error_cov_inv 1kcf
590 # trace shape is (chains?, samples, ...) where chains is optional
591 # first axis; samples is the axis to concatenate along
592 num_chains = get_num_chains(self._mcmc_state) 1kcf
593 sample_axis = 1 if num_chains is not None else 0 1kncsfe
594 prec = jnp.concatenate([burnin, main], axis=sample_axis) 1kncsfe
596 if only_continuous and binary_indices is not None: 1kcEsf
597 *_, k, _ = prec.shape 1s
598 mask = jnp.ones(k, dtype=bool).at[binary_indices].set(False) 1s
599 cont_indices = jnp.arange(k)[mask] 1s
600 prec = prec[..., cont_indices[:, None], cont_indices[None, :]] 1s
602 return prec 1kcEf
604 def get_error_sdev(
605 self, mean: bool = False
606 ) -> (
607 Float32[Array, 'ndpost']
608 | Float32[Array, 'ndpost k']
609 | Float32[Array, '']
610 | Float32[Array, ' k']
611 ):
612 """Return the error standard deviation, post-burnin, chains concatenated.
614 Parameters
615 ----------
616 mean
617 If `True`, average the precision matrix across samples first
618 (harmonic mean at the covariance matrix level), returning a single
619 scalar or vector instead of posterior samples.
621 Returns
622 -------
623 Posterior samples (or single estimate) of the error standard deviation; NaN for binary outcomes.
625 Notes
626 -----
627 Binary outcomes do have a standard deviation of course, but it's not
628 returned by this method because that would require to evaluate
629 predictions on a given X, since the Bernoulli variance is p(1-p).
630 """
631 # reshape operations
632 error_cov_inv = self._main_trace.error_cov_inv 1kcf
633 if error_cov_inv.ndim in (2, 4): 1kncUfe
634 # shape (chains, samples) or (chains, samples, k, k), concatenate chains
635 error_cov_inv = lax.collapse(error_cov_inv, 0, 2) 1nUe
636 is_uv = error_cov_inv.ndim == 1 1kcf
637 if mean: 1kcfe
638 error_cov_inv = error_cov_inv.mean(0) 1fe
639 if is_uv: 639 ↛ 641line 639 didn't jump to line 641 because the condition on line 639 was never true1kcf
640 # univariate case, reshape to 1x1 matrix
641 error_cov_inv = error_cov_inv[..., None, None]
643 # compute sdev and fill in nans for binary outcomes
644 cov = _inv_via_chol_with_gersh(error_cov_inv) 1kcf
645 sdev = jnp.sqrt(jnp.diagonal(cov, axis1=-2, axis2=-1)) 1kcf
646 if is_uv: 646 ↛ 647line 646 didn't jump to line 647 because the condition on line 646 was never true1kcf
647 sdev = sdev.squeeze(-1)
648 with debug_nans(False): 1kcf
649 return jnp.where(self._binary_mask, jnp.nan, sdev) 1kcf
651 @cached_property
652 def varcount(self) -> Int32[Array, 'ndpost p']:
653 """Histogram of predictor usage for decision rules in the trees."""
654 p = self._mcmc_state.forest.max_split.size 1lea
655 return varcount(p, self._main_trace) 1lea
657 @cached_property
658 def varcount_mean(self) -> Float32[Array, ' p']:
659 """Average of `varcount` across MCMC iterations."""
660 return self.varcount.mean(axis=0) 1Vea
662 @cached_property
663 def varprob(self) -> Float32[Array, 'ndpost p']:
664 """Posterior samples of the probability of choosing each predictor for a decision rule."""
665 max_split = self._mcmc_state.forest.max_split 1lea
666 p = max_split.size 1lea
667 varprob = self._main_trace.varprob 1lea
668 if varprob is None: 1lFQea
669 peff = jnp.count_nonzero(max_split) 1Fe
670 varprob = jnp.where(max_split, 1 / peff, 0) 1Fe
671 varprob = jnp.broadcast_to(varprob, (self.ndpost, p)) 1Fe
672 else:
673 varprob = varprob.reshape(-1, p) 1lQa
674 return varprob 1lFQea
676 @cached_property
677 def varprob_mean(self) -> Float32[Array, ' p']:
678 """The marginal posterior probability of each predictor being chosen for a decision rule."""
679 return self.varprob.mean(axis=0) 1Vea
681 def _sample_outcome(
682 self,
683 key: Key[Array, ''],
684 latent: Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'],
685 binary_indices: Int32[Array, ' kb'] | None,
686 w: Float32[Array, ' m'] | None,
687 ) -> Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m']:
688 """Sample from the posterior predictive distribution."""
689 if latent.ndim > 2: # multivariate case 689 ↛ 700line 689 didn't jump to line 700 because the condition on line 689 was always true1fea
690 error_cov_inv = self._main_trace.error_cov_inv 1fea
691 error_cov_inv = lax.collapse(error_cov_inv, 0, -2) 1fea
693 # Cholesky of precision: error_cov_inv = L @ L^T
694 L = chol_with_gersh(error_cov_inv) # (ndpost, k, k) 1fea
696 # Sample z ~ N(0, I) and solve L^T @ error = z
697 # so error = L^{-T} z ~ N(0, L^{-T} L^{-1}) = N(0, Sigma)
698 z = random.normal(key, latent.shape) # (ndpost, k, m) 1fea
699 error = solve_triangular(L, z, trans='T', lower=True) 1fea
700 elif self._mcmc_state.binary_y is not None:
701 # pure binary UV: probit has sigma = 1
702 error = random.normal(key, latent.shape)
703 else: # univariate continuous
704 sigma = jnp.sqrt(jnp.reciprocal(self._main_trace.error_cov_inv)).reshape(-1)
705 error = sigma[..., None] * random.normal(key, latent.shape)
706 if w is not None:
707 error *= w[None, :]
709 outcome = latent + error 1fea
711 # convert binary outcomes via latent probit thresholding
712 if binary_indices is not None: 1nfeAa
713 idx = jnp.s_[..., binary_indices, :] 1nA
714 outcome = outcome.at[idx].set(jnp.where(outcome[idx] > 0, 1.0, 0.0)) 1nA
715 elif self._mcmc_state.binary_y is not None: 1fea
716 outcome = jnp.where(outcome > 0, 1.0, 0.0) 1e
718 return outcome 1nfeAa
720 def _process_w_test(
721 self,
722 x_test: Real[Array, 'p m'] | DataFrame | str,
723 kind: PredictKind,
724 w: Float[Array, ' m'] | Series | None,
725 ) -> Float32[Array, ' m'] | None:
726 """Validate and resolve the error weights for prediction.
728 Parameters
729 ----------
730 x_test
731 The raw (not yet processed) test predictors, or ``'train'``.
732 kind
733 The prediction kind.
734 w
735 User-provided per-observation error scale, or `None`.
737 Returns
738 -------
739 The resolved error scale as a float32 array, or `None` if weights
740 are not applicable.
742 Raises
743 ------
744 ValueError
745 If `w` is specified when it should be `None`, or missing when
746 required.
748 """
749 x_test_is_train = isinstance(x_test, str) and x_test == 'train' 1lbgea
750 has_train_weights = self._mcmc_state.prec_scale is not None 1lbgea
751 is_binary = self._mcmc_state.binary_y is not None 1bga
752 is_multivariate = self._mcmc_state.offset.ndim == 1 1bga
753 needs_weights = ( 1bgnfea
754 kind is PredictKind.outcome_samples
755 and not is_binary
756 and not is_multivariate
757 and has_train_weights
758 )
760 if not needs_weights: 760 ↛ 771line 760 didn't jump to line 771 because the condition on line 760 was always true1bgnfea
761 if w is not None: 761 ↛ 762line 761 didn't jump to line 762 because the condition on line 761 was never true1bga
762 msg = (
763 '`w` must be `None` in this configuration'
764 " (it is used only with kind='outcome_samples',"
765 ' univariate continuous regression fitted with'
766 ' weights)'
767 )
768 raise ValueError(msg)
769 return None 1bga
771 if x_test_is_train:
772 if w is not None:
773 msg = (
774 "`w` must be `None` when x_test='train'"
775 ' (training weights are used automatically)'
776 )
777 raise ValueError(msg)
778 return jnp.reciprocal(jnp.sqrt(self._mcmc_state.prec_scale))
780 # new test data, model was fit with weights
781 if w is None:
782 msg = (
783 '`w` is required because the model was fit with'
784 ' weights and x_test is new data'
785 )
786 raise ValueError(msg)
787 return self._process_response_input(w)
789 def _process_x_test(
790 self,
791 x_test: Real[Array, 'p m'] | DataFrame | str,
792 w: Float32[Array, ' m'] | None,
793 ) -> UInt[Array, 'p m']:
794 """Convert x_test to binned format suitable for prediction."""
795 if isinstance(x_test, str): 1lbgea
796 if x_test != 'train': 796 ↛ 797line 796 didn't jump to line 797 because the condition on line 796 was never true1lga
797 msg = (
798 f"x_test must be an array, a DataFrame, or 'train', got {x_test!r}"
799 )
800 raise ValueError(msg)
801 return self._mcmc_state.X 1lga
802 x_test, x_test_fmt = self._process_predictor_input(x_test) 1bea
803 if x_test_fmt != self._x_train_fmt: 1bRDea
804 msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}' 1RD
805 raise ValueError(msg) 1RD
806 if w is not None: 806 ↛ 807line 806 didn't jump to line 807 because the condition on line 806 was never true1bea
807 self._check_same_length(w, x_test)
808 return self._bin_predictors(x_test, self._splits) 1bea
810 @staticmethod
811 def _process_predictor_input(
812 x: Real[Any, 'p n'] | DataFrame,
813 ) -> tuple[Shaped[Array, 'p n'], Any]:
814 if hasattr(x, 'columns'): 1bGDca
815 fmt = dict(kind='dataframe', columns=x.columns) 1GD
816 x = x.to_numpy().T 1GD
817 else:
818 fmt = dict(kind='array', num_covar=x.shape[0]) 1bca
819 x = jnp.asarray(x) 1bca
820 assert x.ndim == 2 1bca
821 return x, fmt 1bca
823 @staticmethod
824 def _process_response_input(
825 y: Shaped[Array, ' n'] | Shaped[Array, 'k n'] | Series,
826 ) -> Float32[Array, ' n'] | Float32[Array, 'k n']:
827 if hasattr(y, 'to_numpy'): 1bGca
828 y = y.to_numpy() 1G
829 y = jnp.asarray(y, jnp.float32) 1bca
830 if y.ndim < 1 or y.ndim > 2: 830 ↛ 831line 830 didn't jump to line 831 because the condition on line 830 was never true1bca
831 msg = f'y_train must be 1D (n,) or 2D (k, n). Got {y.ndim=}.'
832 raise ValueError(msg)
833 return y 1bca
835 @staticmethod
836 def _check_same_length(x1: Array, x2: Array) -> None:
837 get_length = lambda x: x.shape[-1] 1bca
838 assert get_length(x1) == get_length(x2) 1bca
840 @classmethod
841 def _process_error_variance_settings(
842 cls,
843 x_train: Shaped[Array, 'p n'],
844 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'],
845 outcome_type: OutcomeType | tuple[OutcomeType, ...],
846 binary_mask: Bool[Array, ''] | Bool[Array, ' k'],
847 sigest: FloatLike | Float[Array, ' k'] | None,
848 sigdf: FloatLike,
849 sigquant: FloatLike,
850 lamda: FloatLike | Float[Array, ' k'] | None,
851 ) -> tuple[
852 Float32[Array, ''] | None,
853 Float32[Array, ''] | Float32[Array, 'k k'] | None,
854 Float32[Array, ''] | Float32[Array, ' k'] | None,
855 ]:
856 """Return (error_cov_df, error_cov_scale, sigest)."""
857 if outcome_type is OutcomeType.binary: 1bjhca
858 if sigest is not None or lamda is not None: 858 ↛ 859line 858 didn't jump to line 859 because the condition on line 858 was never true1jh
859 msg = 'Let `sigest=None` and `lamda=None` for binary regression'
860 raise ValueError(msg)
861 return None, None, None 1jh
863 if lamda is None: 1btca
864 # estimate sigest²
865 sigest2 = cls._estimate_sigest2(x_train, y_train, sigest, binary_mask) 1bca
866 sigest = jnp.sqrt(sigest2) 1bca
868 # lamda from sigest²
869 alpha = sigdf / 2 1bca
870 invchi2 = invgamma.ppf(sigquant, alpha) / 2 1bca
871 invchi2rid = invchi2 * sigdf 1bca
872 lamda = sigest2 / invchi2rid 1bca
874 elif sigest is not None: 874 ↛ 875line 874 didn't jump to line 875 because the condition on line 874 was never true1t
875 msg = 'Let `sigest=None` if `lamda` is specified'
876 raise ValueError(msg)
878 else:
879 lamda = jnp.where(binary_mask, 0.0, lamda) 1t
881 # params written in multivariate form
882 if y_train.ndim == 2: 1bBcaH
883 k = y_train.shape[0] 1Bca
884 lamda = jnp.broadcast_to(lamda, (k,)) 1Bca
885 error_cov_df = jnp.asarray(sigdf) + k - 1 1Bca
886 error_cov_scale = jnp.diag(sigdf * lamda) 1Bca
887 else:
888 error_cov_df = jnp.asarray(sigdf) 1bH
889 error_cov_scale = jnp.asarray(sigdf * lamda) 1bH
891 return error_cov_df, error_cov_scale, sigest 1bca
893 @classmethod
894 def _estimate_sigest2(
895 cls,
896 x_train: Shaped[Array, 'p n'],
897 y_train: Float32[Array, '*k n'],
898 sigest: float | Shaped[Array, '*k'] | None,
899 binary_mask: Bool[Array, ''] | Bool[Array, ' k'],
900 ) -> Float32[Array, '*k']:
901 n = y_train.shape[-1] 1bca
902 if sigest is not None: 1btca
903 sigest2 = jnp.square(jnp.asarray(sigest, dtype=jnp.float32)) 1t
904 sigest2 = jnp.broadcast_to(sigest2, y_train.shape[:-1]) 1t
905 elif n < 2: 1bucWa
906 sigest2 = jnp.ones(y_train.shape[:-1]) 1uW
907 elif n <= x_train.shape[0]: 1bXca
908 sigest2 = jnp.var(y_train, axis=-1) 1X
909 else:
910 sigest2 = cls._linear_regression(x_train, y_train) 1bca
911 return jnp.where(binary_mask, 0.0, sigest2) 1bca
913 @staticmethod
914 @jit
915 def _linear_regression(
916 x_train: Shaped[Array, 'p n'],
917 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'],
918 ) -> Float32[Array, ''] | Float32[Array, ' k']:
919 """Return the error variance estimated with OLS with intercept."""
920 x_centered = x_train.T - x_train.mean(axis=1) 1bca
921 y_centered = y_train.T - y_train.mean(axis=-1) 1bca
922 # centering is equivalent to adding an intercept column
923 _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered) 1bca
924 chisq = chisq.reshape(y_train.shape[:-1]) 1bca
925 dof = y_train.shape[-1] - rank 1bca
926 return chisq / dof 1bca
928 @staticmethod
929 def _check_type_settings(
930 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'],
931 outcome_type: OutcomeType | str | Sequence[OutcomeType | str],
932 w: Float[Array, ' n'] | None,
933 ) -> tuple[
934 OutcomeType | tuple[OutcomeType, ...], Bool[Array, ''] | Bool[Array, ' k']
935 ]:
936 # standardize outcome_type to OutcomeType or tuple[OutcomeType, ...]
937 if isinstance(outcome_type, Sequence) and not isinstance(outcome_type, str): 1bmicva
938 outcome_type = tuple(OutcomeType(t) for t in outcome_type) 1miv
939 num_types = len(outcome_type) 1miv
940 if len(set(outcome_type)) == 1: 1Imiv
941 outcome_type = outcome_type[0] 1I
942 else:
943 num_types = None 1bca
944 outcome_type = OutcomeType(outcome_type) 1bca
946 # validation
947 if num_types is not None and ( 1bIYmicva
948 y_train.ndim != 2 or num_types != y_train.shape[0]
949 ):
950 msg = ( 1IY
951 f'Sequence outcome_type of length {num_types}'
952 f' requires y_train.shape=({num_types}, n),'
953 f' found {y_train.shape=}.'
954 )
955 raise ValueError(msg) 1I
956 if w is not None and not ( 1bzZSmicva
957 outcome_type is OutcomeType.continuous and y_train.ndim == 1
958 ):
959 msg = 'Weights are only supported for univariate continuous regression.' 1ZS
960 raise ValueError(msg) 1S
962 if isinstance(outcome_type, tuple): 1bzmicva
963 binary_mask = jnp.array([t is OutcomeType.binary for t in outcome_type]) 1miv
964 else:
965 binary_mask = jnp.bool_(outcome_type is OutcomeType.binary) 1bca
966 binary_mask = jnp.broadcast_to(binary_mask, y_train.shape[:-1]) 1bmicva
968 return outcome_type, binary_mask 1bca
970 @staticmethod
971 def _process_sparsity_settings(
972 x_train: Real[Array, 'p n'],
973 sparse: bool,
974 theta: FloatLike | None,
975 a: FloatLike,
976 b: FloatLike,
977 rho: FloatLike | None,
978 ) -> (
979 tuple[None, None, None, None]
980 | tuple[FloatLike, None, None, None]
981 | tuple[None, FloatLike, FloatLike, FloatLike]
982 ):
983 """Return (theta, a, b, rho)."""
984 if not sparse: 1wbxhca
985 return None, None, None, None 1wxh
986 elif theta is not None: 1bzica
987 return theta, None, None, None 1zi
988 else:
989 if rho is None: 989 ↛ 992line 989 didn't jump to line 992 because the condition on line 989 was always true1bca
990 p, _ = x_train.shape 1bca
991 rho = float(p) 1bca
992 return None, a, b, rho 1bca
994 @staticmethod
995 def _process_offset_settings(
996 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'],
997 binary_mask: Bool[Array, ''] | Bool[Array, ' k'],
998 offset: float | Float32[Any, ''] | Float32[Any, ' k'] | None,
999 ) -> Float32[Array, ''] | Float32[Array, ' k']:
1000 """Return offset."""
1001 if offset is not None: 1btca
1002 off = jnp.asarray(offset, jnp.float32) 1t
1003 return jnp.broadcast_to(off, y_train.shape[:-1]) 1t
1004 if y_train.shape[-1] < 1: 1bucCa
1005 return jnp.zeros(y_train.shape[:-1]) 1uC
1007 bound = 1 / (1 + y_train.shape[-1]) 1bca
1008 binary_offset = ndtri(jnp.clip((y_train != 0).mean(-1), bound, 1 - bound)) 1bca
1009 continuous_offset = y_train.mean(-1) 1bca
1010 return jnp.where(binary_mask, binary_offset, continuous_offset) 1bca
1012 @staticmethod
1013 def _process_leaf_variance_settings(
1014 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'],
1015 binary_mask: Bool[Array, ''] | Bool[Array, ' k'],
1016 k: FloatLike,
1017 num_trees: int,
1018 tau_num: FloatLike | None,
1019 ) -> Float32[Array, ''] | Float32[Array, 'k k']:
1020 """Return `leaf_prior_cov_inv`."""
1021 # determine `tau_num` if not specified
1022 if tau_num is None: 1022 ↛ 1030line 1022 didn't jump to line 1030 because the condition on line 1022 was always true1bca
1023 if y_train.shape[-1] < 2: 1bucCa
1024 continuous_tau = jnp.ones(y_train.shape[:-1]) 1uC
1025 else:
1026 continuous_tau = (y_train.max(-1) - y_train.min(-1)) / 2 1bca
1027 tau_num = jnp.where(binary_mask, 3.0, continuous_tau) 1bca
1029 # leaf prior standard deviation
1030 sigma_mu = tau_num / (k * math.sqrt(num_trees)) 1bca
1032 # leaf prior precision matrix
1033 leaf_prior_cov_inv = jnp.reciprocal(jnp.square(sigma_mu)) 1bca
1034 if y_train.ndim == 2: 1bBcaH0
1035 leaf_prior_cov_inv = jnp.diag( 1Bca
1036 jnp.broadcast_to(leaf_prior_cov_inv, y_train.shape[:-1])
1037 )
1038 return leaf_prior_cov_inv 1bcaH0
1040 @staticmethod
1041 def _determine_splits(
1042 x_train: Real[Array, 'p n'],
1043 usequants: bool,
1044 numcut: int,
1045 xinfo: Float[Array, 'p n'] | None,
1046 ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
1047 if xinfo is not None: 1bucCa
1048 if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]: 1uTC
1049 msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)' 1T
1050 raise ValueError(msg) 1T
1051 return prepcovars.parse_xinfo(xinfo) 1uC
1052 elif usequants: 1bjhca
1053 return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1) 1jh
1054 else:
1055 return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1) 1bca
1057 @staticmethod
1058 def _bin_predictors(
1059 x: Real[Array, 'p n'], splits: Real[Array, 'p max_num_splits']
1060 ) -> UInt[Array, 'p n']:
1061 return prepcovars.bin_predictors(x, splits) 1bca
1063 @staticmethod
1064 def _setup_mcmc(
1065 x_train: Real[Array, 'p n'],
1066 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'],
1067 outcome_type: OutcomeType | tuple[OutcomeType, ...],
1068 offset: Float32[Array, ''] | Float32[Array, ' k'],
1069 w: Float[Array, ' n'] | None,
1070 max_split: UInt[Array, ' p'],
1071 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
1072 error_cov_df: FloatLike | None,
1073 error_cov_scale: FloatLike | Float32[Array, 'k k'] | None,
1074 power: FloatLike,
1075 base: FloatLike,
1076 maxdepth: int,
1077 num_trees: int,
1078 init_kw: Mapping[str, Any],
1079 rm_const: bool,
1080 theta: FloatLike | None,
1081 a: FloatLike | None,
1082 b: FloatLike | None,
1083 rho: FloatLike | None,
1084 varprob: Float[Any, ' p'] | None,
1085 num_chains: int | None,
1086 num_chain_devices: int | None,
1087 num_data_devices: int | None,
1088 devices: Device | Sequence[Device] | None,
1089 sparse: bool,
1090 nskip: int,
1091 ) -> mcmcstep.State:
1092 p_nonterminal = make_p_nonterminal(maxdepth, base, power) 1bca
1094 # process device settings
1095 device_kw, device = process_device_settings( 1bca
1096 y_train, num_chains, num_chain_devices, num_data_devices, devices
1097 )
1099 kw: dict = dict( 1wbxhca
1100 X=x_train,
1101 # copy y_train because it's going to be donated in the mcmc loop
1102 y=jnp.array(y_train),
1103 outcome_type=outcome_type,
1104 offset=offset,
1105 # copy w because it's going to be donated in init
1106 error_scale=None if w is None else jnp.array(w),
1107 max_split=max_split,
1108 num_trees=num_trees,
1109 p_nonterminal=p_nonterminal,
1110 leaf_prior_cov_inv=leaf_prior_cov_inv,
1111 error_cov_df=error_cov_df,
1112 error_cov_scale=error_cov_scale,
1113 min_points_per_decision_node=10,
1114 log_s=process_varprob(varprob, max_split),
1115 theta=theta,
1116 a=a,
1117 b=b,
1118 rho=rho,
1119 sparse_on_at=nskip // 2 if sparse else None,
1120 **device_kw,
1121 )
1123 if rm_const: 1wboxhcpa
1124 n_empty = jnp.sum(max_split == 0).item() 1bca
1125 kw.update(filter_splitless_vars=n_empty) 1bca
1127 kw.update(init_kw) 1bocpa
1129 state = mcmcstep.init(**kw) 1bca
1131 # put state on device if requested explicitly by the user
1132 if device is not None: 1bocpa
1133 state = device_put(state, device, donate=True) 1op
1135 return state 1bca
1137 @classmethod
1138 def _run_mcmc(
1139 cls,
1140 mcmc_state: mcmcstep.State,
1141 ndpost: int,
1142 nskip: int,
1143 keepevery: int,
1144 printevery: int | None,
1145 seed: int | Integer[Array, ''] | Key[Array, ''],
1146 run_mcmc_kw: Mapping,
1147 ) -> RunMCMCResult:
1148 # prepare random generator seed
1149 if is_key(seed): 1b1ca
1150 key = jnp.copy(seed) 1bca
1151 else:
1152 key = jax.random.key(seed) 11
1154 # round up ndpost
1155 num_chains = get_num_chains(mcmc_state) 1bca
1156 if num_chains is None: 1bMhca
1157 num_chains = 1 1Mca
1158 n_save = ndpost // num_chains + bool(ndpost % num_chains) 1bhca
1160 # prepare arguments
1161 kw: dict = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery) 1bca
1162 kw.update( 1b2ohc3a
1163 mcmcloop.make_default_callback(
1164 mcmc_state,
1165 dot_every=None if printevery is None or printevery == 1 else 1,
1166 report_every=printevery,
1167 )
1168 )
1169 kw.update(run_mcmc_kw) 1b2ohc3a
1171 return run_mcmc(key, mcmc_state, n_save, **kw) 1bca
1173 def _predict(
1174 self, x: UInt[Array, 'p m']
1175 ) -> Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m']:
1176 """Evaluate trees on already quantized `x`."""
1177 return predict(x, self._main_trace) 1bga
1179 def check_trees(
1180 self, error: bool = False
1181 ) -> UInt[Array, 'num_chains ndpost/num_chains num_trees']:
1182 """Apply `bartz.grove.check_trace` to all the tree draws.
1184 Parameters
1185 ----------
1186 error
1187 If `True`, throw an error if any invalid trees are found.
1189 Returns
1190 -------
1191 An array where non-zero entries indicate invalid trees.
1193 Raises
1194 ------
1195 RuntimeError
1196 If `error` is `True` and any invalid trees are found.
1197 """
1198 out: UInt[Array, '*chains samples num_trees']
1199 out = check_trace(self._main_trace, self._mcmc_state.forest.max_split) 1bca
1200 if out.ndim < 3: 1bMhca
1201 out = out[None, :, :] 1Mca
1202 if error: 1202 ↛ 1207line 1202 didn't jump to line 1207 because the condition on line 1202 was always true1bhca
1203 bad_count = jnp.count_nonzero(out) 1bca
1204 if bad_count > 0: 1204 ↛ 1205line 1204 didn't jump to line 1205 because the condition on line 1204 was never true1bca
1205 msg = f'Found {bad_count} invalid trees in the MCMC trace.'
1206 raise RuntimeError(msg)
1207 return out 1bca
1209 def check_replicated_trees(self) -> None:
1210 """Check that the trees are equal across data-sharded devices.
1212 If the data is sharded across devices, verify that the trees (which
1213 should be replicated) are identical on all shards.
1215 Raises
1216 ------
1217 RuntimeError
1218 If the trees differ across devices.
1219 """
1220 state = self._mcmc_state 1bca
1221 mesh = state.config.mesh 1bca
1222 if mesh is not None and 'data' in mesh.axis_names: 1qbjmhica
1223 replicated_forest = replace(state.forest, leaf_indices=None) 1jh
1224 equal = equal_shards( 1jh
1225 replicated_forest, 'data', in_specs=PartitionSpec(), mesh=mesh
1226 )
1227 equal_array = jnp.stack(tree.leaves(equal)) 1jh
1228 all_equal = jnp.all(equal_array) 1jh
1229 if not all_equal.item(): 1229 ↛ 1230line 1229 didn't jump to line 1230 because the condition on line 1229 was never true1jh
1230 msg = 'The trees differ across data-sharded devices.'
1231 raise RuntimeError(msg)
1233 def compare_resid(
1234 self, y: Float32[Array, ' n'] | Float32[Array, 'k n'] | None = None
1235 ) -> tuple[
1236 Float32[Array, '*num_chains n'] | Float32[Array, '*num_chains k n'],
1237 Float32[Array, '*num_chains n'] | Float32[Array, '*num_chains k n'],
1238 ]:
1239 """Re-compute residuals to compare them with the updated ones.
1241 Parameters
1242 ----------
1243 y
1244 The response variable. Required for continuous regression (since
1245 ``State`` does not store ``y`` in continuous mode). Ignored for
1246 binary regression (where ``State.z`` is used instead).
1248 Returns
1249 -------
1250 resid1
1251 The final state of the residuals updated during the MCMC.
1252 resid2
1253 The residuals computed from the final state of the trees.
1254 """
1255 state = self._mcmc_state 1ry
1256 resid1 = state.resid 1ry
1258 forests = TreesTrace.from_dataclass(state.forest) 1ry
1259 trees = evaluate_forest(state.X, forests, sum_batch_axis=-1) 1ry
1261 if state.binary_indices is not None: 1ryJ
1262 # mixed binary-continuous: z has only binary rows, y has all rows
1263 assert y is not None, 'y is required for mixed regression' 1J
1264 ref = jnp.asarray(y) 1J
1265 ref = jnp.broadcast_to(ref, state.resid.shape) 1J
1266 ref = ref.at[..., state.binary_indices, :].set(state.z) 1J
1267 elif state.z is not None: 1r4y
1268 ref = state.z 14y
1269 else:
1270 assert y is not None, 'y is required for continuous regression' 1r
1271 ref = jnp.asarray(y) 1r
1272 resid2 = ref - (trees + state.offset[..., None]) 1ry
1274 return resid1, resid2 1ry
1276 def depth_distr(self) -> Int32[Array, '*num_chains ndpost/num_chains d']:
1277 """Histogram of tree depths for each state of the trees.
1279 Returns
1280 -------
1281 A matrix where each row contains a histogram of tree depths.
1282 """
1283 out: Int32[Array, '*chains samples d']
1284 out = forest_depth_distr(self._main_trace.split_tree) 1N
1285 if out.ndim < 3: 1285 ↛ 1287line 1285 didn't jump to line 1287 because the condition on line 1285 was always true1N
1286 out = out[None, :, :] 1N
1287 return out 1N
1289 def _points_per_node_distr(
1290 self, node_type: str
1291 ) -> Int32[Array, '*num_chains ndpost/num_chains n+1']:
1292 out: Int32[Array, '*chains samples n+1']
1293 out = points_per_node_distr( 1KO
1294 self._mcmc_state.X,
1295 self._main_trace.var_tree,
1296 self._main_trace.split_tree,
1297 node_type,
1298 sum_batch_axis=-1,
1299 )
1300 if out.ndim < 3: 1K5O
1301 out = out[None, :, :] 1K
1302 return out 1K5O
1304 def points_per_decision_node_distr(
1305 self,
1306 ) -> Int32[Array, '*num_chains ndpost/num_chains n+1']:
1307 """Histogram of number of points belonging to parent-of-leaf nodes.
1309 Returns
1310 -------
1311 For each chain, a matrix where each row contains a histogram of number of points.
1312 """
1313 return self._points_per_node_distr('leaf-parent') 1KO
1315 def points_per_leaf_distr(
1316 self,
1317 ) -> Int32[Array, '*num_chains ndpost/num_chains n+1']:
1318 """Histogram of number of points belonging to leaves.
1320 Returns
1321 -------
1322 A matrix where each row contains a histogram of number of points.
1323 """
1324 return self._points_per_node_distr('leaf') 1!#
1327@partial(jit, static_argnames='p')
1328# this is jitted such that lax.collapse below does not create a copy
1329def varcount(p: int, trace: mcmcloop.MainTrace) -> Int32[Array, 'ndpost p']:
1330 """Histogram of predictor usage for decision rules in the trees, squashing chains."""
1331 varcount: Int32[Array, '*chains samples p']
1332 varcount = compute_varcount(p, trace) 1lea
1333 return lax.collapse(varcount, 0, -1) 1lea
1336@jit
1337# this is jitted such that lax.collapse below does not create a copy
1338def predict(
1339 x: UInt[Array, 'p m'], trace: mcmcloop.MainTrace
1340) -> Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m']:
1341 """Evaluate trees on already quantized `x`, and squash chains."""
1342 out = evaluate_trace(x, trace) 1bga
1343 # For MV, out has shape (*trace_shape, k, n); for UV, (*trace_shape, n).
1344 # We must collapse only the chain/sample dims, not k.
1345 # Detect MV: leaf_tree has an extra axis compared to split_tree.
1346 is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim 1bga
1347 end = -2 if is_mv else -1 1b6ga
1348 return lax.collapse(out, 0, end) 1b6ga
1351class DeviceKwArgs(TypedDict):
1352 num_chains: int | None
1353 mesh: Mesh | None
1354 target_platform: Literal['cpu', 'gpu'] | None
1357def process_device_settings(
1358 y_train: Array,
1359 num_chains: int | None,
1360 num_chain_devices: int | None,
1361 num_data_devices: int | None,
1362 devices: Device | Sequence[Device] | None,
1363) -> tuple[DeviceKwArgs, Device | None]:
1364 """Return the arguments for `mcmcstep.init` related to devices, and an optional device where to put the state."""
1365 # determine devices
1366 if devices is not None: 1bocpa
1367 if not hasattr(devices, '__len__'): 1367 ↛ 1368line 1367 didn't jump to line 1368 because the condition on line 1367 was never true1op
1368 devices = (devices,)
1369 device = devices[0] 1op
1370 platform = device.platform 1op
1371 elif hasattr(y_train, 'platform'): 1371 ↛ 1379line 1371 didn't jump to line 1379 because the condition on line 1371 was always true1bca
1372 platform = y_train.platform() 1bca
1373 device = None 1bca
1374 # set device=None because if the devices were not specified explicitly
1375 # we may be in the case where computation will follow data placement,
1376 # do not disturb jax as the user may be playing with vmap, jit, reshard...
1377 devices = jax.devices(platform) 1bca
1378 else:
1379 msg = 'not possible to infer device from `y_train`, please set `devices`'
1380 raise ValueError(msg)
1382 # create mesh
1383 if num_chain_devices is None and num_data_devices is None: 1qbjmhica
1384 mesh = None 1bca
1385 else:
1386 mesh = dict() 1qjmhi
1387 if num_chain_devices is not None: 1qjmhi
1388 mesh.update(chains=num_chain_devices) 1qmi
1389 if num_data_devices is not None: 1qjmhi
1390 mesh.update(data=num_data_devices) 1jh
1391 mesh = make_mesh( 1qjmhi
1392 axis_shapes=tuple(mesh.values()),
1393 axis_names=tuple(mesh),
1394 axis_types=(AxisType.Auto,) * len(mesh),
1395 devices=devices,
1396 )
1397 device = None 1qjh
1398 # set device=None because `mcmcstep.init` will `device_put` with the
1399 # mesh already, we don't want to undo its work
1401 # prepare arguments to `init`
1402 settings = DeviceKwArgs( 1qbjohcpa
1403 num_chains=num_chains,
1404 mesh=mesh,
1405 target_platform=None
1406 if mesh is not None or hasattr(y_train, 'platform')
1407 else platform,
1408 # here we don't take into account the case where the user has set both
1409 # batch sizes; since the user has to be playing with `init_kw` to do
1410 # that, we'll let `init` throw the error and the user set
1411 # `target_platform` themselves so they have a clearer idea how the
1412 # thing works.
1413 )
1415 return settings, device 1qbjohcpa
1418def process_varprob(
1419 varprob: Float[Any, ' p'] | None, max_split: UInt[Array, ' p']
1420) -> Float32[Array, ' p'] | None:
1421 """Convert varprob to log_s."""
1422 if varprob is None: 1wbxLca
1423 return None 1bca
1424 varprob = jnp.asarray(varprob) 1wxL
1425 assert varprob.shape == max_split.shape, 'varprob must have shape (p,)' 1wxL
1426 varprob = error_if(varprob, varprob <= 0, 'varprob must be > 0') 1wxL
1427 return jnp.log(varprob) 1wxL