Simulated data

This notebook runs bartz on simulated data. It is meant to be run on a GPU. Use the following link to try it out on colab: link

The next cell installs bartz:

%pip install bartz

The next cell defines configuration parameters for the script:

n_train = 100_000  # number of training points
p = 1000           # number of predictors/features
sigma = 0.1        # error standard deviation
n_test = 1000      # number of test points
n_tree = 10_000    # number of trees used by bartz

The next cell generates simulated data from a linear + quadratic test model that comes packaged with bartz:

from jax import random

from bartz.testing import gen_data

# list of independent seeds for random sampling
keys = list(random.split(random.key(2024_04_16_18_53), 2))

# simulate data with bartz's built-in testing data generating process
data = gen_data(
    keys.pop(),
    n=n_train + n_test,
    p=p,
    q=2,  # number of interactions, each predictor interacts with other q predictors in the quadratic term
    sigma2_eps=sigma**2,  # error variance
    sigma2_lin=0.5,  # linear term variance
    sigma2_quad=0.5,  # quadratic term variance
    k=1,  # number of outcomes
    lam=1.0,  # correlation between outcomes, unused in this case
)

# split data in train-test
train, test = data.split(n_train)

The next cell runs bartz:

from time import perf_counter

from bartz.BART import gbart

# clock bartz
start = perf_counter()
bart = gbart(train.x, train.y.squeeze(0), ntree=n_tree, printevery=10, seed=keys.pop())
end = perf_counter()
..........
Iteration 10/1100, grow prob: 54%, move acc: 37%, fill: 6% (burnin)
..........
Iteration 20/1100, grow prob: 53%, move acc: 36%, fill: 6% (burnin)
..........
Iteration 30/1100, grow prob: 53%, move acc: 35%, fill: 6% (burnin)
..........
Iteration 40/1100, grow prob: 54%, move acc: 34%, fill: 6% (burnin)
..........
Iteration 50/1100, grow prob: 53%, move acc: 34%, fill: 6% (burnin)
..........
Iteration 60/1100, grow prob: 54%, move acc: 34%, fill: 6% (burnin)
..........
Iteration 70/1100, grow prob: 54%, move acc: 33%, fill: 6% (burnin)
..........
Iteration 80/1100, grow prob: 54%, move acc: 32%, fill: 6% (burnin)
..........
Iteration 90/1100, grow prob: 54%, move acc: 32%, fill: 6% (burnin)
..........
Iteration 100/1100, grow prob: 54%, move acc: 33%, fill: 6% (burnin)
..........
Iteration 110/1100, grow prob: 54%, move acc: 32%, fill: 6%
..........
Iteration 120/1100, grow prob: 53%, move acc: 33%, fill: 6%
..........
Iteration 130/1100, grow prob: 53%, move acc: 32%, fill: 6%
..........
Iteration 140/1100, grow prob: 54%, move acc: 32%, fill: 6%
..........
Iteration 150/1100, grow prob: 54%, move acc: 31%, fill: 6%
..........
Iteration 160/1100, grow prob: 54%, move acc: 32%, fill: 6%
..........
Iteration 170/1100, grow prob: 54%, move acc: 31%, fill: 6%
..........
Iteration 180/1100, grow prob: 54%, move acc: 31%, fill: 6%
..........
Iteration 190/1100, grow prob: 54%, move acc: 30%, fill: 6%
..........
Iteration 200/1100, grow prob: 54%, move acc: 30%, fill: 6%
..........
Iteration 210/1100, grow prob: 53%, move acc: 31%, fill: 6%
..........
Iteration 220/1100, grow prob: 55%, move acc: 30%, fill: 6%
..........
Iteration 230/1100, grow prob: 54%, move acc: 29%, fill: 6%
..........
Iteration 240/1100, grow prob: 54%, move acc: 30%, fill: 6%
..........
Iteration 250/1100, grow prob: 55%, move acc: 30%, fill: 6%
..........
Iteration 260/1100, grow prob: 54%, move acc: 29%, fill: 6%
..........
Iteration 270/1100, grow prob: 53%, move acc: 30%, fill: 6%
..........
Iteration 280/1100, grow prob: 54%, move acc: 29%, fill: 6%
..........
Iteration 290/1100, grow prob: 54%, move acc: 28%, fill: 6%
..........
Iteration 300/1100, grow prob: 54%, move acc: 29%, fill: 6%
..........
Iteration 310/1100, grow prob: 54%, move acc: 29%, fill: 6%
..........
Iteration 320/1100, grow prob: 55%, move acc: 29%, fill: 6%
..........
Iteration 330/1100, grow prob: 53%, move acc: 29%, fill: 6%
..........
Iteration 340/1100, grow prob: 53%, move acc: 28%, fill: 6%
..........
Iteration 350/1100, grow prob: 53%, move acc: 29%, fill: 6%
..........
Iteration 360/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 370/1100, grow prob: 54%, move acc: 28%, fill: 6%
..........
Iteration 380/1100, grow prob: 54%, move acc: 28%, fill: 6%
..........
Iteration 390/1100, grow prob: 53%, move acc: 28%, fill: 6%
..........
Iteration 400/1100, grow prob: 53%, move acc: 28%, fill: 6%
..........
Iteration 410/1100, grow prob: 55%, move acc: 28%, fill: 6%
..........
Iteration 420/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 430/1100, grow prob: 53%, move acc: 28%, fill: 6%
..........
Iteration 440/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 450/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 460/1100, grow prob: 54%, move acc: 26%, fill: 6%
..........
Iteration 470/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 480/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 490/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 500/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 510/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 520/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 530/1100, grow prob: 52%, move acc: 27%, fill: 6%
..........
Iteration 540/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 550/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 560/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 570/1100, grow prob: 54%, move acc: 26%, fill: 6%
..........
Iteration 580/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 590/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 600/1100, grow prob: 53%, move acc: 26%, fill: 6%
..........
Iteration 610/1100, grow prob: 53%, move acc: 25%, fill: 6%
..........
Iteration 620/1100, grow prob: 53%, move acc: 26%, fill: 6%
..........
Iteration 630/1100, grow prob: 53%, move acc: 26%, fill: 6%
..........
Iteration 640/1100, grow prob: 54%, move acc: 26%, fill: 6%
..........
Iteration 650/1100, grow prob: 54%, move acc: 26%, fill: 6%
..........
Iteration 660/1100, grow prob: 54%, move acc: 26%, fill: 6%
..........
Iteration 670/1100, grow prob: 53%, move acc: 26%, fill: 6%
..........
Iteration 680/1100, grow prob: 53%, move acc: 26%, fill: 6%
..........
Iteration 690/1100, grow prob: 53%, move acc: 26%, fill: 6%
..........
Iteration 700/1100, grow prob: 54%, move acc: 26%, fill: 6%
..........
Iteration 710/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 720/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 730/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 740/1100, grow prob: 55%, move acc: 25%, fill: 6%
..........
Iteration 750/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 760/1100, grow prob: 53%, move acc: 25%, fill: 6%
..........
Iteration 770/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 780/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 790/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 800/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 810/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 820/1100, grow prob: 53%, move acc: 24%, fill: 6%
..........
Iteration 830/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 840/1100, grow prob: 53%, move acc: 24%, fill: 6%
..........
Iteration 850/1100, grow prob: 53%, move acc: 25%, fill: 6%
..........
Iteration 860/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 870/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 880/1100, grow prob: 53%, move acc: 25%, fill: 6%
..........
Iteration 890/1100, grow prob: 53%, move acc: 24%, fill: 6%
..........
Iteration 900/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 910/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 920/1100, grow prob: 53%, move acc: 23%, fill: 6%
..........
Iteration 930/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 940/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 950/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 960/1100, grow prob: 53%, move acc: 23%, fill: 6%
..........
Iteration 970/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 980/1100, grow prob: 53%, move acc: 24%, fill: 6%
..........
Iteration 990/1100, grow prob: 53%, move acc: 24%, fill: 6%
..........
Iteration 1000/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 1010/1100, grow prob: 54%, move acc: 23%, fill: 6%
..........
Iteration 1020/1100, grow prob: 54%, move acc: 23%, fill: 6%
..........
Iteration 1030/1100, grow prob: 53%, move acc: 24%, fill: 6%
..........
Iteration 1040/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 1050/1100, grow prob: 54%, move acc: 23%, fill: 6%
..........
Iteration 1060/1100, grow prob: 54%, move acc: 23%, fill: 6%
..........
Iteration 1070/1100, grow prob: 53%, move acc: 23%, fill: 6%
..........
Iteration 1080/1100, grow prob: 54%, move acc: 23%, fill: 6%
..........
Iteration 1090/1100, grow prob: 53%, move acc: 23%, fill: 6%
..........
Iteration 1100/1100, grow prob: 54%, move acc: 24%, fill: 6%

