rbartpackages.dbarts.rbart_vi

class rbartpackages.dbarts.rbart_vi(formula, data=None, *, group_by, test=None, subset=None, weights=None, offset=None, offset_test=None, group_by_test=None, prior=None, sigest=None, sigdf=3.0, sigquant=0.9, k=2.0, power=2.0, base=0.95, n_trees=75, n_samples=1500, n_burn=1500, n_chains=4, n_threads=None, combineChains=False, n_cuts=100, useQuantiles=False, n_thin=5, keepTrainingFits=True, printEvery=100, printCutoffs=0, verbose=True, keepTrees=True, keepCall=True, seed=None, keepSampler=None, keepTestFits=True, callback=None, **control)[source]

Fit BART with additive group random intercepts.

Python interface to R’s dbarts::rbart_vi, which adds an i.i.d. random intercept per group_by level to a bart2 fit. A string formula argument is converted to an R formula. In addition to the bart components, the fit exposes the random-intercept outputs below. Arguments left to None are omitted from the R call, so R computes its own defaults.

Parameters:
  • formula (str | Float64[ndarray, 'n p'] | DataFrame) – A model formula (as a string), or, in matrix mode, the x_train predictor matrix.

  • data (Float64[ndarray, 'n'] | DataFrame | None, default: None) – The data frame the formula refers to, or, in matrix mode, the y_train response vector.

  • group_by (Integer[ndarray, 'n'] | String[ndarray, 'n']) – Grouping factor (one level per random intercept), as a vector or a reference to a column of data.

  • test (Float64[ndarray, 'm p'] | DataFrame | None, default: None) – Test predictors, with the same columns as the training data.

  • subset (Integer[ndarray, 'k'] | None, default: None) – Subset of observations to keep.

  • weights (Float64[ndarray, 'n'] | None, default: None) – Per-observation weights.

  • offset (Float64[ndarray, 'n'] | float | None, default: None) – Latent-scale offset for binary outcomes.

  • offset_test (Float64[ndarray, 'm'] | float | None, default: None) – The offset for the test data; defaults to offset when applicable.

  • group_by_test (Integer[ndarray, 'm'] | String[ndarray, 'm'] | None, default: None) – Grouping factor for the test data.

  • prior (object | None, default: None) – Prior over the random-effects SD, as an R function or built-in reference (cauchy or gamma).

  • sigest (float | None, default: None) – Rough estimate of the error SD, as in bart2. Continuous only.

  • sigdf (float, default: 3.0) – Degrees of freedom of the sigma prior, as in bart2. Continuous only.

  • sigquant (float, default: 0.9) – Quantile of the sigma prior placed at sigest, as in bart2.

  • k (float, default: 2.0) – Number of prior SDs between f and the data extremes, as in bart2 (but defaulting to 2).

  • power (float, default: 2.0) – Exponent of the tree depth prior, as in bart2.

  • base (float, default: 0.95) – Scale of the tree depth prior, as in bart2.

  • n_trees (int, default: 75) – Number of trees in the sum, as in bart2.

  • n_samples (int, default: 1500) – Number of posterior samples kept per chain, as in bart2.

  • n_burn (int, default: 1500) – Number of burn-in iterations, as in bart2.

  • n_chains (int, default: 4) – Number of independent chains, as in bart2.

  • n_threads (int | None, default: None) – Number of threads to use, as in bart2.

  • combineChains (bool, default: False) – Whether the chains are stacked into the draws axis, as in bart2.

  • n_cuts (int | Integer[ndarray, 'p'], default: 100) – Maximum number of decision rules per predictor, as in bart2.

  • useQuantiles (bool, default: False) – Whether the decision rules use empirical quantiles, as in bart2.

  • n_thin (int, default: 5) – Thinning: keep one sample out of n_thin (defaulting to 5).

  • keepTrainingFits (bool, default: True) – Whether to return the training-point function draws, as in bart2.

  • printEvery (int, default: 100) – Interval, in samples, of the progress messages, as in bart2.

  • printCutoffs (int, default: 0) – Number of a variable’s decision rules printed before the run, as in bart2.

  • verbose (bool, default: True) – Whether to print progress to the R console, as in bart2.

  • keepTrees (bool, default: True) – Whether the trees are kept (defaulting to True), as in bart2.

  • keepCall (bool, default: True) – Whether the originating R call is stored in call, as in bart2.

  • seed (int | None, default: None) – Seed of the chains’ RNG, as in bart2.

  • keepSampler (bool | None, default: None) – Whether to keep the underlying sampler, as in bart2.

  • keepTestFits (bool, default: True) – Whether the test fits are returned (useful to disable with callback).

  • callback (object | None, default: None) – An R function of trainFits, testFits, ranef, sigma, and tau called after each kept iteration, its results collected in callback.

  • **control (object) – Extra keyword arguments forwarded verbatim (R’s ...) to dbartsControl, by their R names, e.g. rngSeed or updateState.

