{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Basic regression with the BART3 wrapper\n", "\n", "This example fits a small regression model through the R package BART3, driven from Python by `rbartpackages`. It requires R with the `BART3` package installed.\n", "\n", "On [Google Colab](https://colab.research.google.com) you can install R and the package on the fly, which makes it easy to try the wrappers without a local R setup." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from rbartpackages import BART3\n", "\n", "rng = np.random.default_rng(0)\n", "n, p = 100, 5\n", "x_train = rng.standard_normal((n, p))\n", "y_train = x_train[:, 0] + 0.1 * rng.standard_normal(n)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bart = BART3.gbart(x_train=x_train, y_train=y_train, ndpost=200)\n", "y_pred = bart.predict(x_train)\n", "y_pred.shape # (ndpost, n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`bart` exposes the components of the fitted R object as attributes, e.g. `bart.yhat_train`, `bart.varcount`, and `bart.sigma`." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }