Coverage for src/bartz/_interface.py: 95%
628 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/_interface.py
2#
3# Copyright (c) 2025-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Main high-level interface of the package."""
27import math
28import pickle
29from collections.abc import Collection, Hashable, Mapping, Sequence
30from dataclasses import replace
31from enum import Enum
32from functools import cached_property
33from os import PathLike, cpu_count
34from pathlib import Path
36# WORKAROUND(python<3.15): use frozendict instead of MappingProxyType
37from types import MappingProxyType
38from typing import Any, Literal, Protocol, TypedDict, overload, runtime_checkable
39from warnings import warn
41import jax
42import jax.numpy as jnp
43from equinox import Module, error_if, field, tree_at
44from jax import Device, debug_nans, device_put, lax, make_mesh, random, tree
45from jax.scipy.linalg import solve_triangular
46from jax.scipy.special import ndtr, ndtri
47from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec
48from jax.typing import DTypeLike
49from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real, Shaped, UInt
50from numpy import ndarray
52from bartz._jaxext import equal_shards, is_key, jit, split
53from bartz.grove import (
54 TreeHeaps,
55 TreesTrace,
56 check_trace,
57 evaluate_forest,
58 forest_depth_distr,
59 format_tree,
60 points_per_node_distr,
61)
62from bartz.mcmcloop import (
63 BurninTrace,
64 MainTrace,
65 RunMCMCResult,
66 compute_varcount,
67 evaluate_trace,
68 make_print_callback,
69 make_tqdm_callback,
70 run_mcmc,
71)
72from bartz.mcmcstep import DiagWishart, OutcomeType, Wishart, make_p_nonterminal
73from bartz.mcmcstep._axes import (
74 chain_to_axis,
75 chain_vmap_axes,
76 chainful_axis,
77 get_has_chains,
78 trace_sample_axes,
79)
80from bartz.mcmcstep._state import (
81 ArrayLike,
82 FloatLike,
83 State,
84 _inv_via_chol_with_gersh,
85 _leaf_partition_spec,
86 chol_with_gersh,
87 init,
88)
89from bartz.prepcovars import Binner, BinnerFactory, UniqueQuantileBinner
92class PredictKind(Enum):
93 """Kind of output of `Bart.predict`."""
95 mean = 'mean'
96 """The posterior mean of the conditional mean, shape ``(m,)`` (or
97 ``(k, m)`` for multivariate regression)."""
99 mean_samples = 'mean_samples'
100 """Per-sample conditional mean, shape ``(num_chains * n_save, m)``
101 (or ``(num_chains * n_save, k, m)``). For binary regression, this is
102 the probit-transformed sum-of-trees."""
104 outcome_samples = 'outcome_samples'
105 """Samples of the outcome variable, shape ``(num_chains * n_save,
106 m)`` (or ``(num_chains * n_save, k, m)``). For binary regression,
107 these are Bernoulli draws. For continuous regression, these are
108 Gaussian draws with the posterior noise variance."""
110 latent_samples = 'latent_samples'
111 """Raw sum-of-trees values, shape ``(num_chains * n_save, m)`` (or
112 ``(num_chains * n_save, k, m)``)."""
115@runtime_checkable
116class DataFrame(Protocol):
117 """DataFrame duck-type for `Bart`."""
119 @property
120 def columns(self) -> Collection[str]:
121 """The names of the columns."""
122 ...
124 def to_numpy(self) -> Shaped[ndarray, '*shape']:
125 """Convert the dataframe to a 2d numpy array with columns on the second axis."""
126 ...
129@runtime_checkable
130class Series(Protocol):
131 """Series duck-type for `Bart`."""
133 @property
134 def name(self) -> Hashable:
135 """The name of the series."""
136 ...
138 def to_numpy(self) -> Shaped[ndarray, '*shape']:
139 """Convert the series to a 1d numpy array."""
140 ...
143class SparseConfig(Module):
144 R"""
145 Configuration of a sparsity-inducing variable selection prior.
147 This is the prior of [1]_. Pass an instance to the `sparse` argument of
148 `Bart` to activate variable selection on the predictors. The prior on the
149 choice of predictor for each decision rule is
151 .. math::
152 (s_1, \ldots, s_p) \sim
153 \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
155 If `theta` is not specified, it's a priori distributed according to
157 .. math::
158 \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
159 \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
161 References
162 ----------
163 .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
164 High-Dimensional Prediction and Variable Selection”. In: Journal of the
165 American Statistical Association 113.522, pp. 626-636.
166 """
168 theta: FloatLike | None = None
169 """Concentration of the Dirichlet prior. If not specified, it is sampled
170 from a Beta prior parametrized by `a`, `b` and `rho`. If set directly, it
171 should be in the ballpark of the predictor count p or lower."""
173 a: FloatLike = 0.5
174 """Shape parameter of the Beta prior on ``theta / (theta + rho)``."""
176 b: FloatLike = 1.0
177 """Shape parameter of the Beta prior on ``theta / (theta + rho)``."""
179 rho: FloatLike | None = None
180 """Scale of the Beta prior on `theta`. If not specified, set to the number
181 of predictors p. Lower values prefer more sparsity."""
183 augment: bool = field(static=True, default=True)
184 """Whether to account exactly for the decision rules forbidden by the
185 ancestors of each node when updating the variable selection probabilities,
186 using data augmentation. On by default. Setting it to `False` ignores the
187 forbidden rules, which is faster but only approximate. This matters most
188 with few predictors with few cutpoints each, where the same predictor
189 cannot be re-used down a branch."""
191 enabled: bool = field(static=True, default=True)
192 """Whether variable selection is active."""
195class Bart(Module):
196 R"""
197 Nonparametric regression with Bayesian Additive Regression Trees (BART).
199 Regress `y_train` on `x_train` with a latent mean function represented as
200 a sum of decision trees [2]_. The inference is carried out by sampling the
201 posterior distribution of the tree ensemble with an MCMC.
203 Parameters
204 ----------
205 x_train
206 The training predictors.
207 y_train
208 The training responses. For univariate regression, a 1D array of shape
209 `(n,)`. For multivariate regression, a 2D array of shape `(k, n)` where
210 `k` is the number of response components, as introduced in [3]_. For
211 binary regression, the convention is that non-zero values mean 1, zero
212 mean 0, like booleans.
213 outcome_type
214 The type of regression. ``'continuous'`` for continuous regression,
215 ``'binary'`` for binary regression with probit link. For multivariate
216 regression, a scalar value applies to all components; alternatively, a
217 sequence of per-component types (e.g., ``['binary', 'continuous']``)
218 specifies mixed outcome types. Binary components in multivariate
219 outcomes follow the multivariate probit BART formulation of [4]_.
220 sparse
221 A `SparseConfig` for the sparsity-inducing variable selection prior of
222 [1]_. Disabled by default; pass a `SparseConfig` to enable it.
223 varprob
224 The probability distribution over the `p` predictors for choosing a
225 predictor to split on in a decision node a priori. Must be > 0. It does
226 not need to be normalized to sum to 1. If not specified, use a uniform
227 distribution. If `sparse` is enabled, this is used as initial value for
228 the MCMC.
229 binner
230 A callable that, given the training predictors and a random key,
231 returns a `~bartz.prepcovars.Binner` instance. The default is
232 `~bartz.prepcovars.UniqueQuantileBinner`, which places cutpoints at
233 the quantiles of each predictor. Other built-in options are
234 `~bartz.prepcovars.RangeEvenBinner` (evenly-spaced cutpoints over the
235 observed range) and `~bartz.prepcovars.GivenSplitsBinner` (R BART
236 ``xinfo`` format). To pass options, use `functools.partial`, e.g.
237 ``binner=partial(UniqueQuantileBinner, max_bins=128)``.
238 rm_const
239 How to treat predictors with no associated decision rules (i.e., there
240 are no available cutpoints for that predictor). If `True` (default),
241 they are ignored. If `False`, an error is raised if there are any.
242 sigma_df
243 The degrees of freedom of the prior on the error precision. For
244 multivariate regression with `k` components, the Wishart degrees of
245 freedom are set to ``sigma_df + k - 1``.
246 sigma_scale
247 Sets the scale of the prior on the error precision. If 'auto' (default),
248 the prior is scaled so that the error precision equals
249 ``diag(1 / var(y_train))`` in expectation, where with weights `error_scale`
250 the variance is a precision-weighted one that estimates the unit-weight error
251 variance. Otherwise, ``square(sigma_scale)`` is the prior harmonic mean of
252 the error variance; for multivariate regression a scalar is broadcast to
253 all components. For mixed outcome types, binary components are ignored.
254 sigma_init
255 The initial value of the error standard deviation in the MCMC. If 'auto'
256 (default), the initial error precision is set to ``diag(1 / var(y_train))``,
257 with the same precision-weighted variance as `sigma_scale` when weights are
258 given. Otherwise, the initial precision is ``diag(1 / square(sigma_init))``;
259 for multivariate regression a scalar is broadcast to all components. For
260 mixed outcome types, binary components are ignored.
261 k
262 The inverse scale of the prior standard deviation on the latent mean
263 function, relative to half the observed range of `y_train`. If `y_train`
264 has less than two elements, `k` is ignored and the scale is set to 1.
265 power
266 base
267 Parameters of the prior on tree node generation. The probability that a
268 node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
269 power``.
270 tau_num
271 The numerator in the expression that determines the prior standard
272 deviation of leaves. If not specified, default to ``(max(y_train) -
273 min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
274 continuous regression, and 3 for binary regression. For multivariate
275 regression, the range is computed per component. For mixed outcome
276 types, each component uses the default for its type.
277 offset
278 The prior mean of the latent mean function. If not specified, it is set
279 to the mean of `y_train` for continuous regression, and to
280 ``Phi^-1(mean(y_train != 0))`` for binary regression. If `y_train` is
281 empty, `offset` is set to 0. With binary regression, if `y_train` is
282 all zero or all non-zero, it is set to ``Phi^-1(1/(n+1))`` or
283 ``Phi^-1(n/(n+1))``, respectively. For multivariate regression, can be
284 a scalar (broadcast to all components) or a `(k,)` vector. If not
285 specified, it is set to the per-component mean of `y_train`. For mixed
286 outcome types, each component uses the default for its type.
287 error_scale
288 Coefficients that rescale the error standard deviation on each
289 datapoint. Not specifying `error_scale` is equivalent to setting it to 1
290 for all datapoints. Shape ``(n,)`` applies the same scalar weight to every
291 outcome component; for multivariate continuous regression, ``(k, n)``
292 instead supplies a per-component weight per datapoint.
293 missing
294 Boolean mask with the same shape as `y_train`; `True` marks entries
295 to be ignored by the MCMC. Values of `y_train` must be finite
296 everywhere, including at masked positions. If 2-D, the error
297 covariance must be diagonal.
298 num_trees
299 The number of trees used to represent the latent mean function.
300 n_save
301 The number of MCMC samples to save, after burn-in, per chain. The
302 total trace length across all chains is ``num_chains * n_save``.
303 n_burn
304 The number of initial MCMC samples to discard as burn-in. This number
305 of samples is discarded from each chain.
306 n_skip
307 The thinning factor for the MCMC samples, after burn-in.
308 printevery
309 The number of iterations (including thinned-away ones) between each log
310 line. Set to `None` to disable progress reporting entirely (this ignores
311 `pbar`). ^C interrupts the MCMC only every `printevery` iterations, so
312 with reporting disabled it's impossible to kill the MCMC conveniently.
313 pbar
314 If `True`, show a `tqdm` progress bar instead of printing log lines. The
315 bar advances every iteration and refreshes the acceptance statistics
316 every `printevery` iterations. Ignored if `printevery` is `None`.
317 num_chains
318 The number of independent Markov chains to run.
320 The difference between ``num_chains=None`` and ``num_chains=1`` is that
321 in the latter case in the object attributes and some methods there will
322 be an explicit chain axis of size 1.
323 num_chain_devices
324 The number of devices to spread the chains across. Must be a divisor of
325 `num_chains`. Each device will run a fraction of the chains. If 'auto'
326 (default) and running on cpu, the number of devices is picked
327 automatically based on the number of cores and the number of available
328 devices (all the virtual jax cpu devices, or the `devices` list if set).
329 num_data_devices
330 The number of devices to split datapoints across. Must be a divisor of
331 `n`. This is useful only with very high `n`, about > 1000_000. `predict`
332 parallelizes across the same devices, splitting the test points; the
333 number of test points must be a multiple of `num_data_devices` as well.
335 If both num_chain_devices and num_data_devices are specified, the total
336 number of devices used is the product of the two.
337 devices
338 One or more devices used to run the MCMC on. If not specified, the
339 computation will follow the placement of the input arrays. If a list of
340 devices, this argument can be longer than the number of devices needed.
341 seed
342 The seed for the random number generator.
343 maxdepth
344 The maximum depth of the trees. This is 1-based, so with the default
345 ``maxdepth=6``, the depths of the levels range from 0 to 5.
346 init_kw
347 Additional arguments passed to `bartz.mcmcstep.init`.
348 run_mcmc_kw
349 Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
351 References
352 ----------
353 .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
354 High-Dimensional Prediction and Variable Selection”. In: Journal of the
355 American Statistical Association 113.522, pp. 626-636.
356 .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
357 Bayesian additive regression trees," The Annals of Applied Statistics,
358 Ann. Appl. Stat. 4(1), 266-298, (March 2010).
359 .. [3] Um, Seungha, Antonio R. Linero, Debajyoti Sinha, and Dipankar
360 Bandyopadhyay (2023). "Bayesian additive regression trees for
361 multivariate skewed responses". In: Statistics in Medicine 42.3,
362 pp. 246-263.
363 .. [4] Goh, Yong Chen, Wuu Kuang Soh, Andrew C. Parnell, and Keefe
364 Murphy (2024). "Joint Models for Handling Non-Ignorable Missing
365 Data using Bayesian Additive Regression Trees: Application to
366 Leaf Photosynthetic Traits Data". arXiv:2412.14946 [stat.ME].
368 """
370 _main_trace: MainTrace
371 _burnin_trace: BurninTrace
372 _mcmc_state: State
373 _binner: Binner
374 _binary_mask: Bool[Array, ''] | Bool[Array, ' k']
375 # WORKAROUND(jax<0.9.1): use `jax.tree.static` instead of `field(static=True)`
376 _x_train_fmt: Any = field(static=True)
377 _device: Device | None = field(static=True)
379 _error_scale: Float32[Array, ' n'] | Float32[Array, 'k n'] | None = None
381 def __init__(
382 self,
383 x_train: Real[ArrayLike, 'p n'] | DataFrame,
384 y_train: Float32[ArrayLike, ' n']
385 | Float32[ArrayLike, 'k n']
386 | Series
387 | DataFrame,
388 *,
389 outcome_type: OutcomeType | str | Sequence[OutcomeType | str] = 'continuous',
390 sparse: SparseConfig = SparseConfig(enabled=False),
391 varprob: Float[ArrayLike, ' p'] | None = None,
392 binner: BinnerFactory = UniqueQuantileBinner,
393 rm_const: bool = True,
394 sigma_df: FloatLike = 3.0,
395 sigma_scale: FloatLike | Float[ArrayLike, ' k'] | Literal['auto'] = 'auto',
396 sigma_init: FloatLike | Float[ArrayLike, ' k'] | Literal['auto'] = 'auto',
397 k: FloatLike = 2.0,
398 power: FloatLike = 2.0,
399 base: FloatLike = 0.95,
400 tau_num: FloatLike | None = None,
401 offset: FloatLike | Float[ArrayLike, ' k'] | None = None,
402 error_scale: Float[ArrayLike, ' n']
403 | Float[ArrayLike, 'k n']
404 | Series
405 | DataFrame
406 | None = None,
407 missing: Bool[ArrayLike, ' n']
408 | Bool[ArrayLike, 'k n']
409 | Series
410 | DataFrame
411 | None = None,
412 num_trees: int = 200,
413 n_save: int = 1000,
414 n_burn: int = 1000,
415 n_skip: int = 1,
416 printevery: int | None = 100,
417 pbar: bool = True,
418 num_chains: int | None = 4,
419 num_chain_devices: int | None | Literal['auto'] = 'auto',
420 num_data_devices: int | None = None,
421 devices: Literal['cpu', 'gpu'] | Device | Sequence[Device] | None = None,
422 seed: int | Key[Array, ''] = 0,
423 maxdepth: int = 6,
424 init_kw: Mapping = MappingProxyType({}),
425 run_mcmc_kw: Mapping = MappingProxyType({}),
426 ) -> None:
427 # check data and put it in the right format
428 x_train, x_train_fmt = _process_predictor_input(x_train)
429 y_train = _process_response_input(y_train)
430 _check_same_length(x_train, y_train)
432 if error_scale is not None:
433 # keep=True because `error_scale` is donated downstream but also
434 # retained as `self._error_scale` for prediction
435 error_scale, self._error_scale = _process_response_input(
436 error_scale, keep=True
437 )
438 _check_same_length(x_train, error_scale)
440 if missing is not None:
441 missing = _process_response_input(missing, dtype=jnp.bool_)
442 _check_same_length(x_train, missing)
444 # check data types are correct for continuous/binary/multivariate regression
445 outcome_type, binary_mask = _check_type_settings(
446 y_train, outcome_type, error_scale
447 )
449 # process "standardization" settings
450 offset = _process_offset_settings(y_train, binary_mask, offset)
451 leaf_prior_cov_inv = _process_leaf_variance_settings(
452 y_train, binary_mask, k, num_trees, tau_num
453 )
454 error_cov_inv = _process_error_variance_settings(
455 y_train,
456 outcome_type,
457 binary_mask,
458 missing,
459 sigma_df,
460 sigma_scale,
461 sigma_init,
462 error_scale,
463 )
465 # split the user-provided seed into an mcmc key and a binner key
466 if not is_key(seed):
467 seed = random.key(seed)
468 keys = split(seed)
470 # construct the binner and bin x_train
471 binner_obj = binner(x_train, key=keys.pop())
472 x_train = binner_obj.bin(x_train)
473 # copy max_split because `mcmcstep.init` donates it
474 max_split = jnp.array(binner_obj.max_split)
476 # setup and run mcmc
477 initial_state, mcmc_key, device = _setup_mcmc(
478 x_train,
479 y_train,
480 outcome_type,
481 offset,
482 error_scale,
483 missing,
484 max_split,
485 leaf_prior_cov_inv,
486 error_cov_inv,
487 power,
488 base,
489 maxdepth,
490 num_trees,
491 init_kw,
492 rm_const,
493 sparse,
494 varprob,
495 num_chains,
496 num_chain_devices,
497 num_data_devices,
498 devices,
499 n_burn,
500 keys.pop(),
501 )
502 result = _run_mcmc(
503 initial_state,
504 n_save,
505 n_burn,
506 n_skip,
507 printevery,
508 pbar,
509 mcmc_key,
510 run_mcmc_kw,
511 )
513 # set private attributes
514 self._main_trace = result.main_trace
515 self._burnin_trace = result.burnin_trace
516 self._mcmc_state = result.final_state
517 self._binner = binner_obj
518 self._x_train_fmt = x_train_fmt
519 self._binary_mask = binary_mask
520 self._device = device
522 def predict(
523 self,
524 x_test: Real[ArrayLike, 'p m'] | DataFrame | str,
525 *,
526 kind: PredictKind | str = 'mean',
527 key: Key[Array, ''] | None = None,
528 error_scale: Float[ArrayLike, ' m']
529 | Float[ArrayLike, 'k m']
530 | Series
531 | DataFrame
532 | None = None,
533 ) -> (
534 Float32[Array, ' m']
535 | Float32[Array, 'k m']
536 | Float32[Array, 'ndpost m']
537 | Float32[Array, 'ndpost k m']
538 ):
539 """
540 Compute predictions at `x_test`.
542 Parameters
543 ----------
544 x_test
545 The test predictors, or the string ``'train'`` to compute
546 predictions on the training data.
547 kind
548 The kind of output. See `PredictKind` for details.
549 key
550 Jax random key, required when ``kind='outcome_samples'``.
551 error_scale
552 Per-observation error scale for ``kind='outcome_samples'``.
553 Required when the model was fit with weights and ``x_test`` is
554 new data. Shape matches the shape used at fitting: ``(m,)`` for
555 scalar weights, ``(k, m)`` for multivariate vector weights.
557 Returns
558 -------
559 Predictions at `x_test` in the requested format.
561 Raises
562 ------
563 ValueError
564 If `x_test` has a different format than `x_train`, or if `error_scale`
565 is specified when it should be `None`, or if `error_scale` is not
566 specified when it is required, or if the model splits datapoints
567 across devices (`num_data_devices`) and the number of test points
568 is not a multiple of the number of data devices.
570 Notes
571 -----
572 If the model splits datapoints across devices (`num_data_devices`),
573 the test points and the returned predictions are split the same way.
574 """
575 # parse arguments
576 kind = PredictKind(kind)
577 if kind is PredictKind.outcome_samples and key is None: 577 ↛ 578line 577 didn't jump to line 578 because the condition on line 577 was never true
578 msg = '`key` not specified'
579 raise ValueError(msg)
580 error_scale = self._process_error_scale_test(x_test, kind, error_scale)
581 x_test_is_train = isinstance(x_test, str) and x_test == 'train'
582 x_test = self._process_x_test(x_test, error_scale)
584 # place new test data on the devices of the model; the training data
585 # is already in place
586 if not x_test_is_train:
587 x_test, error_scale = self._device_put_test(x_test, error_scale)
589 # invoke jitted implementation
590 return predict(
591 key,
592 self._main_trace,
593 x_test,
594 error_scale,
595 self._mcmc_state.binary_indices,
596 self._mcmc_state.binary_y is not None,
597 kind,
598 # the test points are sharded over the mesh 'data' axis (when
599 # there is one): the training data at `init`, new test data by
600 # `_device_put_test`. `evaluate_trace` can't detect this on its
601 # own at trace time, so declare it.
602 'shard_and_autobatch',
603 )
605 def _drop_device_info(self) -> 'Bart':
606 """Return a copy of the model without device placement metadata.
608 Clear the meshes in the MCMC state config and in the traces, and the
609 explicitly requested device. Only this static metadata is dropped: the
610 arrays keep their actual placement.
611 """
612 config = replace(self._mcmc_state.config, mesh=None)
613 main_trace = replace(self._main_trace, mesh=None)
614 burnin_trace = replace(self._burnin_trace, mesh=None)
615 obj = tree_at(
616 lambda b: (b._mcmc_state.config, b._main_trace, b._burnin_trace), # noqa: SLF001
617 self,
618 (config, main_trace, burnin_trace),
619 )
620 # `_device` is a static field, out of `tree_at`'s reach, so modify the
621 # fresh copy in place
622 object.__setattr__(obj, '_device', None)
623 return obj
625 def dump(self, path: str | PathLike) -> None:
626 """Serialize the fitted model to a file with `pickle`.
628 Parameters
629 ----------
630 path
631 The file to write to.
633 Notes
634 -----
635 Intended for short-term storage (e.g. caching across processes), not
636 long-term archival: the format depends on the versions of bartz, jax and
637 equinox. The arrays are copied to host memory and all device/sharding
638 placement is dropped; `load` reconstructs a single-device model.
639 """
640 # drop all device info (`Device` objects are not picklable), then
641 # gather any sharded arrays to host (dropping their sharding); the
642 # reload is single-device
643 obj = self._drop_device_info()
644 obj = jax.device_get(obj)
645 with Path(path).open('wb') as file:
646 pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
648 @classmethod
649 def load(cls, path: str | PathLike) -> 'Bart':
650 """Load a model saved with `dump`.
652 Parameters
653 ----------
654 path
655 The file to read from.
657 Returns
658 -------
659 The deserialized model, on host memory with no device placement.
661 Raises
662 ------
663 TypeError
664 If the file does not contain a `Bart` instance.
665 """
666 with Path(path).open('rb') as file:
667 obj = pickle.load(file) # noqa: S301, the user owns the file
668 if not isinstance(obj, cls):
669 msg = f'unpickled a {type(obj).__name__}, not a {cls.__name__}'
670 raise TypeError(msg)
671 return obj
673 @property
674 def offset(self) -> Float32[Array, ''] | Float32[Array, ' k']:
675 """The prior mean of the latent mean function."""
676 return self._mcmc_state.offset
678 @property
679 def n_save(self) -> int:
680 """The number of posterior samples after burn-in saved per chain."""
681 sample_axis = trace_sample_axes(self._main_trace).grow_prop_count
682 return self._main_trace.grow_prop_count.shape[sample_axis]
684 @property
685 def num_chains(self) -> int | None:
686 """The number of chains, `None` if scalar."""
687 return self._mcmc_state.num_chains()
689 @property
690 def ndpost(self) -> int:
691 """The total number of posterior samples after burn-in across all chains."""
692 return self._main_trace.grow_prop_count.size
694 @property
695 def num_trees(self) -> int:
696 """Return the number of trees used in the model."""
697 forest = self._mcmc_state.forest
698 chain_axis = chain_vmap_axes(forest).split_tree
699 # chainless split_tree is (num_trees, half_tree_size); num_trees is core axis 0
700 axis = chainful_axis(0, chain_axis)
701 return forest.split_tree.shape[axis]
703 def get_latent_prec(
704 self, only_continuous: bool = False
705 ) -> (
706 Float32[Array, ' n_burn_plus_n_save']
707 | Float32[Array, 'n_burn_plus_n_save k k']
708 | Float32[Array, 'num_chains n_burn_plus_n_save']
709 | Float32[Array, 'num_chains n_burn_plus_n_save k k']
710 ):
711 """Return the posterior samples of the latent error precision matrix.
713 Parameters
714 ----------
715 only_continuous
716 If `True` and the model has mixed binary-continuous outcomes,
717 return only the submatrix for the continuous components.
719 Returns
720 -------
721 MCMC samples of the error precision matrix.
723 Raises
724 ------
725 ValueError
726 If `only_continuous` is `True` but the model has only binary
727 outcomes, so there is no continuous submatrix to return.
729 Notes
730 -----
731 This method is meant to check for convergence, so it returns the full
732 MCMC trace and does not concatenate chains together. For probit
733 regression, this returns the precision of the latent error term, not
734 the Bernoulli precision for the binary outcome. For heteroskedastic
735 regression, the returned precision is the global precision parameter,
736 that would have to be divided by a squared weight to get the precision
737 on a given datapoint.
738 """
739 binary_indices = self._mcmc_state.binary_indices
740 if (
741 only_continuous
742 and binary_indices is None
743 and self._mcmc_state.binary_y is not None
744 ):
745 msg = 'Model has only binary outcomes, so there is no continuous submatrix to return.'
746 raise ValueError(msg)
748 return get_latent_prec(
749 self._burnin_trace,
750 self._main_trace,
751 binary_indices,
752 only_continuous=only_continuous,
753 )
755 def get_error_sdev(
756 self, mean: bool = False
757 ) -> (
758 Float32[Array, ' ndpost']
759 | Float32[Array, 'ndpost k']
760 | Float32[Array, '']
761 | Float32[Array, ' k']
762 ):
763 """Return the error standard deviation, post-burnin, chains concatenated.
765 Parameters
766 ----------
767 mean
768 If `True`, average the error covariance matrix across samples before
769 taking the square root, returning a single scalar or vector instead
770 of posterior samples.
772 Returns
773 -------
774 Posterior samples (or single estimate) of the error standard deviation; NaN for binary outcomes.
776 Notes
777 -----
778 Binary outcomes do have a standard deviation of course, but it's not
779 returned by this method because that would require to evaluate
780 predictions on a given X, since the Bernoulli variance is p(1-p).
781 """
782 # binary outcomes are filled with NaN, so disable the NaN check
783 with debug_nans(False):
784 return get_error_sdev(self._main_trace, self._binary_mask, mean=mean)
786 @cached_property
787 def varcount(self) -> Int32[Array, 'ndpost p']:
788 """Histogram of predictor usage for decision rules in the trees."""
789 p = self._mcmc_state.forest.max_split.size
790 return varcount(p, self._main_trace)
792 @cached_property
793 def varcount_mean(self) -> Float32[Array, ' p']:
794 """Average of `varcount` across MCMC iterations."""
795 return self.varcount.mean(axis=0)
797 @cached_property
798 def varprob(self) -> Float32[Array, 'ndpost p']:
799 """Posterior samples of the probability of choosing each predictor for a decision rule."""
800 return varprob(self._mcmc_state.forest.max_split, self._main_trace)
802 @cached_property
803 def varprob_mean(self) -> Float32[Array, ' p']:
804 """The marginal posterior probability of each predictor being chosen for a decision rule."""
805 return self.varprob.mean(axis=0)
807 def _process_error_scale_test(
808 self,
809 x_test: Real[ArrayLike, 'p m'] | DataFrame | str,
810 kind: PredictKind,
811 error_scale: Float[ArrayLike, ' m']
812 | Float[ArrayLike, 'k m']
813 | Series
814 | DataFrame
815 | None,
816 ) -> Float32[Array, ' m'] | Float32[Array, 'k m'] | None:
817 """Validate and resolve the error weights for prediction.
819 Parameters
820 ----------
821 x_test
822 The raw (not yet processed) test predictors, or ``'train'``.
823 kind
824 The prediction kind.
825 error_scale
826 User-provided per-observation error scale, or `None`.
828 Returns
829 -------
830 The resolved error scale as a float32 array, or `None` if weights are not applicable.
832 Raises
833 ------
834 ValueError
835 If `error_scale` is specified when it should be `None`, or missing
836 when required.
837 """
838 x_test_is_train = isinstance(x_test, str) and x_test == 'train'
839 has_train_weights = self._error_scale is not None
840 is_binary = self._mcmc_state.binary_y is not None
841 needs_weights = (
842 kind is PredictKind.outcome_samples and not is_binary and has_train_weights
843 )
845 if not needs_weights:
846 if error_scale is not None: 846 ↛ 847line 846 didn't jump to line 847 because the condition on line 846 was never true
847 msg = (
848 '`error_scale` must be `None` in this configuration'
849 " (it is used only with kind='outcome_samples',"
850 ' continuous regression fitted with weights)'
851 )
852 raise ValueError(msg)
853 return None
855 if x_test_is_train:
856 if error_scale is not None: 856 ↛ 857line 856 didn't jump to line 857 because the condition on line 856 was never true
857 msg = (
858 "`error_scale` must be `None` when x_test='train'"
859 ' (training weights are used automatically)'
860 )
861 raise ValueError(msg)
862 return self._error_scale
864 # new test data, model was fit with weights
865 if error_scale is None: 865 ↛ 866line 865 didn't jump to line 866 because the condition on line 865 was never true
866 msg = (
867 '`error_scale` is required because the model was fit with'
868 ' weights and x_test is new data'
869 )
870 raise ValueError(msg)
871 error_scale_test = _process_response_input(error_scale)
872 assert self._error_scale is not None # implied by needs_weights
873 if error_scale_test.ndim != self._error_scale.ndim: 873 ↛ 874line 873 didn't jump to line 874 because the condition on line 873 was never true
874 msg = (
875 f'`error_scale` shape mismatch with training weights: got '
876 f'{error_scale_test.shape=}, expected {self._error_scale.ndim}D '
877 f'(matching the training-weight shape).'
878 )
879 raise ValueError(msg)
880 return error_scale_test
882 def _process_x_test(
883 self,
884 x_test: Real[ArrayLike, 'p m'] | DataFrame | str,
885 error_scale: Float32[Array, ' m'] | Float32[Array, 'k m'] | None,
886 ) -> UInt[Array, 'p m']:
887 """Convert x_test to binned format suitable for prediction."""
888 if isinstance(x_test, str):
889 if x_test != 'train': 889 ↛ 890line 889 didn't jump to line 890 because the condition on line 889 was never true
890 msg = (
891 f"x_test must be an array, a DataFrame, or 'train', got {x_test!r}"
892 )
893 raise ValueError(msg)
894 return self._mcmc_state.X
895 x_test, x_test_fmt = _process_predictor_input(x_test)
896 if x_test_fmt != self._x_train_fmt:
897 msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
898 raise ValueError(msg)
899 if error_scale is not None:
900 _check_same_length(error_scale, x_test)
901 return self._binner.bin(x_test)
903 def _device_put_test(
904 self,
905 x_test: UInt[Array, 'p m'],
906 error_scale: Float32[Array, ' m'] | Float32[Array, 'k m'] | None,
907 ) -> tuple[UInt[Array, 'p m'], Float32[Array, ' m'] | Float32[Array, 'k m'] | None]:
908 """Place new test data on the devices of the model.
910 Mirror the placement of the training data done at fit time: shard over
911 the mesh if there is one (the observation axis over 'data'), else move
912 to the device requested explicitly at construction, if any. The inputs
913 are donated, so they must not be used elsewhere.
914 """
915 mesh = self._mcmc_state.config.mesh
916 if mesh is not None:
917 put = lambda a: device_put(
918 a,
919 NamedSharding(mesh, _leaf_partition_spec(a.ndim, None, -1, mesh)),
920 donate=True,
921 )
922 elif self._device is not None:
923 put = lambda a: device_put(a, self._device, donate=True)
924 else:
925 return x_test, error_scale
926 if error_scale is None:
927 return put(x_test), None
928 else:
929 return put(x_test), put(error_scale)
931 def _check_trees(
932 self, error: bool = False
933 ) -> UInt[Array, 'num_chains n_save num_trees']:
934 """Apply `bartz.grove.check_trace` to all the tree draws.
936 Parameters
937 ----------
938 error
939 If `True`, throw an error if any invalid trees are found.
941 Returns
942 -------
943 An array where non-zero entries indicate invalid trees.
945 Raises
946 ------
947 RuntimeError
948 If `error` is `True` and any invalid trees are found.
949 """
950 out = check_trees(self._main_trace, self._mcmc_state.forest.max_split)
951 if error:
952 bad_count = jnp.count_nonzero(out).item()
953 if bad_count > 0:
954 msg = f'Found {bad_count} invalid trees in the MCMC trace.'
955 raise RuntimeError(msg)
956 return out
958 def _tree_goes_bad(self) -> Bool[Array, 'num_chains n_save num_trees']:
959 """Find iterations where a tree becomes invalid.
961 Returns
962 -------
963 An array where ``(i, j, k)`` is `True` if tree `k` is invalid at iteration `j` in chain `i` but not at iteration ``j - 1``.
964 """
965 return tree_goes_bad(self._main_trace, self._mcmc_state.forest.max_split)
967 def _check_replicated_trees(self) -> None:
968 """Check that the trees are equal across data-sharded devices.
970 If the data is sharded across devices, verify that the trees (which
971 should be replicated) are identical on all shards.
973 Raises
974 ------
975 RuntimeError
976 If the trees differ across devices.
977 """
978 state = self._mcmc_state
979 mesh = state.config.mesh
980 if mesh is not None and 'data' in mesh.axis_names:
981 # drop the data-sharded `leaf_indices` (not replicated) before the
982 # cross-shard equality check; `None` is a deliberately off-type
983 # placeholder, so use `tree_at`, which (unlike `dataclasses.replace`)
984 # bypasses the `__init__` type checks
985 replicated_forest = tree_at(lambda f: f.leaf_indices, state.forest, None)
986 equal = equal_shards(
987 replicated_forest, 'data', in_specs=PartitionSpec(), mesh=mesh
988 )
989 equal_array = jnp.stack(tree.leaves(equal))
990 all_equal = jnp.all(equal_array)
991 if not all_equal.item(): 991 ↛ 992line 991 didn't jump to line 992 because the condition on line 991 was never true
992 msg = 'The trees differ across data-sharded devices.'
993 raise RuntimeError(msg)
995 def _compare_resid(
996 self, y: Float32[Array, ' n'] | Float32[Array, 'k n'] | None = None
997 ) -> tuple[
998 Float32[Array, '*num_chains n'] | Float32[Array, '*num_chains k n'],
999 Float32[Array, '*num_chains n'] | Float32[Array, '*num_chains k n'],
1000 ]:
1001 """Re-compute residuals to compare them with the updated ones.
1003 Parameters
1004 ----------
1005 y
1006 The response variable. Required for continuous regression (since
1007 ``State`` does not store ``y`` in continuous mode). Ignored for
1008 binary regression (where ``State.z`` is used instead).
1010 Returns
1011 -------
1012 resid1
1013 The final state of the residuals updated during the MCMC.
1014 resid2
1015 The residuals computed from the final state of the trees.
1016 """
1017 state = self._mcmc_state
1018 if state.binary_indices is not None:
1019 assert y is not None, 'y is required for mixed regression'
1020 elif state.z is None:
1021 assert y is not None, 'y is required for continuous regression'
1022 y_arr = jnp.asarray(y) if y is not None else None
1023 return compare_resid(state, y_arr)
1025 def _depth_distr(self) -> Int32[Array, '*num_chains n_save d']:
1026 """Histogram of tree depths for each state of the trees.
1028 Returns
1029 -------
1030 A matrix where each row contains a histogram of tree depths.
1031 """
1032 return depth_distr(self._main_trace)
1034 def _points_per_node_distr(
1035 self, node_type: Literal['leaf', 'leaf-parent']
1036 ) -> Int32[Array, '*num_chains n_save n_plus_1']:
1037 return points_per_node_distr_trace(
1038 self._mcmc_state.X, self._main_trace, node_type
1039 )
1041 def _points_per_decision_node_distr(
1042 self,
1043 ) -> Int32[Array, '*num_chains n_save n_plus_1']:
1044 """Histogram of number of points belonging to parent-of-leaf nodes.
1046 Returns
1047 -------
1048 For each chain, a matrix where each row contains a histogram of number of points.
1049 """
1050 return self._points_per_node_distr('leaf-parent')
1052 def _points_per_leaf_distr(self) -> Int32[Array, '*num_chains n_save n_plus_1']:
1053 """Histogram of number of points belonging to leaves.
1055 Returns
1056 -------
1057 A matrix where each row contains a histogram of number of points.
1058 """
1059 return self._points_per_node_distr('leaf')
1061 def _print_tree(
1062 self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False
1063 ) -> None:
1064 """Print a single tree in human-readable format.
1066 Parameters
1067 ----------
1068 i_chain
1069 The index of the MCMC chain.
1070 i_sample
1071 The index of the (post-burnin) sample in the chain.
1072 i_tree
1073 The index of the tree in the sample.
1074 print_all
1075 If `True`, also print the content of unused node slots.
1076 """
1077 trace = self._main_trace
1078 trees = _trees_chain_first(trace)
1079 chain_index = i_chain if trace.has_chains else ...
1080 trees = tree.map(lambda x: x[chain_index, i_sample, i_tree, :], trees)
1081 s = format_tree(trees, print_all=print_all)
1082 print(s) # noqa: T201, this method is intended for debug
1085def _process_predictor_input(
1086 x: Real[ArrayLike, 'p n'] | DataFrame,
1087) -> tuple[Shaped[Array, 'p n'], Any]:
1088 if isinstance(x, DataFrame):
1089 fmt = dict(kind='dataframe', columns=x.columns)
1090 x = x.to_numpy().T
1091 else:
1092 fmt = dict(kind='array', num_covar=x.shape[0])
1093 x = jnp.asarray(x)
1094 assert x.ndim == 2
1095 return x, fmt
1098@overload
1099def _process_response_input(
1100 arr: Shaped[ArrayLike, ' n'] | Shaped[ArrayLike, 'k n'] | Series | DataFrame,
1101 /,
1102 *,
1103 keep: Literal[False] = False,
1104 dtype: DTypeLike = jnp.float32,
1105) -> Shaped[Array, ' n'] | Shaped[Array, 'k n']: ...
1108@overload
1109def _process_response_input( 1109 ↛ anywhereline 1109 didn't jump anywhere: it always raised an exception.
1110 arr: Shaped[ArrayLike, ' n'] | Shaped[ArrayLike, 'k n'] | Series | DataFrame,
1111 /,
1112 *,
1113 keep: Literal[True],
1114 dtype: DTypeLike = jnp.float32,
1115) -> tuple[
1116 Shaped[Array, ' n'] | Shaped[Array, 'k n'],
1117 Shaped[Array, ' n'] | Shaped[Array, 'k n'],
1118]: ...
1121def _process_response_input(
1122 arr: Shaped[ArrayLike, ' n'] | Shaped[ArrayLike, 'k n'] | Series | DataFrame,
1123 /,
1124 *,
1125 keep: bool = False,
1126 dtype: DTypeLike = jnp.float32,
1127) -> (
1128 Shaped[Array, ' n']
1129 | Shaped[Array, 'k n']
1130 | tuple[
1131 Shaped[Array, ' n'] | Shaped[Array, 'k n'],
1132 Shaped[Array, ' n'] | Shaped[Array, 'k n'],
1133 ]
1134):
1135 if isinstance(arr, DataFrame):
1136 arr = arr.to_numpy().T
1137 elif isinstance(arr, Series):
1138 arr = arr.to_numpy()
1139 # in normal mode: one unconditional copy, safe to donate downstream.
1140 # in `keep` mode: convert without copying when possible to get the
1141 # keep array, then `jnp.copy` to make a separate disposable copy.
1142 arr = jnp.array(arr, dtype, copy=not keep)
1143 if arr.ndim < 1 or arr.ndim > 2: 1143 ↛ 1144line 1143 didn't jump to line 1144 because the condition on line 1143 was never true
1144 msg = f'response-like input must be 1D (n,) or 2D (k, n). Got {arr.ndim=}.'
1145 raise ValueError(msg)
1146 if keep:
1147 return jnp.copy(arr), arr
1148 return arr
1151def _check_same_length(x1: Shaped[Array, '... n'], x2: Shaped[Array, '... n']) -> None:
1152 get_length = lambda x: x.shape[-1]
1153 assert get_length(x1) == get_length(x2)
1156def _check_type_settings(
1157 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'],
1158 outcome_type: OutcomeType | str | Sequence[OutcomeType | str],
1159 error_scale: Float[Array, ' n'] | Float[Array, 'k n'] | None,
1160) -> tuple[OutcomeType | tuple[OutcomeType, ...], Bool[Array, ''] | Bool[Array, ' k']]:
1161 # standardize outcome_type to OutcomeType or tuple[OutcomeType, ...]
1162 if isinstance(outcome_type, Sequence) and not isinstance(outcome_type, str):
1163 outcome_type = tuple(OutcomeType(t) for t in outcome_type)
1164 num_types = len(outcome_type)
1165 if len(set(outcome_type)) == 1:
1166 outcome_type = outcome_type[0]
1167 else:
1168 num_types = None
1169 outcome_type = OutcomeType(outcome_type)
1171 # validation
1172 if num_types is not None and (y_train.ndim != 2 or num_types != y_train.shape[0]):
1173 msg = (
1174 f'Sequence outcome_type of length {num_types}'
1175 f' requires y_train.shape=({num_types}, n),'
1176 f' found {y_train.shape=}.'
1177 )
1178 raise ValueError(msg)
1179 if error_scale is not None and outcome_type is not OutcomeType.continuous:
1180 msg = 'Weights are not supported when any outcome is binary.'
1181 raise ValueError(msg)
1182 if ( 1182 ↛ 1187line 1182 didn't jump to line 1187 because the condition on line 1182 was never true
1183 error_scale is not None
1184 and error_scale.ndim == 2
1185 and (y_train.ndim != 2 or error_scale.shape[0] != y_train.shape[0])
1186 ):
1187 msg = (
1188 f'2D error_scale (vector per-component weights) requires y_train of '
1189 f'shape (k, n) with matching k; got {error_scale.shape=}, '
1190 f'{y_train.shape=}.'
1191 )
1192 raise ValueError(msg)
1194 if isinstance(outcome_type, tuple):
1195 binary_mask = jnp.array([t is OutcomeType.binary for t in outcome_type])
1196 else:
1197 binary_mask = jnp.bool_(outcome_type is OutcomeType.binary)
1198 binary_mask = jnp.broadcast_to(binary_mask, y_train.shape[:-1])
1200 return outcome_type, binary_mask
1203def _process_sparsity_settings(
1204 x_train: Real[Array, 'p n'], sparse: SparseConfig
1205) -> (
1206 tuple[None, None, None, None]
1207 | tuple[FloatLike, None, None, None]
1208 | tuple[None, FloatLike, FloatLike, FloatLike]
1209):
1210 """Return (theta, a, b, rho)."""
1211 if not sparse.enabled:
1212 return None, None, None, None
1213 elif sparse.theta is not None:
1214 return sparse.theta, None, None, None
1215 else:
1216 rho = sparse.rho
1217 if rho is None:
1218 p, _ = x_train.shape
1219 rho = float(p)
1220 return None, sparse.a, sparse.b, rho
1223def _process_offset_settings(
1224 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'],
1225 binary_mask: Bool[Array, ''] | Bool[Array, ' k'],
1226 offset: FloatLike | Float[ArrayLike, ' k'] | None,
1227) -> Float32[Array, ''] | Float32[Array, ' k']:
1228 """Return offset."""
1229 if offset is not None:
1230 off = jnp.asarray(offset, jnp.float32)
1231 return jnp.broadcast_to(off, y_train.shape[:-1])
1232 if y_train.shape[-1] < 1:
1233 return jnp.zeros(y_train.shape[:-1])
1235 bound = 1 / (1 + y_train.shape[-1])
1236 binary_offset = ndtri(jnp.clip((y_train != 0).mean(-1), bound, 1 - bound))
1237 continuous_offset = y_train.mean(-1)
1238 return jnp.where(binary_mask, binary_offset, continuous_offset)
1241def _process_leaf_variance_settings(
1242 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'],
1243 binary_mask: Bool[Array, ''] | Bool[Array, ' k'],
1244 k: FloatLike,
1245 num_trees: int,
1246 tau_num: FloatLike | None,
1247) -> Float32[Array, ''] | Float32[Array, 'k k']:
1248 """Return `leaf_prior_cov_inv`."""
1249 # determine `tau_num` if not specified
1250 if tau_num is None:
1251 if y_train.shape[-1] < 2:
1252 continuous_tau = jnp.ones(y_train.shape[:-1])
1253 else:
1254 continuous_tau = (y_train.max(-1) - y_train.min(-1)) / 2
1255 tau_num = jnp.where(binary_mask, 3.0, continuous_tau)
1257 # leaf prior standard deviation
1258 sigma_mu = tau_num / (k * math.sqrt(num_trees))
1260 # leaf prior precision matrix
1261 leaf_prior_cov_inv = jnp.reciprocal(jnp.square(sigma_mu))
1262 if y_train.ndim == 2:
1263 leaf_prior_cov_inv = jnp.diag(
1264 jnp.broadcast_to(leaf_prior_cov_inv, y_train.shape[:-1])
1265 )
1266 return leaf_prior_cov_inv
1269def _process_error_variance_settings(
1270 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'],
1271 outcome_type: OutcomeType | tuple[OutcomeType, ...],
1272 binary_mask: Bool[Array, ''] | Bool[Array, ' k'],
1273 missing: Bool[Array, ' n'] | Bool[Array, 'k n'] | None,
1274 sigma_df: FloatLike,
1275 sigma_scale: FloatLike | Float[ArrayLike, ' k'] | Literal['auto'],
1276 sigma_init: FloatLike | Float[ArrayLike, ' k'] | Literal['auto'],
1277 error_scale: Float32[Array, ' n'] | Float32[Array, 'k n'] | None,
1278) -> Wishart | None:
1279 """Build the error precision prior from the user settings."""
1280 if outcome_type is OutcomeType.binary:
1281 if not isinstance(sigma_scale, str) or not isinstance(sigma_init, str):
1282 msg = (
1283 'Do not set `sigma_scale` or `sigma_init` for binary regression, '
1284 'they are ignored'
1285 )
1286 raise ValueError(msg)
1287 return None
1289 *kdims, _ = y_train.shape # () or (k,)
1290 k = kdims[0] if kdims else 1
1291 nu = jnp.asarray(sigma_df, jnp.float32) + (k - 1)
1293 # guarded per-component variance of y_train, computed only when an 'auto'
1294 # spec needs it (this function is not jitted, so it would not be elided)
1295 if isinstance(sigma_scale, str) or isinstance(sigma_init, str):
1296 vary = _guarded_response_variance(y_train, error_scale, missing)
1297 else:
1298 vary = None
1300 # prior rate: E[precision] = nu / rate, so rate = nu * var per component
1301 rate_diag = jnp.where(
1302 binary_mask, 0.0, nu * _resolve_error_variance(sigma_scale, vary, kdims)
1303 )
1305 # initial precision = 1 / var per component (1 for binary components)
1306 init_var = _resolve_error_variance(sigma_init, vary, kdims)
1307 init_diag = jnp.where(binary_mask, 1.0, jnp.reciprocal(init_var))
1309 if y_train.ndim == 2:
1310 rate, init = jnp.diag(rate_diag), jnp.diag(init_diag)
1311 else:
1312 rate, init = rate_diag, init_diag
1313 return make_error_cov_prior(nu, rate, init, outcome_type, missing)
1316@jit
1317def _guarded_response_variance(
1318 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'],
1319 error_scale: Float32[Array, ' n'] | Float32[Array, 'k n'] | None,
1320 missing: Bool[Array, ' n'] | Bool[Array, 'k n'] | None,
1321) -> Float32[Array, '*k']:
1322 """Per-component variance of `y_train`, used by the 'auto' error scale.
1324 A precision-weighted variance (precision ``1 / error_scale ** 2``) estimates
1325 the unit-weight ``sigma ** 2``; `missing` entries are dropped. The variance
1326 is guarded to 1 when undefined (fewer than 2 valid points) or non-positive.
1327 """
1328 if error_scale is None and missing is None:
1329 vary = jnp.var(y_train, axis=-1)
1330 return jnp.where(vary > 0, vary, 1.0)
1331 else:
1332 prec = (
1333 jnp.ones(())
1334 if error_scale is None
1335 else jnp.reciprocal(jnp.square(error_scale))
1336 )
1337 if missing is not None:
1338 prec = jnp.where(missing, 0.0, prec)
1339 y_train = jnp.where(missing, 0.0, y_train)
1340 n_valid = jnp.count_nonzero(prec, axis=-1)
1341 wmean = jnp.sum(prec * y_train, axis=-1) / jnp.sum(prec, axis=-1)
1342 sqdev = prec * jnp.square(y_train - wmean[..., None])
1343 vary = jnp.sum(sqdev, axis=-1) / n_valid
1344 # guard on n_valid too: with a single valid point the variance is 0 in
1345 # exact arithmetic, but float rounding in wmean can leave a tiny
1346 # positive vary that would slip past the `vary > 0` guard
1347 return jnp.where((n_valid > 1) & (vary > 0), vary, 1.0)
1350def _resolve_error_variance(
1351 spec: FloatLike | Float[ArrayLike, ' k'] | Literal['auto'],
1352 vary: Float32[Array, '*k'] | None,
1353 shape: Sequence[int],
1354) -> Float32[Array, '*k']:
1355 """Per-component error variance from a scale spec ('auto' uses var(y))."""
1356 if isinstance(spec, str):
1357 if spec != 'auto':
1358 msg = f"unrecognized value {spec!r}, expected 'auto' or a number"
1359 raise ValueError(msg)
1360 assert vary is not None # computed iff some spec is 'auto'
1361 return vary
1362 else:
1363 return jnp.broadcast_to(jnp.square(jnp.asarray(spec, jnp.float32)), shape)
1366def make_error_cov_prior(
1367 nu: Float32[Array, ''],
1368 rate: Float32[Array, ''] | Float32[Array, 'k k'],
1369 value: Float32[Array, ''] | Float32[Array, 'k k'],
1370 outcome_type: OutcomeType | tuple[OutcomeType, ...],
1371 missing: Bool[Array, ' n'] | Bool[Array, 'k n'] | None,
1372) -> Wishart:
1373 """Build the error precision prior, diagonal-constrained where required.
1375 Mixed binary-continuous and partial-missing (2-D mask) regression restrict
1376 the error covariance to diagonal, so they take a `DiagWishart`; the dense
1377 cases take a `Wishart`. `init` re-checks this choice. `value` is the initial
1378 value of the precision.
1379 """
1380 if isinstance(outcome_type, tuple):
1381 binary = [t is OutcomeType.binary for t in outcome_type]
1382 is_mixed = any(binary) and not all(binary)
1383 else:
1384 is_mixed = False
1385 # a 2-D missingness mask only occurs with multivariate y (checked in `init`)
1386 partial_missing = missing is not None and missing.ndim == 2
1387 if is_mixed or partial_missing:
1388 return DiagWishart(nu=nu, rate=rate, value=value)
1389 else:
1390 return Wishart(nu=nu, rate=rate, value=value)
1393def _setup_mcmc(
1394 x_train: Real[Array, 'p n'],
1395 y_train: Float32[Array, ' n'] | Float32[Array, 'k n'],
1396 outcome_type: OutcomeType | tuple[OutcomeType, ...],
1397 offset: Float32[Array, ''] | Float32[Array, ' k'],
1398 error_scale: Float[Array, ' n'] | Float[Array, 'k n'] | None,
1399 missing: Bool[Array, ' n'] | Bool[Array, 'k n'] | None,
1400 max_split: UInt[Array, ' p'],
1401 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
1402 error_cov_inv: Wishart | None,
1403 power: FloatLike,
1404 base: FloatLike,
1405 maxdepth: int,
1406 num_trees: int,
1407 init_kw: Mapping[str, Any],
1408 rm_const: bool,
1409 sparse: SparseConfig,
1410 varprob: Float[ArrayLike, ' p'] | None,
1411 num_chains: int | None,
1412 num_chain_devices: int | None | Literal['auto'],
1413 num_data_devices: int | None,
1414 devices: Literal['cpu', 'gpu'] | Device | Sequence[Device] | None,
1415 n_burn: int,
1416 mcmc_key: Key[Array, ''],
1417) -> tuple[State, Key[Array, ''], Device | None]:
1418 p_nonterminal = make_p_nonterminal(maxdepth, base, power)
1420 # resolve the sparsity prior hyperparameters
1421 theta, a, b, rho = _process_sparsity_settings(x_train, sparse)
1423 # process device settings
1424 device_kw, device = process_device_settings(
1425 y_train, num_chains, num_chain_devices, num_data_devices, devices
1426 )
1428 kw: dict = dict(
1429 X=x_train,
1430 y=y_train,
1431 outcome_type=outcome_type,
1432 offset=offset,
1433 error_scale=error_scale,
1434 missing=missing,
1435 max_split=max_split,
1436 num_trees=num_trees,
1437 p_nonterminal=p_nonterminal,
1438 leaf_prior_cov_inv=leaf_prior_cov_inv,
1439 error_cov_inv=error_cov_inv,
1440 min_points_per_decision_node=10,
1441 log_s=process_varprob(varprob, max_split),
1442 theta=theta,
1443 a=a,
1444 b=b,
1445 rho=rho,
1446 sparse_on_at=n_burn // 2 if sparse.enabled else None,
1447 augment=sparse.augment,
1448 **device_kw,
1449 )
1451 if rm_const:
1452 n_empty = jnp.sum(max_split == 0).item()
1453 kw.update(filter_splitless_vars=n_empty)
1455 kw.update(init_kw)
1457 state = init(**kw)
1459 # put state and mcmc key on device if requested explicitly by the user
1460 if device is not None:
1461 mcmc_key, state = device_put((mcmc_key, state), device, donate=True)
1463 return state, mcmc_key, device
1466def _run_mcmc(
1467 mcmc_state: State,
1468 n_save: int,
1469 n_burn: int,
1470 n_skip: int,
1471 printevery: int | None,
1472 pbar: bool,
1473 key: Key[Array, ''],
1474 run_mcmc_kw: Mapping,
1475) -> RunMCMCResult:
1476 # prepare arguments
1477 kw: dict = dict(n_burn=n_burn, n_skip=n_skip, inner_loop_length=printevery)
1478 # `printevery=None` disables progress reporting entirely: no callback is
1479 # installed, so the loop traces without any `debug.callback` effect (a tqdm
1480 # bar would otherwise advance every iteration regardless of `printevery`).
1481 if printevery is not None:
1482 if pbar:
1483 kw.update(make_tqdm_callback(mcmc_state, report_every=printevery))
1484 else:
1485 kw.update(
1486 make_print_callback(
1487 mcmc_state,
1488 dot_every=None if printevery == 1 else 1,
1489 report_every=printevery,
1490 )
1491 )
1492 kw.update(run_mcmc_kw)
1494 return run_mcmc(key, mcmc_state, n_save, **kw)
1497@jit(static_argnames='p')
1498# this is jitted such that lax.collapse below does not create a copy
1499def varcount(p: int, trace: MainTrace) -> Int32[Array, 'ndpost p']:
1500 """Histogram of predictor usage for decision rules in the trees, squashing chains."""
1501 varcount: Int32[Array, '*chains samples p']
1502 varcount = compute_varcount(p, trace, out_chain_axis=0)
1503 return lax.collapse(varcount, 0, -1)
1506@jit(static_argnames='mean')
1507def get_error_sdev(
1508 trace: MainTrace,
1509 binary_mask: Bool[Array, ''] | Bool[Array, ' k'],
1510 *,
1511 mean: bool = False,
1512) -> (
1513 Float32[Array, ' ndpost']
1514 | Float32[Array, 'ndpost k']
1515 | Float32[Array, '']
1516 | Float32[Array, ' k']
1517):
1518 """Error standard deviation, post-burnin, chains concatenated."""
1519 prec = trace.error_cov_inv
1520 if trace.has_chains:
1521 # shape (chains, samples) or (chains, samples, k, k), concatenate chains
1522 prec = chain_to_axis(prec, chain_vmap_axes(trace).error_cov_inv)
1523 prec = lax.collapse(prec, 0, 2)
1524 is_uv = prec.ndim == 1
1525 if is_uv:
1526 # univariate case, reshape to 1x1 matrix
1527 prec = prec[..., None, None]
1529 # invert precision to covariance, then take diagonal variance
1530 cov = _inv_via_chol_with_gersh(prec)
1531 var = jnp.diagonal(cov, axis1=-2, axis2=-1)
1532 if mean:
1533 var = var.mean(0)
1534 sdev = jnp.sqrt(var)
1535 if is_uv:
1536 sdev = sdev.squeeze(-1)
1537 return jnp.where(binary_mask, jnp.nan, sdev)
1540@jit(static_argnames='only_continuous')
1541def get_latent_prec(
1542 burnin_trace: BurninTrace,
1543 main_trace: MainTrace,
1544 binary_indices: Int32[Array, ' kb'] | None,
1545 *,
1546 only_continuous: bool = False,
1547) -> (
1548 Float32[Array, ' n_burn_plus_n_save']
1549 | Float32[Array, 'n_burn_plus_n_save k k']
1550 | Float32[Array, 'num_chains n_burn_plus_n_save']
1551 | Float32[Array, 'num_chains n_burn_plus_n_save k k']
1552):
1553 """Latent error precision trace, burn-in + main concatenated."""
1554 burnin = burnin_trace.error_cov_inv
1555 main = main_trace.error_cov_inv
1556 sample_axis = trace_sample_axes(main_trace).error_cov_inv
1557 prec = jnp.concatenate([burnin, main], axis=sample_axis)
1558 prec = chain_to_axis(prec, chain_vmap_axes(main_trace).error_cov_inv)
1559 if only_continuous and binary_indices is not None:
1560 *_, k, _ = prec.shape
1561 kc = k - binary_indices.size
1562 mask = jnp.ones(k, dtype=bool).at[binary_indices].set(False)
1563 (cont_indices,) = jnp.nonzero(mask, size=kc)
1564 prec = prec[..., cont_indices[:, None], cont_indices[None, :]]
1565 return prec
1568@jit
1569def varprob(
1570 max_split: UInt[Array, ' p'], trace: MainTrace
1571) -> Float32[Array, 'ndpost p']:
1572 """Posterior samples of predictor selection probability, chains concatenated."""
1573 p = max_split.size
1574 varprob = trace.varprob
1575 if varprob is None:
1576 ndpost = trace.grow_prop_count.size
1577 peff = jnp.count_nonzero(max_split)
1578 out = jnp.where(max_split, 1 / peff, 0)
1579 return jnp.broadcast_to(out, (ndpost, p))
1580 varprob = chain_to_axis(varprob, chain_vmap_axes(trace).varprob)
1581 return varprob.reshape(-1, p)
1584def _trees_chain_first(obj: TreeHeaps) -> TreesTrace:
1585 """Extract `obj`'s heap arrays, moving any chain axis to the front.
1587 Returns a `TreesTrace` whose leading axis is the chain axis when `obj`
1588 carries one, and the bare per-object heap arrays otherwise.
1589 """
1590 trees = TreesTrace.from_dataclass(obj)
1591 if get_has_chains(obj):
1592 axes = trees.axes_from_dataclass(chain_vmap_axes(obj))
1593 # WORKAROUND(python<3.14): use operator.is_none
1594 trees = tree.map(chain_to_axis, trees, axes, is_leaf=lambda x: x is None)
1595 return trees
1598@jit
1599def check_trees(
1600 trace: MainTrace, max_split: UInt[Array, ' p']
1601) -> UInt[Array, 'num_chains n_save num_trees']:
1602 """Apply `bartz.grove.check_trace` to all the tree draws."""
1603 trees = _trees_chain_first(trace)
1604 out: UInt[Array, '*chains samples num_trees']
1605 out = check_trace(trees, max_split)
1606 if out.ndim < 3:
1607 out = out[None, :, :]
1608 return out
1611@jit
1612def tree_goes_bad(
1613 trace: MainTrace, max_split: UInt[Array, ' p']
1614) -> Bool[Array, 'num_chains n_save num_trees']:
1615 """Find iterations where a tree becomes invalid."""
1616 bad = check_trees(trace, max_split).astype(bool)
1617 bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)])
1618 return bad & ~bad_before
1621@jit
1622def compare_resid(
1623 state: State, y: Float32[Array, ' n'] | Float32[Array, 'k n'] | None
1624) -> tuple[
1625 Float32[Array, '*num_chains n'] | Float32[Array, '*num_chains k n'],
1626 Float32[Array, '*num_chains n'] | Float32[Array, '*num_chains k n'],
1627]:
1628 """Re-compute residuals to compare them with the updated ones."""
1629 chain_axes = chain_vmap_axes(state)
1630 resid1 = chain_to_axis(state.resid, chain_axes.resid)
1631 z = chain_to_axis(state.z, chain_axes.z) if state.z is not None else None
1633 forests = _trees_chain_first(state.forest)
1634 trees = evaluate_forest(state.X, forests, sum_batch_axis=-1)
1636 if state.binary_indices is not None:
1637 # mixed binary-continuous: z has only binary rows, y has all rows
1638 assert y is not None
1639 ref = jnp.broadcast_to(y, resid1.shape)
1640 ref = ref.at[..., state.binary_indices, :].set(z)
1641 elif z is not None:
1642 ref = z
1643 else:
1644 assert y is not None
1645 ref = y
1646 resid2 = ref - (trees + state.offset[..., None])
1648 return resid1, resid2
1651@jit
1652def depth_distr(trace: MainTrace) -> Int32[Array, '*num_chains n_save d']:
1653 """Histogram of tree depths for each state of the trees."""
1654 split_tree = chain_to_axis(trace.split_tree, chain_vmap_axes(trace).split_tree)
1655 out: Int32[Array, '*chains samples d']
1656 out = forest_depth_distr(split_tree)
1657 if out.ndim < 3: 1657 ↛ 1659line 1657 didn't jump to line 1659 because the condition on line 1657 was always true
1658 out = out[None, :, :]
1659 return out
1662@jit(static_argnames='node_type')
1663def points_per_node_distr_trace(
1664 X: UInt[Array, 'p n'], trace: MainTrace, node_type: Literal['leaf', 'leaf-parent']
1665) -> Int32[Array, '*num_chains n_save n+1']:
1666 """Histogram of number of points per node, for every tree draw in the trace."""
1667 chain_axes = chain_vmap_axes(trace)
1668 var_tree = chain_to_axis(trace.var_tree, chain_axes.var_tree)
1669 split_tree = chain_to_axis(trace.split_tree, chain_axes.split_tree)
1670 out: Int32[Array, '*chains samples n+1']
1671 out = points_per_node_distr(X, var_tree, split_tree, node_type, sum_batch_axis=-1)
1672 if out.ndim < 3:
1673 out = out[None, :, :]
1674 return out
1677class DeviceKwArgs(TypedDict):
1678 num_chains: int | None
1679 mesh: Mesh | None
1682def process_device_settings(
1683 y_train: Shaped[Array, '...'],
1684 num_chains: int | None,
1685 num_chain_devices: int | None | Literal['auto'],
1686 num_data_devices: int | None,
1687 devices: Literal['cpu', 'gpu'] | Device | Sequence[Device] | None,
1688) -> tuple[DeviceKwArgs, Device | None]:
1689 """Return the arguments for `mcmcstep.init` related to devices, and an optional device where to put the state."""
1690 # whether the user pinned a concrete pool of devices (vs. inheriting all of
1691 # the platform's devices); the auto chain sharding may not exceed that pool
1692 explicit_devices = devices is not None and not isinstance(devices, str)
1693 platform, device, devices = _determine_devices(y_train, devices)
1694 num_chain_devices = _determine_num_chain_devices(
1695 platform,
1696 num_chains,
1697 num_chain_devices,
1698 num_data_devices,
1699 len(devices),
1700 explicit_devices,
1701 )
1702 mesh, device = _determine_mesh(num_chain_devices, num_data_devices, device, devices)
1704 # prepare arguments to `init`
1705 settings = DeviceKwArgs(num_chains=num_chains, mesh=mesh)
1707 return settings, device
1710def _determine_devices(
1711 y_train: Shaped[Array, '...'],
1712 devices: Literal['cpu', 'gpu'] | Device | Sequence[Device] | None,
1713) -> tuple[str, Device | None, Sequence[Device]]:
1714 """Determine the target platform and set of devices for the MCMC, and possibly a single target device."""
1715 if isinstance(devices, str):
1716 platform = devices
1717 devices = jax.devices(platform)
1718 return platform, devices[0], devices
1719 elif devices is not None:
1720 if not hasattr(devices, '__len__'):
1721 devices = (devices,)
1722 device = devices[0]
1723 return device.platform, device, devices
1724 elif hasattr(y_train, 'platform'): 1724 ↛ 1731line 1724 didn't jump to line 1731 because the condition on line 1724 was always true
1725 # set device=None because if the devices were not specified explicitly
1726 # we may be in the case where computation will follow data placement,
1727 # do not disturb jax as the user may be playing with vmap, jit, reshard...
1728 platform = y_train.platform() # ty: ignore[call-non-callable]
1729 return platform, None, jax.devices(platform)
1730 else:
1731 msg = 'not possible to infer device from `y_train`, please set `devices`'
1732 raise ValueError(msg)
1735def _largest_divisor_at_most(n: int, cap: int) -> int:
1736 """Return the largest divisor of `n` in [1, cap]."""
1737 for d in range(cap, 0, -1): 1737 ↛ 1740line 1737 didn't jump to line 1740 because the loop on line 1737 didn't complete
1738 if n % d == 0:
1739 return d
1740 return 1 # unreachable: 1 always divides n
1743def _determine_num_chain_devices(
1744 platform: str,
1745 num_chains: int | None,
1746 num_chain_devices: int | None | Literal['auto'],
1747 num_data_devices: int | None,
1748 num_devices: int,
1749 explicit_devices: bool,
1750) -> int | None:
1751 """Resolve and validate `num_chain_devices`, returning the chain mesh axis size or `None`."""
1752 if num_chain_devices == 'auto':
1753 num_chain_devices = _auto_num_chain_devices(
1754 platform, num_chains, num_data_devices, num_devices, explicit_devices
1755 )
1757 # an explicit value must be a positive divisor of the number of chains
1758 if num_chain_devices is not None:
1759 effective_chains = 1 if num_chains is None else num_chains
1760 if num_chain_devices < 1 or effective_chains % num_chain_devices:
1761 chains_desc = (
1762 'a single chain (num_chains=None)'
1763 if num_chains is None
1764 else f'num_chains={num_chains}'
1765 )
1766 msg = (
1767 f'num_chain_devices={num_chain_devices} must be a positive '
1768 f'divisor of the number of chains ({chains_desc})'
1769 )
1770 raise ValueError(msg)
1772 # there is no chain axis to shard when the chains are scalar
1773 if num_chains is None:
1774 return None
1775 return num_chain_devices
1778def _auto_num_chain_devices(
1779 platform: str,
1780 num_chains: int | None,
1781 num_data_devices: int | None,
1782 num_devices: int,
1783 explicit_devices: bool,
1784) -> int | None:
1785 """Pick `num_chain_devices` automatically for multi-chain cpu runs.
1787 `num_data_devices` reserves devices for the data axis, so the chain axis can
1788 only use a fraction of them; this keeps the ``chains x data`` mesh within the
1789 `num_devices` available devices.
1790 """
1791 if num_chains is None or num_chains == 1 or platform != 'cpu':
1792 return None
1793 data_devices = num_data_devices or 1
1794 num_cores = cpu_count()
1795 assert num_cores is not None, 'could not determine number of cpu cores'
1797 # devices available for the chain axis after reserving for the data axis
1798 core_budget = max(1, num_cores // data_devices)
1799 num_shards = _largest_divisor_at_most(num_chains, core_budget)
1801 if num_shards > 1:
1802 # the mesh draws from `num_devices` devices, whether those are all the
1803 # platform's devices or an explicit subset passed by the user
1804 device_budget = max(1, num_devices // data_devices)
1805 if device_budget < num_shards:
1806 new_num_shards = _largest_divisor_at_most(num_chains, device_budget)
1807 warn(
1808 _auto_chain_devices_warning(
1809 num_chains,
1810 num_shards,
1811 new_num_shards,
1812 device_budget,
1813 num_devices,
1814 num_data_devices,
1815 explicit_devices,
1816 )
1817 )
1818 num_shards = new_num_shards
1820 return num_shards if num_shards > 1 else None
1823def _auto_chain_devices_warning(
1824 num_chains: int,
1825 desired: int,
1826 actual: int,
1827 device_budget: int,
1828 num_devices: int,
1829 num_data_devices: int | None,
1830 explicit_devices: bool,
1831) -> str:
1832 """Compose the warning shown when auto chain sharding is capped by the device count."""
1833 if explicit_devices:
1834 pool = f'the {num_devices} devices passed in `devices`'
1835 few = f'only {num_devices} devices were passed in `devices`'
1836 advice = ''
1837 else:
1838 pool = f'the {num_devices} jax cpu devices'
1839 few = f'jax is set up with only {num_devices} cpu devices'
1840 advice = (
1841 ' To enable more parallelization, increase the limit with '
1842 '`jax.config.update("jax_num_cpu_devices", <num_devices>)`.'
1843 )
1844 if num_data_devices:
1845 limit = (
1846 f'only {device_budget} of {pool} are free for chains '
1847 f'(num_data_devices={num_data_devices} reserves the rest)'
1848 )
1849 else:
1850 limit = few
1851 return (
1852 f'`Bart` would like to shard {num_chains} chains across {desired} '
1853 f'devices, but {limit}, so it will use {actual} devices for chains '
1854 f'instead.{advice}'
1855 )
1858def _determine_mesh(
1859 num_chain_devices: int | None,
1860 num_data_devices: int | None,
1861 device: Device | None,
1862 devices: Sequence[Device],
1863) -> tuple[Mesh | None, Device | None]:
1864 """Create a jax device mesh for `mcmcstep.init()`."""
1865 if num_chain_devices is None and num_data_devices is None:
1866 return None, device
1867 else:
1868 mesh = dict()
1869 if num_chain_devices is not None:
1870 mesh.update(chains=num_chain_devices)
1871 if num_data_devices is not None:
1872 mesh.update(data=num_data_devices)
1873 mesh = make_mesh(
1874 axis_shapes=tuple(mesh.values()),
1875 axis_names=tuple(mesh),
1876 axis_types=(AxisType.Auto,) * len(mesh),
1877 devices=devices,
1878 )
1879 return mesh, None
1880 # set device=None because `mcmcstep.init` will `device_put` with the
1881 # mesh already, we don't want to undo its work
1884def process_varprob(
1885 varprob: Float[ArrayLike, ' p'] | None, max_split: UInt[Array, ' p']
1886) -> Float32[Array, ' p'] | None:
1887 """Convert varprob to log_s."""
1888 if varprob is None:
1889 return None
1890 varprob = jnp.asarray(varprob)
1891 assert varprob.shape == max_split.shape, 'varprob must have shape (p,)'
1892 varprob = error_if(varprob, varprob <= 0, 'varprob must be > 0')
1893 return jnp.log(varprob)
1896def predict_latent(
1897 x: UInt[Array, 'p m'],
1898 trace: MainTrace,
1899 test_points: Literal['none', 'autobatch', 'shard_and_autobatch'] = 'none',
1900) -> Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m']:
1901 """Evaluate trees on already quantized `x`, and squash chains."""
1902 return evaluate_trace(x, trace, flatten_chains=True, test_points=test_points)
1905@jit(static_argnums=(5, 6, 7))
1906def predict(
1907 key: Key[Array, ''] | None,
1908 trace: MainTrace,
1909 x_test: UInt[Array, 'p m'],
1910 error_scale: Float[Array, ' m'] | Float[Array, 'k m'] | None,
1911 binary_indices: Int32[Array, ' kb'] | None,
1912 has_binary: bool,
1913 kind: PredictKind | str,
1914 test_points: Literal['none', 'autobatch', 'shard_and_autobatch'],
1915 /,
1916) -> (
1917 Float32[Array, ' m']
1918 | Float32[Array, 'k m']
1919 | Float32[Array, 'ndpost m']
1920 | Float32[Array, 'ndpost k m']
1921):
1922 """Implement `Bart.predict`."""
1923 # get latent i.e. bare sum-of-trees predictions
1924 latent = predict_latent(x_test, trace, test_points)
1925 if kind is PredictKind.latent_samples:
1926 return latent
1928 # sample posterior (uses latent directly, no probit squash needed)
1929 if kind is PredictKind.outcome_samples:
1930 assert key is not None
1931 return sample_outcome(
1932 key, trace, latent, error_scale, binary_indices, has_binary
1933 )
1935 # squash predictions to (0, 1) if probit
1936 if binary_indices is not None:
1937 indexing = jnp.s_[..., binary_indices, :]
1938 mean_samples = latent.at[indexing].set(ndtr(latent[indexing]))
1939 elif has_binary: # self._mcmc_state.binary_y is not None:
1940 mean_samples = ndtr(latent)
1941 else:
1942 mean_samples = latent
1944 # take mean or return samples
1945 if kind is PredictKind.mean:
1946 return mean_samples.mean(axis=0)
1947 return mean_samples
1950@jit(static_argnums=(5,))
1951def sample_outcome(
1952 key: Key[Array, ''],
1953 trace: MainTrace,
1954 latent: Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m'],
1955 error_scale: Float32[Array, ' m'] | Float32[Array, 'k m'] | None,
1956 binary_indices: Int32[Array, ' kb'] | None,
1957 has_binary: bool,
1958 /,
1959) -> Float32[Array, 'ndpost m'] | Float32[Array, 'ndpost k m']:
1960 """Sample from the posterior predictive distribution."""
1961 # move error_cov_inv chain axis to 0
1962 prec = chain_to_axis(trace.error_cov_inv, chain_vmap_axes(trace).error_cov_inv)
1964 if latent.ndim > 2: # multivariate case
1965 error_cov_inv = lax.collapse(prec, 0, -2)
1967 # Cholesky of precision: error_cov_inv = L @ L^T
1968 L = chol_with_gersh(error_cov_inv) # (ndpost, k, k)
1970 # Sample z ~ N(0, I) and solve L^T @ error = z
1971 # so error = L^{-T} z ~ N(0, L^{-T} L^{-1}) = N(0, Sigma)
1972 z = random.normal(key, latent.shape) # (ndpost, k, m)
1973 error = solve_triangular(L, z, trans='T', lower=True) # (ndpost, k, m)
1974 if error_scale is not None:
1975 # error_scale is (m,) or (k, m) so it always broadcasts right
1976 error *= error_scale
1977 elif has_binary: 1977 ↛ 1979line 1977 didn't jump to line 1979 because the condition on line 1977 was never true
1978 # pure binary UV: probit has sigma = 1
1979 error = random.normal(key, latent.shape)
1980 else: # univariate continuous
1981 sigma = jnp.sqrt(jnp.reciprocal(prec)).reshape(-1)
1982 error = sigma[..., None] * random.normal(key, latent.shape)
1983 if error_scale is not None: 1983 ↛ 1986line 1983 didn't jump to line 1986 because the condition on line 1983 was always true
1984 error *= error_scale[None, :]
1986 outcome = latent + error
1988 # convert binary outcomes via latent probit thresholding
1989 if binary_indices is not None:
1990 idx = jnp.s_[..., binary_indices, :]
1991 outcome = outcome.at[idx].set(jnp.where(outcome[idx] > 0, 1.0, 0.0))
1992 elif has_binary:
1993 outcome = jnp.where(outcome > 0, 1.0, 0.0)
1995 return outcome