Notes

split_probs, proposal_probs, and samplerOnly (in bart2) are not part of rbart_vi’s interface and are unavailable here.

R documentation

title
-----

Bayesian Additive Regression Trees with Random Effects

name
----

rbart

alias
-----

residuals.rbart

keyword
-------

randomeffects

description
-----------

   Fits a varying intercept/random effect BART model.


usage
-----


 rbart_vi(
     formula, data, test, subset, weights, offset, offset.test = offset,
     group.by, group.by.test, prior = cauchy,
     sigest = NA_real_, sigdf = 3.0, sigquant = 0.90,
     k = 2.0,
     power = 2.0, base = 0.95,
     n.trees = 75L,
     n.samples = 1500L, n.burn = 1500L,
     n.chains = 4L, n.threads = min(dbarts::guessNumCores(), n.chains),
     combineChains = FALSE,
     n.cuts = 100L, useQuantiles = FALSE,
     n.thin = 5L, keepTrainingFits = TRUE,
     printEvery = 100L, printCutoffs = 0L,
     verbose = TRUE,
     keepTrees = TRUE, keepCall = TRUE,
     seed = NA_integer_,
     keepSampler = keepTrees,
     keepTestFits = TRUE,
     callback = NULL,
      )

 plot rbart (
     x, plquants = c(0.05, 0.95), cols = c('blue', 'black'),  )

 fitted rbart (
     object,
     type = c("ev", "ppd", "bart", "ranef"),
     sample = c("train", "test"),
      )

 extract rbart (
     object,
     type = c("ev", "ppd", "bart", "ranef", "trees"),
     sample = c("train", "test"),
     combineChains = TRUE,
      )

 predict rbart (
     object, newdata, group.by, offset,
     type = c("ev", "ppd", "bart", "ranef"),
     combineChains = TRUE,
      )

 residuals rbart (object,  )


arguments
---------


    group.by
     Grouping factor. Can be an integer vector/factor, or a reference to such in  data .

    group.by.test
     Grouping factor for test data, of the same type as  group.by . Can be missing.

    prior
     A function or symbolic reference to built-in priors. Determines the prior over the standard deviation of the random effects. Supplied functions take two arguments,  x  - the standard deviation, and  rel.scale  - the standard deviation of the response variable before random effects are fit. Built in priors are  cauchy  with a scale of 2.5 times the relative scale and  gamma  with a shape of 2.5 and scale of 2.5 times the relative scale.

    n.thin
     The number of tree jumps taken for every stored sample, but also the number of samples from the posterior of the standard deviation of the random effects before one is kept.

    keepTestFits
     Logical where, if false, test fits are obtained while running but not returned. Useful with  callback .

    callback
     Optional function of  trainFits ,  testFits ,  ranef ,  sigma , and  tau . Called after every post-burn-in iteration and the results of which are collected and stored in the final object.

    formula, data, test, subset, weights, offset, offset.test, sigest, sigdf, sigquant, k, power, base, n.trees, n.samples, n.burn, n.chains, n.threads, combineChains, n.cuts, useQuantiles, keepTrainingFits, printEvery, printCutoffs, verbose, keepTrees, keepCall, seed, keepSampler,
     Same as in  bart2 .

    object
     A fitted  rbart  model.

    newdata
     Same as  test , but named to match  predict  generic.

    type
     One of  "ev" ,  "ppd" ,  "bart" ,  "ranef" , or  "trees"  for the posterior of the expected value, posterior predictive distribution, non-parametric/BART component, random effect, or saved trees respectively. The expected value is the sum of the BART component and the random effects, while the posterior predictive distribution is a response sampled with that mean. To synergize with  predict.glm ,  "response"  can be used as a synonym for  "value"  and  "link"  can be used as a synonym for  "bart" . For additional details on tree extraction, see the corresponding subsection in  bart .

    sample
     One of  "train"  or  "test" , referring to the training or tests samples respectively.

    x, plquants, cols
     Same as in  plot.bart .



