{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "g4ZqupCK0IP8" }, "source": [ "# Simulated data" ] }, { "cell_type": "markdown", "metadata": { "id": "v6mmaj9O0VLE" }, "source": [ "This notebook runs bartz on simulated data. It is meant to be run on a GPU. Use the following link to try it out on colab: [link](https://colab.research.google.com/github/bartz-org/bartz/blob/main/docs/examples/basic_simdata.ipynb)\n", "\n", "The next cell installs bartz:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gAIYdaOqXP5g" }, "outputs": [], "source": [ "%pip install bartz" ] }, { "cell_type": "markdown", "metadata": { "id": "rJ387an92jww" }, "source": [ "The next cell defines configuration parameters for the script:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "hFK9jGWdJ8aH" }, "outputs": [], "source": [ "n_train = 100_000 # number of training points\n", "p = 1000 # number of predictors/features\n", "sigma = 0.1 # error standard deviation\n", "n_test = 1000 # number of test points\n", "n_tree = 10_000 # number of trees used by bartz" ] }, { "cell_type": "markdown", "metadata": { "id": "WwIKUa-E26po" }, "source": [ "The next cell generates simulated data from a linear + quadratic test model that comes packaged with bartz:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "xyYQxZ632yCQ" }, "outputs": [], "source": [ "from jax import random\n", "\n", "from bartz.testing import gen_data\n", "\n", "# list of independent seeds for random sampling\n", "keys = list(random.split(random.key(2024_04_16_18_53), 2))\n", "\n", "# simulate data with bartz's built-in testing data generating process\n", "data = gen_data(\n", " keys.pop(),\n", " n=n_train + n_test,\n", " p=p,\n", " q=2, # number of interactions, each predictor interacts with other q predictors in the quadratic term\n", " sigma2_eps=sigma**2, # error variance\n", " sigma2_lin=0.5, # linear term variance\n", " sigma2_quad=0.5, # quadratic term variance\n", " k=1, # number of outcomes\n", " lam=1.0, # correlation between outcomes, unused in this case\n", ")\n", "\n", "# split data in train-test\n", "train, test = data.split(n_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "CglltLyE6w8a" }, "source": [ "The next cell runs bartz:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Y0zr7gHR3DFw", "outputId": "6fa246b9-e17a-426c-a49c-8e49fec86ae2" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "..........\n", "Iteration 10/1100, grow prob: 54%, move acc: 37%, fill: 6% (burnin)\n", "..........\n", "Iteration 20/1100, grow prob: 53%, move acc: 36%, fill: 6% (burnin)\n", "..........\n", "Iteration 30/1100, grow prob: 53%, move acc: 35%, fill: 6% (burnin)\n", "..........\n", "Iteration 40/1100, grow prob: 54%, move acc: 34%, fill: 6% (burnin)\n", "..........\n", "Iteration 50/1100, grow prob: 53%, move acc: 34%, fill: 6% (burnin)\n", "..........\n", "Iteration 60/1100, grow prob: 54%, move acc: 34%, fill: 6% (burnin)\n", "..........\n", "Iteration 70/1100, grow prob: 54%, move acc: 33%, fill: 6% (burnin)\n", "..........\n", "Iteration 80/1100, grow prob: 54%, move acc: 32%, fill: 6% (burnin)\n", "..........\n", "Iteration 90/1100, grow prob: 54%, move acc: 32%, fill: 6% (burnin)\n", "..........\n", "Iteration 100/1100, grow prob: 54%, move acc: 33%, fill: 6% (burnin)\n", "..........\n", "Iteration 110/1100, grow prob: 54%, move acc: 32%, fill: 6%\n", "..........\n", "Iteration 120/1100, grow prob: 53%, move acc: 33%, fill: 6%\n", "..........\n", "Iteration 130/1100, grow prob: 53%, move acc: 32%, fill: 6%\n", "..........\n", "Iteration 140/1100, grow prob: 54%, move acc: 32%, fill: 6%\n", "..........\n", "Iteration 150/1100, grow prob: 54%, move acc: 31%, fill: 6%\n", "..........\n", "Iteration 160/1100, grow prob: 54%, move acc: 32%, fill: 6%\n", "..........\n", "Iteration 170/1100, grow prob: 54%, move acc: 31%, fill: 6%\n", "..........\n", "Iteration 180/1100, grow prob: 54%, move acc: 31%, fill: 6%\n", "..........\n", "Iteration 190/1100, grow prob: 54%, move acc: 30%, fill: 6%\n", "..........\n", "Iteration 200/1100, grow prob: 54%, move acc: 30%, fill: 6%\n", "..........\n", "Iteration 210/1100, grow prob: 53%, move acc: 31%, fill: 6%\n", "..........\n", "Iteration 220/1100, grow prob: 55%, move acc: 30%, fill: 6%\n", "..........\n", "Iteration 230/1100, grow prob: 54%, move acc: 29%, fill: 6%\n", "..........\n", "Iteration 240/1100, grow prob: 54%, move acc: 30%, fill: 6%\n", "..........\n", "Iteration 250/1100, grow prob: 55%, move acc: 30%, fill: 6%\n", "..........\n", "Iteration 260/1100, grow prob: 54%, move acc: 29%, fill: 6%\n", "..........\n", "Iteration 270/1100, grow prob: 53%, move acc: 30%, fill: 6%\n", "..........\n", "Iteration 280/1100, grow prob: 54%, move acc: 29%, fill: 6%\n", "..........\n", "Iteration 290/1100, grow prob: 54%, move acc: 28%, fill: 6%\n", "..........\n", "Iteration 300/1100, grow prob: 54%, move acc: 29%, fill: 6%\n", "..........\n", "Iteration 310/1100, grow prob: 54%, move acc: 29%, fill: 6%\n", "..........\n", "Iteration 320/1100, grow prob: 55%, move acc: 29%, fill: 6%\n", "..........\n", "Iteration 330/1100, grow prob: 53%, move acc: 29%, fill: 6%\n", "..........\n", "Iteration 340/1100, grow prob: 53%, move acc: 28%, fill: 6%\n", "..........\n", "Iteration 350/1100, grow prob: 53%, move acc: 29%, fill: 6%\n", "..........\n", "Iteration 360/1100, grow prob: 54%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 370/1100, grow prob: 54%, move acc: 28%, fill: 6%\n", "..........\n", "Iteration 380/1100, grow prob: 54%, move acc: 28%, fill: 6%\n", "..........\n", "Iteration 390/1100, grow prob: 53%, move acc: 28%, fill: 6%\n", "..........\n", "Iteration 400/1100, grow prob: 53%, move acc: 28%, fill: 6%\n", "..........\n", "Iteration 410/1100, grow prob: 55%, move acc: 28%, fill: 6%\n", "..........\n", "Iteration 420/1100, grow prob: 53%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 430/1100, grow prob: 53%, move acc: 28%, fill: 6%\n", "..........\n", "Iteration 440/1100, grow prob: 53%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 450/1100, grow prob: 53%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 460/1100, grow prob: 54%, move acc: 26%, fill: 6%\n", "..........\n", "Iteration 470/1100, grow prob: 53%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 480/1100, grow prob: 54%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 490/1100, grow prob: 54%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 500/1100, grow prob: 54%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 510/1100, grow prob: 53%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 520/1100, grow prob: 53%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 530/1100, grow prob: 52%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 540/1100, grow prob: 54%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 550/1100, grow prob: 54%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 560/1100, grow prob: 53%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 570/1100, grow prob: 54%, move acc: 26%, fill: 6%\n", "..........\n", "Iteration 580/1100, grow prob: 53%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 590/1100, grow prob: 54%, move acc: 27%, fill: 6%\n", "..........\n", "Iteration 600/1100, grow prob: 53%, move acc: 26%, fill: 6%\n", "..........\n", "Iteration 610/1100, grow prob: 53%, move acc: 25%, fill: 6%\n", "..........\n", "Iteration 620/1100, grow prob: 53%, move acc: 26%, fill: 6%\n", "..........\n", "Iteration 630/1100, grow prob: 53%, move acc: 26%, fill: 6%\n", "..........\n", "Iteration 640/1100, grow prob: 54%, move acc: 26%, fill: 6%\n", "..........\n", "Iteration 650/1100, grow prob: 54%, move acc: 26%, fill: 6%\n", "..........\n", "Iteration 660/1100, grow prob: 54%, move acc: 26%, fill: 6%\n", "..........\n", "Iteration 670/1100, grow prob: 53%, move acc: 26%, fill: 6%\n", "..........\n", "Iteration 680/1100, grow prob: 53%, move acc: 26%, fill: 6%\n", "..........\n", "Iteration 690/1100, grow prob: 53%, move acc: 26%, fill: 6%\n", "..........\n", "Iteration 700/1100, grow prob: 54%, move acc: 26%, fill: 6%\n", "..........\n", "Iteration 710/1100, grow prob: 54%, move acc: 25%, fill: 6%\n", "..........\n", "Iteration 720/1100, grow prob: 54%, move acc: 25%, fill: 6%\n", "..........\n", "Iteration 730/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 740/1100, grow prob: 55%, move acc: 25%, fill: 6%\n", "..........\n", "Iteration 750/1100, grow prob: 54%, move acc: 25%, fill: 6%\n", "..........\n", "Iteration 760/1100, grow prob: 53%, move acc: 25%, fill: 6%\n", "..........\n", "Iteration 770/1100, grow prob: 54%, move acc: 25%, fill: 6%\n", "..........\n", "Iteration 780/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 790/1100, grow prob: 54%, move acc: 25%, fill: 6%\n", "..........\n", "Iteration 800/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 810/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 820/1100, grow prob: 53%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 830/1100, grow prob: 54%, move acc: 25%, fill: 6%\n", "..........\n", "Iteration 840/1100, grow prob: 53%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 850/1100, grow prob: 53%, move acc: 25%, fill: 6%\n", "..........\n", "Iteration 860/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 870/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 880/1100, grow prob: 53%, move acc: 25%, fill: 6%\n", "..........\n", "Iteration 890/1100, grow prob: 53%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 900/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 910/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 920/1100, grow prob: 53%, move acc: 23%, fill: 6%\n", "..........\n", "Iteration 930/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 940/1100, grow prob: 54%, move acc: 25%, fill: 6%\n", "..........\n", "Iteration 950/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 960/1100, grow prob: 53%, move acc: 23%, fill: 6%\n", "..........\n", "Iteration 970/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 980/1100, grow prob: 53%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 990/1100, grow prob: 53%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 1000/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 1010/1100, grow prob: 54%, move acc: 23%, fill: 6%\n", "..........\n", "Iteration 1020/1100, grow prob: 54%, move acc: 23%, fill: 6%\n", "..........\n", "Iteration 1030/1100, grow prob: 53%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 1040/1100, grow prob: 54%, move acc: 24%, fill: 6%\n", "..........\n", "Iteration 1050/1100, grow prob: 54%, move acc: 23%, fill: 6%\n", "..........\n", "Iteration 1060/1100, grow prob: 54%, move acc: 23%, fill: 6%\n", "..........\n", "Iteration 1070/1100, grow prob: 53%, move acc: 23%, fill: 6%\n", "..........\n", "Iteration 1080/1100, grow prob: 54%, move acc: 23%, fill: 6%\n", "..........\n", "Iteration 1090/1100, grow prob: 53%, move acc: 23%, fill: 6%\n", "..........\n", "Iteration 1100/1100, grow prob: 54%, move acc: 24%, fill: 6%\n" ] } ], "source": [ "from time import perf_counter\n", "\n", "from bartz.BART import gbart\n", "\n", "# clock bartz\n", "start = perf_counter()\n", "bart = gbart(train.x, train.y.squeeze(0), ntree=n_tree, printevery=10, seed=keys.pop())\n", "end = perf_counter()" ] }, { "cell_type": "markdown", "metadata": { "id": "BpyizlIjMeCw" }, "source": [ "Interpretation of the printout:\n", "* grow prob = the fraction of trees that bart tried to add leaves to, rather than remove leaves from\n", "* move acc = in the last iteration, the fraction of attempted tree changes that were kept\n", "* fill = how much of the fixed tree space is used; if it's more that ~50% the trees are too deep, increase the number of trees\n", "\n", "The fractions refer to the state of the trees after the last iteration, they are not averaged over multiple iterations.\n", "\n", "A low acceptance means that the trees are changing very slowly." ] }, { "cell_type": "markdown", "metadata": { "id": "ra9xOl3YNY2m" }, "source": [ "The next cell computes the predictions." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "W2v-A58BNbuX", "outputId": "e3d7204d-dff0-47bd-c66b-9a9572297c1b" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "total sdev: 1.0\n", "error sdev: 0.10\n", "RMSE: 0.45\n", "expected RMSE: 0.51\n", "model error sdev: 0.42\n", "time: 9.9 min\n" ] } ], "source": [ "from jax import numpy as jnp\n", "\n", "# compute predictions\n", "yhat_test = bart.predict(test.x) # posterior samples, n_samples x n_test\n", "yhat_test_mean = jnp.mean(yhat_test, axis=0) # posterior mean point-by-point\n", "yhat_test_var = jnp.var(yhat_test, axis=0) # posterior variance point-by-point\n", "\n", "# RMSE\n", "rmse = jnp.sqrt(jnp.mean(jnp.square(yhat_test_mean - test.y)))\n", "expected_error_variance = jnp.mean(jnp.square(bart.sigma))\n", "expected_rmse = jnp.sqrt(jnp.mean(yhat_test_var + expected_error_variance))\n", "avg_sigma = jnp.sqrt(expected_error_variance)\n", "\n", "print(f'total sdev: {jnp.std(train.y):#.2g}')\n", "print(f'error sdev: {sigma:#.2g}')\n", "print(f'RMSE: {rmse:#.2g}')\n", "print(f'expected RMSE: {expected_rmse:#.2g}')\n", "print(f'model error sdev: {avg_sigma:#.2g}')\n", "print(f'time: {(end - start) / 60:#.2g} min')" ] }, { "cell_type": "markdown", "metadata": { "id": "QigclpPkOVJq" }, "source": [ "The RMSE can at best be as low as the error standard deviation used to generate the data." ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.14.2" } }, "nbformat": 4, "nbformat_minor": 0 }