bartz.mcmcstep.StepConfig

class bartz.mcmcstep.StepConfig(steps_done, sparse_on_at, resid_reduction_config, count_reduction_config, prec_reduction_config, prec_count_num_trees, sequential_unroll, augment, mesh)[source]

Options for the MCMC step.

steps_done: Int32[Array, '']

The number of MCMC steps completed so far.

sparse_on_at: Int32[Array, ''] | None

After how many steps to turn on variable selection.

resid_reduction_config: ReductionConfig

How to sum the residuals in each leaf.

count_reduction_config: ReductionConfig

How to count the datapoints in each leaf.

prec_reduction_config: ReductionConfig

How to sum the likelihood precisions in each leaf.

prec_count_num_trees: int | None

Batch size for processing trees to compute count and prec trees.

sequential_unroll: int | bool

How much to unroll the sequential accept/reject loop over trees in step. See the unroll argument of jax.lax.scan.

augment: bool

Whether to account exactly, via data augmentation, for the decision rules forbidden by the ancestors of each node when updating log_s.

mesh: Mesh | None

The mesh used to shard data and computation across multiple devices.

property data_sharded: bool[source]

Whether the data axis is sharded across devices.