details
-------


   Fits a BART model with additive random intercepts, one for each factor level of  group.by . For continuous responses:


      y_i \sim N(f(x_i) + \alpha_{g[i]}, \sigma^2) y_i ~ N(f(x_i) + \alpha_{g[i]}, \sigma^2)
      \alpha_j \sim N(0, \tau^2) \alpha_j ~ N(0, \tau^2) .


 For binary outcomes the response model is changed to  P(Y_i = 1) = \Phi(f(x_i) + \alpha_{g[i]}) .  i  indexes observations,  g[i]  is the group index of observation  i ,  f(x)  and  \sigma_y  come from a BART model, and  \alpha_j  are the independent and identically distributed random intercepts. Draws from the posterior of  tau  are made using a slice sampler, with a width dynamically determined by assessing the curvature of the posterior distribution at its mode.

 Out Of Sample Groups
   Predicting random effects for groups not in the training sample is supported by sampling from their posterior predictive distribution, that is a draw is taken from  p(\alpha \mid y) = \int p(\alpha \mid \tau)p(\tau \mid y)d\alpha . For out-of-sample groups in the test data, these random effect draws can be kept with the saved object. For those supplied to  predict , they cannot and may change for subsequent calls.


 Generics
   See the generics section of  bart .



value
-----


   An object of class  rbart . Contains all of the same elements of an object of class  bart , as well as the elements:

    ranef
     Samples from the posterior of the random effects. A array/matrix of posterior samples. The  (k, l, j)  value is the  l th draw of the posterior of the random effect for group  j  (i.e.  \alpha^*_j \alpha*_j ) corresponding to chain  k . When  n.chains  is one or  combineChains  is  TRUE , the result is a collapsed down to a matrix.

    ranef.mean
     Posterior mean of random effects, derived by taking mean across group index of samples.

    tau
     Matrix of posterior samples of  tau , the standard deviation of the random effects. Dimensions are equal to the number of chains times the numbers of samples unless  n.chains  is one or  combineChains  is  TRUE .

    first.tau
     Burn-in draws of  tau .

    callback
     Optional results of  callback  function.



author
------


   Vincent Dorie:  vdorie@gmail.com


seealso
-------


    bart ,  dbarts


examples
--------


 f <- function(x) {
     10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
         10 * x[,4] + 5 * x[,5]
 }

 set.seed(99)
 sigma <- 1.0
 n     <- 100

 x  <- matrix(runif(n * 10), n, 10)
 Ey <- f(x)
 y  <- rnorm(n, Ey, sigma)

 n.g <- 10
 g <- sample(n.g, length(y), replace = TRUE)
 sigma.b <- 1.5
 b <- rnorm(n.g, 0, sigma.b)

 y <- y + b[g]

 df <- as.data.frame(x)
 colnames(df) <- paste0("x_", seq_len(ncol(x)))
 df$y <- y
 df$g <- g

 ## low numbers to reduce run time
 rbartFit <- rbart_vi(y ~ . - g, df, group.by = g,
                      n.samples = 40L, n.burn = 10L, n.thin = 2L,
                      n.chains = 1L,
                      n.trees = 25L, n.threads = 1L)
binaryOffset: Float64[ndarray, 'n'] | None = None

Per-observation offset on the latent probit scale (binary outcomes only).

extract(*, type=None, sample=None, combineChains=None)[source]

Return the kept draws for the training (default) or test points.

Like predict, the draws are on the expected-value scale by default. With type='trees' (requires keeptrees=True) the tree structures are returned as a data frame instead. Arguments left to None are omitted from the R call.

Parameters:
  • type (Literal['ev', 'ppd', 'bart', 'trees'] | None, default: None) – Quantity returned: 'ev', 'ppd', 'bart' (see predict), or 'trees' for the tree structures.

  • sample (Literal['train', 'test'] | None, default: None) – Which points to extract: 'train' or 'test'.

  • combineChains (bool | None, default: None) – Whether the chains are stacked into the draws axis rather than kept on a leading nchain axis.

Returns:

Float64[ndarray, 'ndpost n'] | Float64[ndarray, 'nchain ndpost n'] | DataFrame – The draws at the requested points, or the tree-structure data frame with type='trees'.

first_k: Float64[ndarray, 'nskip'] | Float64[ndarray, 'nchain nskip'] | None = None

Burn-in draws of k (only when k is given a hyperprior).