Interpretation of the printout:

  • grow prob = the fraction of trees that bart tried to add leaves to, rather than remove leaves from

  • move acc = in the last iteration, the fraction of attempted tree changes that were kept

  • fill = how much of the fixed tree space is used; if it’s more that ~50% the trees are too deep, increase the number of trees

The fractions refer to the state of the trees after the last iteration, they are not averaged over multiple iterations.

A low acceptance means that the trees are changing very slowly.

The next cell computes the predictions.

from jax import numpy as jnp

# compute predictions
yhat_test = bart.predict(test.x) # posterior samples, n_samples x n_test
yhat_test_mean = jnp.mean(yhat_test, axis=0) # posterior mean point-by-point
yhat_test_var = jnp.var(yhat_test, axis=0) # posterior variance point-by-point

# RMSE
rmse = jnp.sqrt(jnp.mean(jnp.square(yhat_test_mean - test.y)))
expected_error_variance = jnp.mean(jnp.square(bart.sigma))
expected_rmse = jnp.sqrt(jnp.mean(yhat_test_var + expected_error_variance))
avg_sigma = jnp.sqrt(expected_error_variance)

print(f'total sdev: {jnp.std(train.y):#.2g}')
print(f'error sdev: {sigma:#.2g}')
print(f'RMSE: {rmse:#.2g}')
print(f'expected RMSE: {expected_rmse:#.2g}')
print(f'model error sdev: {avg_sigma:#.2g}')
print(f'time: {(end - start) / 60:#.2g} min')
total sdev: 1.0
error sdev: 0.10
RMSE: 0.45
expected RMSE: 0.51
model error sdev: 0.42
time: 9.9 min

The RMSE can at best be as low as the error standard deviation used to generate the data.