Coverage for src / bartz / mcmcstep / _state.py: 95%
454 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 18:11 +0000
1# bartz/src/bartz/mcmcstep/_state.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Module defining the BART MCMC state and initialization."""
27from collections.abc import Callable, Hashable, Sequence
28from dataclasses import fields, replace
29from enum import Enum
30from functools import partial, wraps
31from math import log2
32from typing import Any, Literal, TypedDict, TypeVar
34import numpy
35from equinox import Module, error_if, filter_jit
36from equinox import field as eqx_field
37from jax import (
38 NamedSharding,
39 device_put,
40 eval_shape,
41 jit,
42 lax,
43 make_mesh,
44 random,
45 tree,
46 vmap,
47)
48from jax import numpy as jnp
49from jax.scipy.linalg import solve_triangular
50from jax.sharding import AxisType, Mesh, PartitionSpec
51from jaxtyping import Array, Bool, Float32, Int32, Integer, PyTree, Shaped, UInt
53from bartz.grove import tree_depths
54from bartz.jaxext import get_default_device, is_key, minimal_unsigned_dtype
57class OutcomeType(Enum):
58 """Whether the regression outcome is continuous or binary (probit)."""
60 continuous = 'continuous'
61 binary = 'binary'
64def field(*, chains: bool = False, data: bool = False, **kwargs: Any): # noqa: ANN202
65 """Extend `equinox.field` with two new parameters.
67 Parameters
68 ----------
69 chains
70 Whether the arrays in the field have an optional first axis that
71 represents independent Markov chains.
72 data
73 Whether the last axis of the arrays in the field represent units of
74 the data.
75 **kwargs
76 Other parameters passed to `equinox.field`.
78 Returns
79 -------
80 A dataclass field descriptor with the special attributes in the metadata, unset if False.
81 """
82 metadata = dict(kwargs.pop('metadata', {}))
83 assert 'chains' not in metadata
84 assert 'data' not in metadata
85 if chains:
86 metadata['chains'] = True
87 if data:
88 metadata['data'] = True
89 return eqx_field(metadata=metadata, **kwargs)
92def chain_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']:
93 """Determine vmapping axes for chains.
95 This function determines the argument to the `in_axes` or `out_axes`
96 parameter of `jax.vmap` to vmap over all and only the chain axes found in the
97 pytree `x`.
99 Parameters
100 ----------
101 x
102 A pytree. Subpytrees that are Module attributes marked with
103 ``field(..., chains=True)`` are considered to have a leading chain axis.
105 Returns
106 -------
107 A pytree with the same structure as `x` with 0 or None in the leaves.
108 """
109 return _find_metadata(x, 'chains', 0, None) 1ab
112def data_vmap_axes(x: PyTree[Module | Any, 'T']) -> PyTree[int | None, 'T']:
113 """Determine vmapping axes for data.
115 This is analogous to `chain_vmap_axes` but returns -1 for all fields
116 marked with ``field(..., data=True)``.
117 """
118 return _find_metadata(x, 'data', -1, None) 1ab
121T = TypeVar('T')
124def _find_metadata(
125 x: PyTree[Any, ' S'], key: Hashable, if_true: T, if_false: T
126) -> PyTree[T, ' S']:
127 """Replace all subtrees of x marked with a metadata key."""
129 def is_lazy_array(x: object) -> bool: 1ab
130 return isinstance(x, _LazyArray) 1ab
132 def is_module(x: object) -> bool: 1ab
133 return isinstance(x, Module) and not is_lazy_array(x) 1ab
135 if is_module(x): 1ab
136 args = [] 1ab
137 for f in fields(x): 1ab
138 v = getattr(x, f.name) 1ab
139 if f.metadata.get('static', False): 1ab
140 args.append(v) 1ab
141 elif f.metadata.get(key, False): 1ab
142 subtree = tree.map(lambda _: if_true, v, is_leaf=is_lazy_array) 1ab
143 args.append(subtree) 1ab
144 else:
145 args.append(_find_metadata(v, key, if_true, if_false)) 1ab
146 return x.__class__(*args) 1ab
148 def get_axes(x: object) -> PyTree[T]: 1ab
149 if is_module(x): 1aeb
150 return _find_metadata(x, key, if_true, if_false) 1ae
151 else:
152 return tree.map(lambda _: if_false, x, is_leaf=is_lazy_array) 1ab
154 def is_leaf(x: object) -> bool: 1ab
155 return isinstance(x, Module) # this catches _LazyArray as well 1ab
157 return tree.map(get_axes, x, is_leaf=is_leaf) 1ab
160class Forest(Module):
161 """Represents the MCMC state of a sum of trees."""
163 leaf_tree: (
164 Float32[Array, '*chains num_trees 2**d']
165 | Float32[Array, '*chains num_trees k 2**d']
166 ) = field(chains=True)
167 """The leaf values."""
169 var_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
170 """The decision axes."""
172 split_tree: UInt[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
173 """The decision boundaries."""
175 affluence_tree: Bool[Array, '*chains num_trees 2**(d-1)'] = field(chains=True)
176 """Marks leaves that can be grown."""
178 max_split: UInt[Array, ' p']
179 """The maximum split index for each predictor."""
181 blocked_vars: UInt[Array, ' q'] | None
182 """Indices of variables that are not used. This shall include at least
183 the `i` such that ``max_split[i] == 0``, otherwise behavior is
184 undefined."""
186 p_nonterminal: Float32[Array, ' 2**d']
187 """The prior probability of each node being nonterminal, conditional on
188 its ancestors. Includes the nodes at maximum depth which should be set
189 to 0."""
191 p_propose_grow: Float32[Array, ' 2**(d-1)']
192 """The unnormalized probability of picking a leaf for a grow proposal."""
194 leaf_indices: UInt[Array, '*chains num_trees n'] = field(chains=True, data=True)
195 """The index of the leaf each datapoints falls into, for each tree."""
197 min_points_per_decision_node: Int32[Array, ''] | None
198 """The minimum number of data points in a decision node."""
200 min_points_per_leaf: Int32[Array, ''] | None
201 """The minimum number of data points in a leaf node."""
203 log_trans_prior: Float32[Array, '*chains num_trees'] | None = field(chains=True)
204 """The log transition and prior Metropolis-Hastings ratio for the
205 proposed move on each tree."""
207 log_likelihood: Float32[Array, '*chains num_trees'] | None = field(chains=True)
208 """The log likelihood ratio."""
210 grow_prop_count: Int32[Array, '*chains'] = field(chains=True)
211 """The number of grow proposals made during one full MCMC cycle."""
213 prune_prop_count: Int32[Array, '*chains'] = field(chains=True)
214 """The number of prune proposals made during one full MCMC cycle."""
216 grow_acc_count: Int32[Array, '*chains'] = field(chains=True)
217 """The number of grow moves accepted during one full MCMC cycle."""
219 prune_acc_count: Int32[Array, '*chains'] = field(chains=True)
220 """The number of prune moves accepted during one full MCMC cycle."""
222 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'] | None
223 """The prior precision matrix of a leaf, conditional on the tree structure.
224 For the univariate case (k=1), this is a scalar (the inverse variance).
225 The prior covariance of the sum of trees is
226 ``num_trees * leaf_prior_cov_inv^-1``."""
228 log_s: Float32[Array, '*chains p'] | None = field(chains=True)
229 """The logarithm of the prior probability for choosing a variable to split
230 along in a decision rule, conditional on the ancestors. Not normalized.
231 If `None`, use a uniform distribution."""
233 theta: Float32[Array, '*chains'] | None = field(chains=True)
234 """The concentration parameter for the Dirichlet prior on the variable
235 distribution `s`. Required only to update `log_s`."""
237 a: Float32[Array, ''] | None
238 """Parameter of the prior on `theta`. Required only to sample `theta`.
239 See `step_theta`."""
241 b: Float32[Array, ''] | None
242 """Parameter of the prior on `theta`. Required only to sample `theta`.
243 See `step_theta`."""
245 rho: Float32[Array, ''] | None
246 """Parameter of the prior on `theta`. Required only to sample `theta`.
247 See `step_theta`."""
249 def num_chains(self) -> int | None:
250 """Return the number of chains, or `None` if not multichain."""
251 # maybe this should be replaced by chain_shape() -> () | (int,)
252 if self.var_tree.ndim == 2: 1areo
253 return None 1ro
254 else:
255 return self.var_tree.shape[0] 1ae
258class StepConfig(Module):
259 """Options for the MCMC step."""
261 steps_done: Int32[Array, '']
262 """The number of MCMC steps completed so far."""
264 sparse_on_at: Int32[Array, ''] | None
265 """After how many steps to turn on variable selection."""
267 resid_num_batches: int | None = field(static=True)
268 """The number of batches for computing the sum of residuals. If
269 `None`, they are computed with no batching."""
271 count_num_batches: int | None = field(static=True)
272 """The number of batches for computing counts. If
273 `None`, they are computed with no batching."""
275 prec_num_batches: int | None = field(static=True)
276 """The number of batches for computing precision scales. If
277 `None`, they are computed with no batching."""
279 prec_count_num_trees: int | None = field(static=True)
280 """Batch size for processing trees to compute count and prec trees."""
282 mesh: Mesh | None = field(static=True)
283 """The mesh used to shard data and computation across multiple devices."""
286class State(Module):
287 """Represents the MCMC state of BART."""
289 X: UInt[Array, 'p n'] = field(data=True)
290 """The predictors."""
292 binary_y: None | Bool[Array, ' n'] | Bool[Array, 'k n'] = field(data=True)
293 """The response as booleans for binary regression, `None` for continuous.
294 In the mixed binary-continuous case, only the binary outcome components
295 are stored, with shape ``(kb, n)``."""
297 z: None | Float32[Array, '*chains n'] | Float32[Array, '*chains k n'] = field(
298 chains=True, data=True
299 )
300 """The latent variable for binary regression. `None` in continuous
301 regression. In the mixed binary-continuous case, only the binary outcome
302 components are stored, with shape ``(*chains, kb, n)``."""
304 binary_indices: None | Int32[Array, ' kb']
305 """The indices of binary outcome components in the full list of outcome
306 components. `None` when there are no binary components. Filled in by
307 `init` and used by `step_z` to update only the binary rows of `resid`."""
309 offset: Float32[Array, ''] | Float32[Array, ' k']
310 """Constant shift added to the sum of trees."""
312 resid: Float32[Array, '*chains n'] | Float32[Array, '*chains k n'] = field(
313 chains=True, data=True
314 )
315 """The residuals (`y` or `z` minus sum of trees)."""
317 error_cov_inv: Float32[Array, '*chains'] | Float32[Array, '*chains k k'] = field(
318 chains=True
319 )
320 """The inverse error covariance (scalar for univariate, matrix for multivariate).
321 Identity in binary regression."""
323 prec_scale: Float32[Array, ' n'] | None = field(data=True)
324 """The scale on the error precision, i.e., ``1 / error_scale ** 2``.
325 `None` in binary regression."""
327 error_cov_df: Float32[Array, ''] | None
328 """The df parameter of the inverse Wishart prior on the noise
329 covariance. For the univariate case, the relationship to the inverse
330 gamma prior parameters is ``alpha = df / 2``.
331 `None` in binary regression."""
333 error_cov_scale: Float32[Array, ''] | Float32[Array, 'k k'] | None
334 """The scale parameter of the inverse Wishart prior on the noise
335 covariance. For the univariate case, the relationship to the inverse
336 gamma prior parameters is ``beta = scale / 2``.
337 `None` in binary regression."""
339 forest: Forest
340 """The sum of trees model."""
342 config: StepConfig
343 """Metadata and configurations for the MCMC step."""
346def _init_shape_shifting_parameters(
347 y: Float32[Array, ' n'] | Float32[Array, 'k n'],
348 outcome_type: OutcomeType | list[OutcomeType],
349 offset: Float32[Array, ''] | Float32[Array, ' k'],
350 error_scale: Float32[Any, ' n'] | None,
351 error_cov_df: float | Float32[Any, ''] | None,
352 error_cov_scale: float | Float32[Any, ''] | Float32[Any, 'k k'] | None,
353 leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
354) -> tuple[
355 bool,
356 tuple[()] | tuple[int],
357 None | Float32[Array, ''],
358 None | Float32[Array, ''],
359 None | Float32[Array, ''],
360 None | Int32[Array, ' kb'],
361]:
362 """
363 Check and initialize parameters that change array type/shape based on outcome kind.
365 Parameters
366 ----------
367 y
368 The response variable (used only for shape checks).
369 outcome_type
370 Whether the regression is continuous or binary. Can be a list of
371 `OutcomeType` for per-component specification in the multivariate case.
372 offset
373 The offset to add to the predictions.
374 error_scale
375 Per-observation error scale (univariate only).
376 error_cov_df
377 The error covariance degrees of freedom.
378 error_cov_scale
379 The error covariance scale.
380 leaf_prior_cov_inv
381 The inverse of the leaf prior covariance.
383 Returns
384 -------
385 is_binary
386 Whether all outcomes are binary.
387 kshape
388 The outcome shape, empty for univariate, (k,) for multivariate.
389 error_cov_inv
390 The initialized error covariance inverse.
391 error_cov_df
392 The error covariance degrees of freedom (as array).
393 error_cov_scale
394 The error covariance scale (as array).
395 binary_indices
396 The indices of binary outcome components, or `None` if there are none.
397 """
398 kshape = offset.shape 1ab
400 # determine per-component outcome kinds
401 if isinstance(outcome_type, list): 1afgbh
402 assert kshape, 'per-component outcome_type requires multivariate y' 1fgh
403 (k,) = kshape 1fgh
404 assert len(outcome_type) == k 1fguh
405 binary_mask = [t is OutcomeType.binary for t in outcome_type] 1fguh
406 is_binary = all(binary_mask) 1fgh
407 is_mixed = any(binary_mask) and not is_binary 1fguh
408 else:
409 is_binary = outcome_type is OutcomeType.binary 1ab
410 is_mixed = False 1ab
412 if is_mixed: 1afgbuh
413 binary_indices = jnp.array([i for i, b in enumerate(binary_mask) if b]) 1fgh
414 else:
415 binary_indices = None 1ab
417 # All-binary
418 if is_binary: 1adfgbhk
419 assert error_scale is None 1dk
420 assert error_cov_df is None 1dk
421 assert error_cov_scale is None 1dk
422 if kshape: 1dAk
423 error_cov_inv = jnp.eye(kshape[0]) 1A
424 else:
425 error_cov_inv = jnp.array(1.0) 1dk
427 # Mixed binary-continuous (multivariate, diagonal error covariance)
428 elif is_mixed: 1afgbh
429 assert error_scale is None, ( 1fgBh
430 'error_scale is not supported for mixed binary-continuous'
431 )
432 error_cov_df = jnp.asarray(error_cov_df) 1fgBh
433 error_cov_scale = jnp.asarray(error_cov_scale) 1fgh
434 assert error_cov_scale.shape == 2 * kshape 1fgh
436 # enforce diagonal error_cov_scale
437 diag = jnp.diag(jnp.diag(error_cov_scale)) 1fgh
438 error_cov_scale = error_if( 1fgh
439 error_cov_scale,
440 jnp.any(error_cov_scale != diag),
441 'error_cov_scale must be diagonal for mixed binary-continuous',
442 )
444 # initialize diagonal error_cov_inv: use inv-gamma mode for continuous
445 # components, 1.0 for binary components
446 scale_diag = jnp.diag(error_cov_scale) 1fgh
447 inv_diag = jnp.where( 1fgh
448 jnp.array(binary_mask),
449 1.0,
450 error_cov_df / jnp.where(scale_diag, scale_diag, 1.0),
451 )
452 error_cov_inv = jnp.diag(inv_diag) 1fgh
454 # All-continuous
455 else:
456 error_cov_df = jnp.asarray(error_cov_df) 1ab
457 error_cov_scale = jnp.asarray(error_cov_scale) 1ab
458 assert error_cov_scale.shape == 2 * kshape 1ab
460 # Multivariate vs univariate
461 if kshape: 1almb
462 error_cov_inv = error_cov_df * _inv_via_chol_with_gersh(error_cov_scale) 1lm
463 else:
464 # inverse gamma prior: alpha = df / 2, beta = scale / 2
465 error_cov_inv = error_cov_df / error_cov_scale 1ab
467 assert y.shape[:-1] == kshape 1ab
468 assert leaf_prior_cov_inv.shape == 2 * kshape 1ab
470 return ( 1ab
471 is_binary,
472 kshape,
473 error_cov_inv,
474 error_cov_df,
475 error_cov_scale,
476 binary_indices,
477 )
480def _check_splitless_vars(
481 filter_splitless_vars: int,
482 max_split: UInt[Array, ' p'],
483 offset: Float32[Array, ''] | Float32[Array, ' k'],
484) -> Float32[Array, ''] | Float32[Array, ' k']:
485 """Check there aren't too many deactivated predictors."""
486 msg = ( 1ab
487 f'there are more than {filter_splitless_vars=} predictors with no splits, '
488 'please increase `filter_splitless_vars` or investigate the missing splits'
489 )
490 return error_if(offset, jnp.sum(max_split == 0) > filter_splitless_vars, msg) 1ab
493def _parse_outcome_type(
494 outcome_type: 'OutcomeType | str | Sequence[OutcomeType | str]',
495) -> 'OutcomeType | list[OutcomeType]':
496 """Normalize outcome_type to enum (or list of enums)."""
497 if isinstance(outcome_type, Sequence) and not isinstance(outcome_type, str): 1afpgqbh
498 return [OutcomeType(t) for t in outcome_type] 1fgh
499 else:
500 return OutcomeType(outcome_type) 1apqb
503def _parse_p_nonterminal(
504 p_nonterminal: Float32[Any, ' d_minus_1'],
505) -> Float32[Array, ' d_minus_1+1']:
506 """Check it's in (0, 1) and pad with a 0 at the end."""
507 p_nonterminal = jnp.asarray(p_nonterminal) 1ab
508 ok = (p_nonterminal > 0) & (p_nonterminal < 1) 1ab
509 p_nonterminal = error_if(p_nonterminal, ~ok, 'p_nonterminal must be in (0, 1)') 1ab
510 return jnp.pad(p_nonterminal, (0, 1)) 1ab
513def make_p_nonterminal(
514 d: int,
515 alpha: float | Float32[Array, ''] = 0.95,
516 beta: float | Float32[Array, ''] = 2.0,
517) -> Float32[Array, ' {d}-1']:
518 """Prepare the `p_nonterminal` argument to `init`.
520 It is calculated according to the formula:
522 P_nt(depth) = alpha / (1 + depth)^beta, with depth 0-based
524 Parameters
525 ----------
526 d
527 The maximum depth of the trees (d=1 means tree with only root node)
528 alpha
529 The a priori probability of the root node having children, conditional
530 on it being possible
531 beta
532 The exponent of the power decay of the probability of having children
533 with depth.
535 Returns
536 -------
537 An array of probabilities, one per tree level but the last.
538 """
539 assert d >= 1 1ab
540 depth = jnp.arange(d - 1) 1ab
541 return alpha / (1 + depth).astype(float) ** beta 1ab
544class _LazyArray(Module):
545 """Like `functools.partial` but specialized to array-creating functions like `jax.numpy.zeros`."""
547 array_creator: Callable
548 shape: tuple[int, ...]
549 args: tuple
551 def __init__(
552 self, array_creator: Callable, shape: tuple[int, ...], *args: Any
553 ) -> None:
554 self.array_creator = array_creator 1ab
555 self.shape = shape 1ab
556 self.args = args 1ab
558 def __call__(self, **kwargs: Any) -> T:
559 return self.array_creator(self.shape, *self.args, **kwargs) 1ab
561 @property
562 def ndim(self) -> int:
563 return len(self.shape) 1jdi
566def init(
567 *,
568 X: UInt[Any, 'p n'],
569 y: Float32[Any, ' n'] | Float32[Any, ' k n'],
570 outcome_type: OutcomeType | str | Sequence[OutcomeType | str] = 'continuous',
571 offset: float | Float32[Any, ''] | Float32[Any, ' k'],
572 max_split: UInt[Any, ' p'],
573 num_trees: int,
574 p_nonterminal: Float32[Any, ' d_minus_1'],
575 leaf_prior_cov_inv: float | Float32[Any, ''] | Float32[Array, 'k k'],
576 error_cov_df: float | Float32[Any, ''] | None = None,
577 error_cov_scale: float | Float32[Any, ''] | Float32[Array, 'k k'] | None = None,
578 error_scale: Float32[Any, ' n'] | None = None,
579 min_points_per_decision_node: int | Integer[Any, ''] | None = None,
580 resid_num_batches: int | None | Literal['auto'] = 'auto',
581 count_num_batches: int | None | Literal['auto'] = 'auto',
582 prec_num_batches: int | None | Literal['auto'] = 'auto',
583 prec_count_num_trees: int | None | Literal['auto'] = 'auto',
584 save_ratios: bool = False,
585 filter_splitless_vars: int = 0,
586 min_points_per_leaf: int | Integer[Any, ''] | None = None,
587 log_s: Float32[Any, ' p'] | None = None,
588 theta: float | Float32[Any, ''] | None = None,
589 a: float | Float32[Any, ''] | None = None,
590 b: float | Float32[Any, ''] | None = None,
591 rho: float | Float32[Any, ''] | None = None,
592 sparse_on_at: int | Integer[Any, ''] | None = None,
593 num_chains: int | None = None,
594 mesh: Mesh | dict[str, int] | None = None,
595 target_platform: Literal['cpu', 'gpu'] | None = None,
596) -> State:
597 """
598 Make a BART posterior sampling MCMC initial state.
600 Parameters
601 ----------
602 X
603 The predictors. Note this is trasposed compared to the usual convention.
604 y
605 The response. If two-dimensional, the outcome is multivariate with the
606 first axis indicating the component. For binary data, non-zero means 1,
607 zero means 0.
608 outcome_type
609 Whether the regression is continuous or binary (probit). Can also be a
610 sequence of `OutcomeType` values, one per outcome component, for mixed
611 binary-continuous multivariate regression.
612 offset
613 Constant shift added to the sum of trees. 0 if not specified.
614 max_split
615 The maximum split index for each variable. All split ranges start at 1.
616 num_trees
617 The number of trees in the forest.
618 p_nonterminal
619 The probability of a nonterminal node at each depth. The maximum depth
620 of trees is fixed by the length of this array. Use `make_p_nonterminal`
621 to set it with the conventional formula.
622 leaf_prior_cov_inv
623 The prior precision matrix of a leaf, conditional on the tree structure.
624 For the univariate case (k=1), this is a scalar (the inverse variance).
625 The prior covariance of the sum of trees is
626 ``num_trees * leaf_prior_cov_inv^-1``. The prior mean of leaves is
627 always zero.
628 error_cov_df
629 error_cov_scale
630 The df and scale parameters of the inverse Wishart prior on the error
631 covariance. For the univariate case, the relationship to the inverse
632 gamma prior parameters is ``alpha = df / 2``, ``beta = scale / 2``.
633 Leave unspecified for binary regression.
634 error_scale
635 Each error is scaled by the corresponding factor in `error_scale`, so
636 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
637 Not supported for binary regression. If not specified, defaults to 1 for
638 all points, but potentially skipping calculations.
639 min_points_per_decision_node
640 The minimum number of data points in a decision node. 0 if not
641 specified.
642 resid_num_batches
643 count_num_batches
644 prec_num_batches
645 The number of batches, along datapoints, for summing the residuals,
646 counting the number of datapoints in each leaf, and computing the
647 likelihood precision in each leaf, respectively. `None` for no batching.
648 If 'auto', it's chosen automatically based on the target platform; see
649 the description of `target_platform` below for how it is determined.
650 prec_count_num_trees
651 The number of trees to process at a time when counting datapoints or
652 computing the likelihood precision. If `None`, do all trees at once,
653 which may use too much memory. If 'auto' (default), it's chosen
654 automatically.
655 save_ratios
656 Whether to save the Metropolis-Hastings ratios.
657 filter_splitless_vars
658 The maximum number of variables without splits that can be ignored. If
659 there are more, `init` raises an exception.
660 min_points_per_leaf
661 The minimum number of datapoints in a leaf node. 0 if not specified.
662 Unlike `min_points_per_decision_node`, this constraint is not taken into
663 account in the Metropolis-Hastings ratio because it would be expensive
664 to compute. Grow moves that would violate this constraint are vetoed.
665 This parameter is independent of `min_points_per_decision_node` and
666 there is no check that they are coherent. It makes sense to set
667 ``min_points_per_decision_node >= 2 * min_points_per_leaf``.
668 log_s
669 The logarithm of the prior probability for choosing a variable to split
670 along in a decision rule, conditional on the ancestors. Not normalized.
671 If not specified, use a uniform distribution. If not specified and
672 `theta` or `rho`, `a`, `b` are, it's initialized automatically.
673 theta
674 The concentration parameter for the Dirichlet prior on `s`. Required
675 only to update `log_s`. If not specified, and `rho`, `a`, `b` are
676 specified, it's initialized automatically.
677 a
678 b
679 rho
680 Parameters of the prior on `theta`. Required only to sample `theta`.
681 sparse_on_at
682 After how many MCMC steps to turn on variable selection.
683 num_chains
684 The number of independent MCMC chains to represent in the state. Single
685 chain with scalar values if not specified.
686 mesh
687 A jax mesh used to shard data and computation across multiple devices.
688 If it has a 'chains' axis, that axis is used to shard the chains. If it
689 has a 'data' axis, that axis is used to shard the datapoints.
691 As a shorthand, if a dictionary mapping axis names to axis size is
692 passed, the corresponding mesh is created, e.g., ``dict(chains=4,
693 data=2)`` will let jax pick 8 devices to split chains (which must be a
694 multiple of 4) across 4 pairs of devices, where in each pair the data is
695 split in two.
697 Note: if a mesh is passed, the arrays are always sharded according to
698 it. In particular even if the mesh has no 'chains' or 'data' axis, the
699 arrays will be replicated on all devices in the mesh.
700 target_platform
701 Platform ('cpu' or 'gpu') used to determine the number of batches
702 automatically. If `mesh` is specified, the platform is inferred from the
703 devices in the mesh. Otherwise, if `y` is a concrete array (i.e., `init`
704 is not invoked in a `jax.jit` context), the platform is set to the
705 platform of `y`. Otherwise, use `target_platform`.
707 To avoid confusion, in all cases where the `target_platform` argument
708 would be ignored, `init` raises an exception if `target_platform` is
709 set.
711 Returns
712 -------
713 An initialized BART MCMC state.
715 Raises
716 ------
717 ValueError
718 If arguments unused in binary regression are set.
720 Notes
721 -----
722 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
723 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
724 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
725 integers in the range ``[0, 1, ..., max_split[i]]``.
727 In general the arrays passed to this function as arguments may be donated,
728 invalidating them. Create copies before passing them to `init` if this
729 happens and you need them again.
730 """
731 # convert to array all array-like arguments that are used in other
732 # configurations but don't need further processing themselves
733 X = jnp.asarray(X) 1ab
734 y = jnp.asarray(y) 1ab
735 assert y.dtype == jnp.float32 1ab
736 offset = jnp.asarray(offset) 1ab
737 leaf_prior_cov_inv = jnp.asarray(leaf_prior_cov_inv) 1ab
738 max_split = jnp.asarray(max_split) 1ab
740 # normalize outcome_type to enum (or list of enums)
741 outcome_type = _parse_outcome_type(outcome_type) 1ab
743 # check p_nonterminal and pad it with a 0 at the end (still not final shape)
744 p_nonterminal = _parse_p_nonterminal(p_nonterminal) 1ab
746 # process arguments that change depending on outcome type
747 is_binary, kshape, error_cov_inv, error_cov_df, error_cov_scale, binary_indices = ( 1ab
748 _init_shape_shifting_parameters(
749 y,
750 outcome_type,
751 offset,
752 error_scale,
753 error_cov_df,
754 error_cov_scale,
755 leaf_prior_cov_inv,
756 )
757 )
759 # extract array sizes from arguments
760 (max_depth,) = p_nonterminal.shape 1ab
761 p, n = X.shape 1ab
763 # check and initialize sparsity parameters
764 if not _all_none_or_not_none(rho, a, b): 764 ↛ 765line 764 didn't jump to line 765 because the condition on line 764 was never true1ab
765 msg = 'rho, a, b are not either all `None` or all set'
766 raise ValueError(msg)
767 if theta is None and rho is not None: 1satCpqb
768 theta = rho 1apq
769 if log_s is None and theta is not None: 1sadtDECpqb
770 log_s = jnp.zeros(max_split.size) 1apq
771 if not _all_none_or_not_none(theta, sparse_on_at): 771 ↛ 772line 771 didn't jump to line 772 because the condition on line 771 was never true1sadDEb
772 msg = 'sparsity params (either theta or rho,a,b) and sparse_on_at must be either all None or all set'
773 raise ValueError(msg)
775 # process multichain settings
776 chain_shape = () if num_chains is None else (num_chains,) 1areb
777 resid_shape = chain_shape + y.shape 1areb
778 add_chains = partial(_add_chains, chain_shape=chain_shape) 1ab
780 # determine batch sizes for reductions
781 mesh = _parse_mesh(num_chains, mesh) 1ab
782 target_platform = _parse_target_platform( 1ab
783 y, mesh, target_platform, resid_num_batches, count_num_batches, prec_num_batches
784 )
785 red_cfg = _parse_reduction_configs( 1ab
786 resid_num_batches,
787 count_num_batches,
788 prec_num_batches,
789 prec_count_num_trees,
790 y,
791 num_trees,
792 mesh,
793 target_platform,
794 )
796 # check there aren't too many deactivated predictors
797 offset = _check_splitless_vars(filter_splitless_vars, max_split, offset) 1ab
799 # determine shapes for trees
800 tree_shape = (*chain_shape, num_trees) 1ab
801 tree_size = 2**max_depth 1ab
803 # initialize all remaining stuff and put it in an unsharded state
804 state = State( 1adpqb
805 X=X,
806 binary_y=y, # temporary to be sharded together with everything else
807 z=(
808 _LazyArray(jnp.full, resid_shape, offset[..., None])
809 if is_binary
810 else _LazyArray(
811 jnp.full,
812 (*chain_shape, binary_indices.size, n),
813 offset[binary_indices, None],
814 )
815 if binary_indices is not None
816 else None
817 ),
818 binary_indices=binary_indices,
819 offset=offset,
820 resid=(
821 _LazyArray(jnp.zeros, resid_shape)
822 if is_binary
823 else None # resid is created later after y and offset are sharded
824 ),
825 error_cov_inv=add_chains(error_cov_inv),
826 prec_scale=error_scale, # temporarily set to error_scale, fix after sharding
827 error_cov_df=error_cov_df,
828 error_cov_scale=error_cov_scale,
829 forest=Forest(
830 leaf_tree=_LazyArray(
831 jnp.zeros, (*tree_shape, *kshape, tree_size), jnp.float32
832 ),
833 var_tree=_LazyArray(
834 jnp.zeros, (*tree_shape, tree_size // 2), minimal_unsigned_dtype(p - 1)
835 ),
836 split_tree=_LazyArray(
837 jnp.zeros, (*tree_shape, tree_size // 2), max_split.dtype
838 ),
839 affluence_tree=_LazyArray(
840 _initial_affluence_tree,
841 (*tree_shape, tree_size // 2),
842 n,
843 min_points_per_decision_node,
844 ),
845 blocked_vars=_get_blocked_vars(filter_splitless_vars, max_split),
846 max_split=max_split,
847 grow_prop_count=_LazyArray(jnp.zeros, chain_shape, int),
848 grow_acc_count=_LazyArray(jnp.zeros, chain_shape, int),
849 prune_prop_count=_LazyArray(jnp.zeros, chain_shape, int),
850 prune_acc_count=_LazyArray(jnp.zeros, chain_shape, int),
851 p_nonterminal=p_nonterminal[tree_depths(tree_size)],
852 p_propose_grow=p_nonterminal[tree_depths(tree_size // 2)],
853 leaf_indices=_LazyArray(
854 jnp.ones, (*tree_shape, n), minimal_unsigned_dtype(tree_size - 1)
855 ),
856 min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
857 min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
858 log_trans_prior=_LazyArray(jnp.zeros, (*chain_shape, num_trees))
859 if save_ratios
860 else None,
861 log_likelihood=_LazyArray(jnp.zeros, (*chain_shape, num_trees))
862 if save_ratios
863 else None,
864 leaf_prior_cov_inv=leaf_prior_cov_inv,
865 log_s=add_chains(_asarray_or_none(log_s)),
866 theta=add_chains(_asarray_or_none(theta)),
867 rho=_asarray_or_none(rho),
868 a=_asarray_or_none(a),
869 b=_asarray_or_none(b),
870 ),
871 config=StepConfig(
872 steps_done=jnp.int32(0),
873 sparse_on_at=_asarray_or_none(sparse_on_at),
874 mesh=mesh,
875 **red_cfg,
876 ),
877 )
879 # delete big input arrays such that they can be deleted as soon as they
880 # are sharded, only those arrays that contain an (n,) sized axis
881 del X, error_scale, y 1adpqb
883 # move all arrays to the appropriate device
884 state = _shard_state(state) 1ab
886 # calculate initial resid in the continuous outcome case, such that y and
887 # offset are already sharded if needed
888 if state.resid is None: 1adbk
889 resid = _LazyArray( 1ab
890 _initial_resid,
891 resid_shape,
892 state.binary_y, # this is actually y
893 state.offset,
894 binary_indices,
895 )
896 resid = _shard_leaf(resid, 0, -1, state.config.mesh) 1ab
897 state = replace(state, resid=resid) 1ab
899 # calculate initial binary_y
900 if is_binary or binary_indices is not None: 1adfgbhk
901 binary_y = _LazyArray( 1dfghk
902 _initial_binary_y,
903 state.binary_y.shape
904 if binary_indices is None
905 else (binary_indices.size, n),
906 state.binary_y, # this is actually y
907 binary_indices,
908 )
909 binary_y = _shard_leaf(binary_y, None, -1, state.config.mesh) 1dfghk
910 else:
911 binary_y = None 1ab
912 state = replace(state, binary_y=binary_y) 1ab
914 # calculate prec_scale after sharding to do the calculation on the right
915 # devices
916 if state.prec_scale is not None: 1atb
917 prec_scale = _compute_prec_scale(state.prec_scale) 1t
918 state = replace(state, prec_scale=prec_scale) 1t
920 # make all types strong to avoid unwanted recompilations
921 return _remove_weak_types(state) 1ab
924def _initial_resid(
925 shape: tuple[int, ...],
926 y: Float32[Array, ' n'] | Float32[Array, 'k n'],
927 offset: Float32[Array, ''] | Float32[Array, ' k'],
928 binary_indices: Int32[Array, ' kb'] | None,
929) -> Float32[Array, ' n'] | Float32[Array, 'k n']:
930 """Calculate the initial value for `State.resid` in the continuous outcome case.
932 In the mixed binary-continuous case, binary rows are zeroed out (their
933 residual starts at ``z - trees - offset = 0``).
934 """
935 resid = jnp.broadcast_to(y - offset[..., None], shape) 1ab
936 if binary_indices is not None: 1afgbh
937 resid = resid.at[..., binary_indices, :].set(0.0) 1fgh
938 return resid 1ab
941def _initial_binary_y(
942 shape: tuple[int, ...],
943 y: Float32[Array, 'k n'] | Float32[Array, ' n'],
944 binary_indices: Int32[Array, ' kb'] | None,
945) -> Bool[Array, 'kb n'] | Bool[Array, ' n']:
946 """Extract and convert the binary outcome components from ``y``."""
947 if binary_indices is None: 1dfghk
948 out = y.astype(bool) 1dk
949 else:
950 out = y[binary_indices, :].astype(bool) 1fgh
951 assert out.shape == shape 1dk
952 return out 1dk
955def _initial_affluence_tree(
956 shape: tuple[int, ...], n: int, min_points_per_decision_node: int | None
957) -> Array:
958 """Create the initial value of `Forest.affluence_tree`."""
959 return ( 1aLbk
960 jnp.zeros(shape, bool)
961 .at[..., 1]
962 .set(
963 True
964 if min_points_per_decision_node is None
965 else n >= min_points_per_decision_node
966 )
967 )
970@partial(jit, donate_argnums=(0,))
971def _compute_prec_scale(error_scale: Float32[Array, ' n']) -> Float32[Array, ' n']:
972 """Compute 1 / error_scale**2.
974 This is a separate function to use donate_argnums to avoid intermediate
975 copies.
976 """
977 return jnp.reciprocal(jnp.square(error_scale)) 1t
980def _get_blocked_vars(
981 filter_splitless_vars: int, max_split: UInt[Array, ' p']
982) -> None | UInt[Array, ' q']:
983 """Initialize the `blocked_vars` field."""
984 if filter_splitless_vars: 1avwb
985 (p,) = max_split.shape 1vw
986 (blocked_vars,) = jnp.nonzero( 1vw
987 max_split == 0, size=filter_splitless_vars, fill_value=p
988 )
989 return blocked_vars.astype(minimal_unsigned_dtype(p)) 1vw
990 # see `fully_used_variables` for the type cast
991 else:
992 return None 1ab
995def _add_chains(
996 x: Shaped[Array, '*shape'] | None, chain_shape: tuple[int, ...]
997) -> Shaped[Array, '*shape'] | Shaped[Array, ' num_chains *shape'] | None:
998 """Broadcast `x` to all chains."""
999 if x is None: 1sab
1000 return None 1sb
1001 else:
1002 return jnp.broadcast_to(x, chain_shape + x.shape) 1ab
1005def _parse_mesh(
1006 num_chains: int | None, mesh: Mesh | dict[str, int] | None
1007) -> Mesh | None:
1008 """Parse the `mesh` argument."""
1009 if mesh is None: 1jadbi
1010 return None 1ab
1012 # convert dict format to actual mesh
1013 if isinstance(mesh, dict): 1jdFi
1014 assert set(mesh).issubset({'chains', 'data'}) 1i
1015 mesh = make_mesh( 1i
1016 tuple(mesh.values()), tuple(mesh), axis_types=(AxisType.Auto,) * len(mesh)
1017 )
1019 # check there's no chain mesh axis if there are no chains
1020 if num_chains is None: 1jdFin
1021 assert 'chains' not in mesh.axis_names 1n
1023 # check the axes we use are in auto mode
1024 assert 'chains' not in mesh.axis_names or 'chains' in mesh.auto_axes 1jdfin
1025 assert 'data' not in mesh.axis_names or 'data' in mesh.auto_axes 1jdfin
1027 return mesh 1jdfin
1030def _parse_target_platform(
1031 y: Array,
1032 mesh: Mesh | None,
1033 target_platform: Literal['cpu', 'gpu'] | None,
1034 resid_num_batches: int | None | Literal['auto'],
1035 count_num_batches: int | None | Literal['auto'],
1036 prec_num_batches: int | None | Literal['auto'],
1037) -> Literal['cpu', 'gpu'] | None:
1038 if mesh is not None: 1jadbi
1039 assert target_platform is None, 'mesh provided, do not set target_platform' 1jdi
1040 return mesh.devices.flat[0].platform 1jdi
1041 elif hasattr(y, 'platform'): 1axbk
1042 assert target_platform is None, 'device inferred from y, unset target_platform' 1ak
1043 return y.platform() 1ak
1044 elif ( 1xyb
1045 resid_num_batches == 'auto'
1046 or count_num_batches == 'auto'
1047 or prec_num_batches == 'auto'
1048 ):
1049 assert target_platform in ('cpu', 'gpu') 1yb
1050 return target_platform 1yb
1051 else:
1052 assert target_platform is None, 'target_platform not used, unset it' 1x
1053 return target_platform 1x
1056@partial(filter_jit, donate='all')
1057# jit and donate because otherwise type conversion would create copies
1058def _remove_weak_types(x: PyTree[Array, 'T']) -> PyTree[Array, 'T']:
1059 """Make all types strong.
1061 This is to avoid recompilation in `run_mcmc` or `step`.
1062 """
1064 def remove_weak(x: T) -> T: 1ab
1065 if isinstance(x, Array) and x.weak_type: 1ab
1066 return x.astype(x.dtype) 1ab
1067 else:
1068 return x 1ab
1070 return tree.map(remove_weak, x) 1ab
1073def _shard_state(state: State) -> State:
1074 """Place all arrays on the appropriate devices, and instantiate lazily defined arrays."""
1075 mesh = state.config.mesh 1ab
1076 shard_leaf = partial(_shard_leaf, mesh=mesh) 1ab
1077 return tree.map( 1ab
1078 shard_leaf,
1079 state,
1080 chain_vmap_axes(state),
1081 data_vmap_axes(state),
1082 is_leaf=lambda x: x is None or isinstance(x, _LazyArray),
1083 )
1086def _shard_leaf(
1087 x: Array | None | _LazyArray,
1088 chain_axis: int | None,
1089 data_axis: int | None,
1090 mesh: Mesh | None,
1091) -> Array | None:
1092 """Create `x` if it's lazy and shard it."""
1093 if x is None: 1ab
1094 return None 1ab
1096 if mesh is None: 1jadbi
1097 sharding = None 1ab
1098 else:
1099 spec = [None] * x.ndim 1jdi
1100 if chain_axis is not None and 'chains' in mesh.axis_names: 1jdfin
1101 spec[chain_axis] = 'chains' 1jfi
1102 if data_axis is not None and 'data' in mesh.axis_names: 1jdfin
1103 spec[data_axis] = 'data' 1dn
1105 # remove trailing Nones to be consistent with jax's output, it's useful
1106 # for comparing shardings during debugging
1107 while spec and spec[-1] is None: 1jdfi
1108 spec.pop() 1jdi
1110 spec = PartitionSpec(*spec) 1jdi
1111 sharding = NamedSharding(mesh, spec) 1jdi
1113 if isinstance(x, _LazyArray): 1ab
1114 x = _concretize_lazy_array(x, sharding) 1ab
1115 elif sharding is not None: 1jadbi
1116 x = device_put(x, sharding, donate=True) 1jdi
1118 return x 1ab
1121@filter_jit
1122# jit such that in recent jax versions the shards are created on the right
1123# devices immediately instead of being created on the wrong device and then
1124# copied
1125def _concretize_lazy_array(x: _LazyArray, sharding: NamedSharding | None) -> Array:
1126 """Create an array from an abstract spec on the appropriate devices."""
1127 x = x() 1ab
1128 if sharding is not None: 1jadbi
1129 x = lax.with_sharding_constraint(x, sharding) 1jdi
1130 return x 1ab
1133def _all_none_or_not_none(*args: object) -> bool:
1134 is_none = [x is None for x in args] 1ab
1135 return all(is_none) or not any(is_none) 1sapqb
1138def _asarray_or_none(x: object) -> Array | None:
1139 if x is None: 1aGb
1140 return None 1Gb
1141 return jnp.asarray(x) 1ab
1144def _get_platform(mesh: Mesh | None) -> str:
1145 if mesh is None:
1146 return get_default_device().platform
1147 else:
1148 return mesh.devices.flat[0].platform
1151class _ReductionConfig(TypedDict):
1152 """Fields of `StepConfig` related to reductions."""
1154 resid_num_batches: int | None
1155 count_num_batches: int | None
1156 prec_num_batches: int | None
1157 prec_count_num_trees: int | None
1160def _parse_reduction_configs(
1161 resid_num_batches: int | None | Literal['auto'],
1162 count_num_batches: int | None | Literal['auto'],
1163 prec_num_batches: int | None | Literal['auto'],
1164 prec_count_num_trees: int | None | Literal['auto'],
1165 y: Float32[Array, ' n'] | Float32[Array, ' k n'] | Bool[Array, ' n'],
1166 num_trees: int,
1167 mesh: Mesh | None,
1168 target_platform: Literal['cpu', 'gpu'] | None,
1169) -> _ReductionConfig:
1170 """Determine settings for indexed reduces."""
1171 n = y.shape[-1] 1ab
1172 n //= get_axis_size(mesh, 'data') # per-device datapoints 1ab
1173 parse_num_batches = partial(_parse_num_batches, target_platform, n) 1ab
1174 return dict( 1ab
1175 resid_num_batches=parse_num_batches(resid_num_batches, 'resid'),
1176 count_num_batches=parse_num_batches(count_num_batches, 'count'),
1177 prec_num_batches=parse_num_batches(prec_num_batches, 'prec'),
1178 prec_count_num_trees=_parse_prec_count_num_trees(
1179 prec_count_num_trees, num_trees, n
1180 ),
1181 )
1184def _parse_num_batches(
1185 target_platform: Literal['cpu', 'gpu'] | None,
1186 n: int,
1187 num_batches: int | None | Literal['auto'],
1188 which: Literal['resid', 'count', 'prec'],
1189) -> int | None:
1190 """Return the number of batches or determine it automatically."""
1191 final_round = partial(_final_round, n) 1ab
1192 if num_batches != 'auto': 1adbz
1193 nb = num_batches 1az
1194 elif target_platform == 'cpu': 1194 ↛ 1196line 1194 didn't jump to line 1196 because the condition on line 1194 was always true1db
1195 nb = final_round(16) 1db
1196 elif target_platform == 'gpu':
1197 nb = dict(resid=1024, count=2048, prec=1024)[which] # on an A4000
1198 nb = final_round(nb)
1199 return nb 1adbz
1202def _final_round(n: int, num: float) -> int | None:
1203 """Bound batch size, round number of batches to a power of 2, and disable batching if there's only 1 batch."""
1204 # at least some elements per batch
1205 num = min(n // 32, num) 1db
1207 # round to the nearest power of 2 because I guess XLA and the hardware
1208 # will like that (not sure about this, maybe just multiple of 32?)
1209 num = 2 ** round(log2(num)) if num else 0 1dHbIn
1211 # disable batching if the batch is as large as the whole dataset
1212 return num if num > 1 else None 1dHbIkn
1215def _parse_prec_count_num_trees(
1216 prec_count_num_trees: int | None | Literal['auto'], num_trees: int, n: int
1217) -> int | None:
1218 """Return the number of trees to process at a time or determine it automatically."""
1219 if prec_count_num_trees != 'auto': 1adpqb
1220 return prec_count_num_trees 1apq
1221 max_n_by_ntree = 2**27 # about 100M 1db
1222 pcnt = max_n_by_ntree // max(1, n) 1db
1223 pcnt = min(num_trees, pcnt) 1db
1224 pcnt = max(1, pcnt) 1db
1225 pcnt = _search_divisor( 1db
1226 pcnt, num_trees, max(1, pcnt // 2), max(1, min(num_trees, pcnt * 2))
1227 )
1228 if pcnt >= num_trees: 1228 ↛ 1230line 1228 didn't jump to line 1230 because the condition on line 1228 was always true1db
1229 pcnt = None 1db
1230 return pcnt 1db
1233def _search_divisor(target_divisor: int, dividend: int, low: int, up: int) -> int:
1234 """Find the divisor closest to `target_divisor` in [low, up] if `target_divisor` is not already.
1236 If there is none, give up and return `target_divisor`.
1237 """
1238 assert target_divisor >= 1 1db
1239 assert 1 <= low <= up <= dividend 1db
1240 if dividend % target_divisor == 0: 1240 ↛ 1242line 1240 didn't jump to line 1242 because the condition on line 1240 was always true1db
1241 return target_divisor 1db
1242 candidates = numpy.arange(low, up + 1)
1243 divisors = candidates[dividend % candidates == 0]
1244 if divisors.size == 0:
1245 return target_divisor
1246 penalty = numpy.abs(divisors - target_divisor)
1247 closest = numpy.argmin(penalty)
1248 return divisors[closest].item()
1251def get_axis_size(mesh: Mesh | None, axis_name: str) -> int:
1252 if mesh is None or axis_name not in mesh.axis_names: 1adbin
1253 return 1 1abi
1254 else:
1255 i = mesh.axis_names.index(axis_name) 1dn
1256 return mesh.axis_sizes[i] 1dn
1259def chol_with_gersh(
1260 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool = False
1261) -> Float32[Array, '*batch_shape k k']:
1262 """Cholesky with Gershgorin stabilization, supports batching."""
1263 return _chol_with_gersh_impl(mat, absolute_eps) 1lm
1266@partial(jnp.vectorize, signature='(k,k)->(k,k)', excluded=(1,))
1267def _chol_with_gersh_impl(
1268 mat: Float32[Array, '*batch_shape k k'], absolute_eps: bool
1269) -> Float32[Array, '*batch_shape k k']:
1270 rho = jnp.max(jnp.sum(jnp.abs(mat), axis=1), initial=0.0) 1lm
1271 eps = jnp.finfo(mat.dtype).eps 1lm
1272 u = mat.shape[0] * rho * eps 1lm
1273 if absolute_eps: 1lJm
1274 u += eps 1lJ
1275 mat = mat.at[jnp.diag_indices_from(mat)].add(u) 1lm
1276 return jnp.linalg.cholesky(mat) 1lm
1279def _inv_via_chol_with_gersh(
1280 mat: Float32[Array, '*batch_shape k k'],
1281) -> Float32[Array, '*batch_shape k k']:
1282 """Compute matrix inverse via Cholesky with Gershgorin stabilization.
1284 DO NOT USE THIS FUNCTION UNLESS YOU REALLY NEED TO.
1285 """
1286 # mat = L L^T
1287 # mat^-1 = L^-T L^-1 = L^-T I L^-1 = L^-T (L^-T I)^T
1288 # I suspect this to be more accurate than (L^-1 I)^T (L^-1 I)
1289 L = chol_with_gersh(mat) 1lm
1290 eye = jnp.broadcast_to(jnp.eye(mat.shape[-1]), mat.shape) 1lm
1291 Ltinv = solve_triangular(L, eye, trans='T', lower=True) 1lm
1292 return solve_triangular(L, Ltinv.mT, trans='T', lower=True) 1lm
1295def get_num_chains(x: PyTree) -> int | None:
1296 """Get the number of chains of a pytree.
1298 Find all nodes in the structure that define 'num_chains()', stopping
1299 traversal at nodes that define it. Check all values obtained invoking
1300 `num_chains` are equal, then return it.
1301 """
1302 leaves, _ = tree.flatten(x, is_leaf=lambda x: hasattr(x, 'num_chains')) 1ao
1303 num_chains = [x.num_chains() for x in leaves if hasattr(x, 'num_chains')] 1ao
1304 ref = num_chains[0] 1ao
1305 assert all(c == ref for c in num_chains) 1ao
1306 return ref 1ao
1309def _chain_axes_with_keys(x: PyTree) -> PyTree[int | None]:
1310 """Return `chain_vmap_axes(x)` but also set to 0 for random keys."""
1311 axes = chain_vmap_axes(x) 1ae
1313 def axis_if_key(x: object, axis: int | None) -> int | None: 1ae
1314 if is_key(x): 1ae
1315 return 0 1ae
1316 else:
1317 return axis 1ae
1319 return tree.map(axis_if_key, x, axes) 1ae
1322def _get_mc_out_axes(
1323 fun: Callable[[tuple, dict], PyTree], args: PyTree, in_axes: PyTree[int | None]
1324) -> PyTree[int | None]:
1325 """Decide chain vmap axes for outputs."""
1326 vmapped_fun = vmap(fun, in_axes=in_axes) 1ae
1327 out = eval_shape(vmapped_fun, *args) 1ae
1328 return chain_vmap_axes(out) 1ae
1331def _find_mesh(x: PyTree) -> Mesh | None:
1332 """Find the mesh used for chains."""
1334 class MeshFound(Exception): 1ae
1335 pass 1ae
1337 def find_mesh(x: object) -> None: 1ae
1338 if isinstance(x, State): 1ae
1339 raise MeshFound(x.config.mesh) 1ae
1341 try: 1ae
1342 tree.map(find_mesh, x, is_leaf=lambda x: isinstance(x, State)) 1ae
1343 except MeshFound as e: 1ae
1344 return e.args[0] 1ae
1345 else:
1346 raise ValueError
1349def _split_all_keys(x: PyTree, num_chains: int) -> PyTree:
1350 """Split all random keys in `num_chains` keys."""
1351 mesh = _find_mesh(x) 1ae
1353 def split_key(x: object) -> object: 1ae
1354 if is_key(x): 1ae
1355 x = random.split(x, num_chains) 1ae
1356 if mesh is not None and 'chains' in mesh.axis_names: 1jadfeiK
1357 x = device_put(x, NamedSharding(mesh, PartitionSpec('chains'))) 1jfi
1358 return x 1adeK
1360 return tree.map(split_key, x) 1ae
1363def vmap_chains(fun: Callable[..., T]) -> Callable[..., T]:
1364 """Apply vmap on chain axes automatically if the inputs are multichain."""
1366 @wraps(fun)
1367 def auto_vmapped_fun(*args: Any, **kwargs: Any) -> T:
1368 all_args = args, kwargs 1ao
1369 num_chains = get_num_chains(all_args) 1ao
1370 if num_chains is not None: 1areo
1371 all_args = _split_all_keys(all_args, num_chains) 1ae
1373 def wrapped_fun(args: tuple[Any, ...], kwargs: dict[str, Any]) -> T: 1ae
1374 return fun(*args, **kwargs) 1ae
1376 mc_in_axes = _chain_axes_with_keys(all_args) 1ae
1377 mc_out_axes = _get_mc_out_axes(wrapped_fun, all_args, mc_in_axes) 1ae
1378 vmapped_fun = vmap(wrapped_fun, in_axes=mc_in_axes, out_axes=mc_out_axes) 1ae
1379 return vmapped_fun(*all_args) 1ae
1381 else:
1382 return fun(*args, **kwargs) 1ro
1384 return auto_vmapped_fun