first_sigma: Float64[ndarray, 'nskip'] | Float64[ndarray, 'nchain nskip'] | None = None

Burn-in error-SD draws (continuous outcomes only).

fitted(*, type=None, sample=None)[source]

Return the posterior mean for the training (default) or test points.

Parameters:
  • type (Literal['ev', 'ppd', 'bart'] | None, default: None) – Quantity averaged: 'ev', 'ppd', or 'bart' (see predict).

  • sample (Literal['train', 'test'] | None, default: None) – Which points to use: 'train' or 'test'.

Returns:

Float64[ndarray, 'n']The posterior mean at the requested points.

k: Float64[ndarray, 'ndpost'] | Float64[ndarray, 'nchain ndpost'] | None = None

End-node-prior k draws (only when k is given a hyperprior).

n_chains: int | None = None

Number of MCMC chains; None when the sampler is kept in fit.

sigest: float | None = None

Rough residual SD used to set the sigma prior (continuous outcomes only).

sigma: Float64[ndarray, 'ndpost'] | Float64[ndarray, 'nchain ndpost'] | None = None

Kept error-SD draws, continuous outcomes only (burn-in is in first_sigma).

yhat_test: Float64[ndarray, 'ndpost m'] | Float64[ndarray, 'nchain ndpost m'] | None = None

Test-point posterior function draws; None without test data.

yhat_test_mean: Float64[ndarray, 'm'] | None = None

Posterior mean of yhat_test (continuous outcomes with test data only).

yhat_train: Float64[ndarray, 'ndpost n'] | Float64[ndarray, 'nchain ndpost n'] | None = None

Training-point posterior function draws (latent probit scale for binary).

None with keeptrainfits=False.

yhat_train_mean: Float64[ndarray, 'n'] | None = None

Posterior mean of yhat_train (continuous outcomes only).

call: LangVector

The R call that created the fit.

With keepcall=False this is a dummy NULL() call, not None.

varcount: Int32[ndarray, 'ndpost p'] | Int32[ndarray, 'nchain ndpost p']

Per-draw count of splits on each variable, summed over trees.

callback: Float64[ndarray, 'ndpost values'] | Float64[ndarray, 'nchain ndpost values'] | None = None

Stacked per-draw results of the callback function, if given.

first_tau: Float64[ndarray, 'nskip'] | Float64[ndarray, 'nchain nskip']

Burn-in draws of the random-effects SD.

fit: tuple[dbarts, ...] | None = None

The per-chain samplers as dbarts objects.

Kept only with keepTrees or keepSampler (both on by default).

group_by: String[ndarray, 'n']

The training grouping factor, as the level name of each observation.

group_by_test: String[ndarray, 'm'] | None = None

The test grouping factor, if given.

ranef: Float64[ndarray, 'ndpost g'] | Float64[ndarray, 'nchain ndpost g']

Random-intercept draws for each of the g groups.

ranef_mean: Float64[ndarray, 'g']

Posterior mean of ranef per group.

seed: Int32[ndarray, 'state']

R RNG state used by predict to draw the effects of unseen groups.

tau: Float64[ndarray, 'ndpost'] | Float64[ndarray, 'nchain ndpost']

Random-effects SD draws.

y: Float64[ndarray, 'n'] = None

The training responses.

predict(newdata, *, group_by=None, offset=None, type=None, combineChains=None)[source]

Compute predictions at new points; requires a keepTrees=True fit.

Each new point needs a group_by level. Arguments left to None are omitted from the R call, so R computes its own defaults.

Parameters:
  • newdata (Float64[ndarray, 'm p'] | DataFrame) – New predictors, with the same column structure as x_train.

  • group_by (Integer[ndarray, 'm'] | String[ndarray, 'm'] | None, default: None) – Grouping factor of the new points; out-of-sample groups draw fresh random effects.

  • offset (Float64[ndarray, 'm'] | float | None, default: None) – Offset added to the predictions.

  • type (Literal['ev', 'ppd', 'bart', 'ranef'] | None, default: None) – Quantity returned: 'ev' (expected value), 'ppd' (posterior predictive), 'bart' (the latent sum-of-trees), or 'ranef' (the random effects).

  • combineChains (bool | None, default: None) – Whether the chains are stacked into the draws axis rather than kept on a leading nchain axis.

Returns:

Float64[ndarray, 'ndpost m'] | Float64[ndarray, 'nchain ndpost m'] – The predictions at newdata, on the expected-value scale unless type says otherwise.