bartz.mcmcstep.State

class bartz.mcmcstep.State(_chain_anchor, X, binary_y, z, binary_indices, offset, resid, error_cov_inv, prec_scale, inv_sdev_scale, forest, config)[source]

Represents the MCMC state of BART.

X: UInt[Array, 'p n']

The predictors.

binary_y: None | Bool[Array, 'n'] | Bool[Array, 'kb n']

The response as booleans for binary regression, None for continuous. In the mixed binary-continuous case, only the binary outcome components are stored, with shape (kb, n).

z: None | Float32[Array, '*chains n'] | Float32[Array, '*chains kb n']

The latent variable for binary regression. None in continuous regression. In the mixed binary-continuous case, only the binary outcome components are stored, with shape (*chains, kb, n).

binary_indices: None | Int32[Array, 'kb']

The indices of binary outcome components in the full list of outcome components. None when there are no binary components.

offset: Float32[Array, ''] | Float32[Array, 'k']

Constant shift added to the sum of trees.

resid: Float32[Array, '*chains n'] | Float32[Array, '*chains k n']

The residuals (y or z minus sum of trees).

error_cov_inv: Wishart

The inverse error covariance with its Wishart prior. The current value is error_cov_inv.value (scalar for univariate, matrix for multivariate); identity with no prior in binary regression.

prec_scale: Float32[Array, 'n'] | Float32[Array, 'k k n'] | None

The scale on the error precision. None in binary regression. With scalar per-datapoint weights, shape (n,) and value 1 / error_scale ** 2. With vector per-datapoint weights, shape (k, k, n) and value 1/outer(error_scale, error_scale) repeated over datapoints.

inv_sdev_scale: Float32[Array, 'n'] | Float32[Array, 'k n'] | None

The reciprocal of the per-observation error standard-deviation scale. None in binary regression. Shape (n,) for scalar weights, or (k, n) for per-component vector weights.

forest: Forest

The sum of trees model.

config: StepConfig

Metadata and configurations for the MCMC step.

property has_chains: bool[source]

Whether this state carries an explicit chain axis.

num_chains()[source]

Return the number of chains, or None if not multichain.

Return type:

int | None