Basic regression with the BART3 wrapperΒΆ

This example fits a small regression model through the R package BART3, driven from Python by rbartpackages. It requires R with the BART3 package installed.

On Google Colab you can install R and the package on the fly, which makes it easy to try the wrappers without a local R setup.

import numpy as np

from rbartpackages import BART3

rng = np.random.default_rng(0)
n, p = 100, 5
x_train = rng.standard_normal((n, p))
y_train = x_train[:, 0] + 0.1 * rng.standard_normal(n)
bart = BART3.gbart(x_train=x_train, y_train=y_train, ndpost=200)
y_pred = bart.predict(x_train)
y_pred.shape  # (ndpost, n)

bart exposes the components of the fitted R object as attributes, e.g. bart.yhat_train, bart.varcount, and bart.sigma.