Simulated data¶
This notebook runs bartz on simulated data. It is meant to be run on a GPU. Use the following link to try it out on colab: link
The next cell installs bartz:
%pip install bartz
The next cell defines configuration parameters for the script:
n_train = 100_000 # number of training points
p = 1000 # number of predictors/features
sigma = 0.1 # error standard deviation
n_test = 1000 # number of test points
n_tree = 10_000 # number of trees used by bartz
The next cell generates simulated data from a linear + quadratic test model that comes packaged with bartz:
from jax import random
from bartz.testing import gen_data
# list of independent seeds for random sampling
keys = list(random.split(random.key(2024_04_16_18_53), 2))
# simulate data with bartz's built-in testing data generating process
data = gen_data(
keys.pop(),
n=n_train + n_test,
p=p,
q=2, # number of interactions, each predictor interacts with other q predictors in the quadratic term
sigma2_eps=sigma**2, # error variance
sigma2_lin=0.5, # linear term variance
sigma2_quad=0.5, # quadratic term variance
k=1, # number of outcomes
lam=1.0, # correlation between outcomes, unused in this case
)
# split data in train-test
train, test = data.split(n_train)
The next cell runs bartz:
from time import perf_counter
from bartz.BART import gbart
# clock bartz
start = perf_counter()
bart = gbart(train.x, train.y.squeeze(0), ntree=n_tree, printevery=10, seed=keys.pop())
end = perf_counter()
..........
Iteration 10/1100, grow prob: 54%, move acc: 37%, fill: 6% (burnin)
..........
Iteration 20/1100, grow prob: 53%, move acc: 36%, fill: 6% (burnin)
..........
Iteration 30/1100, grow prob: 53%, move acc: 35%, fill: 6% (burnin)
..........
Iteration 40/1100, grow prob: 54%, move acc: 34%, fill: 6% (burnin)
..........
Iteration 50/1100, grow prob: 53%, move acc: 34%, fill: 6% (burnin)
..........
Iteration 60/1100, grow prob: 54%, move acc: 34%, fill: 6% (burnin)
..........
Iteration 70/1100, grow prob: 54%, move acc: 33%, fill: 6% (burnin)
..........
Iteration 80/1100, grow prob: 54%, move acc: 32%, fill: 6% (burnin)
..........
Iteration 90/1100, grow prob: 54%, move acc: 32%, fill: 6% (burnin)
..........
Iteration 100/1100, grow prob: 54%, move acc: 33%, fill: 6% (burnin)
..........
Iteration 110/1100, grow prob: 54%, move acc: 32%, fill: 6%
..........
Iteration 120/1100, grow prob: 53%, move acc: 33%, fill: 6%
..........
Iteration 130/1100, grow prob: 53%, move acc: 32%, fill: 6%
..........
Iteration 140/1100, grow prob: 54%, move acc: 32%, fill: 6%
..........
Iteration 150/1100, grow prob: 54%, move acc: 31%, fill: 6%
..........
Iteration 160/1100, grow prob: 54%, move acc: 32%, fill: 6%
..........
Iteration 170/1100, grow prob: 54%, move acc: 31%, fill: 6%
..........
Iteration 180/1100, grow prob: 54%, move acc: 31%, fill: 6%
..........
Iteration 190/1100, grow prob: 54%, move acc: 30%, fill: 6%
..........
Iteration 200/1100, grow prob: 54%, move acc: 30%, fill: 6%
..........
Iteration 210/1100, grow prob: 53%, move acc: 31%, fill: 6%
..........
Iteration 220/1100, grow prob: 55%, move acc: 30%, fill: 6%
..........
Iteration 230/1100, grow prob: 54%, move acc: 29%, fill: 6%
..........
Iteration 240/1100, grow prob: 54%, move acc: 30%, fill: 6%
..........
Iteration 250/1100, grow prob: 55%, move acc: 30%, fill: 6%
..........
Iteration 260/1100, grow prob: 54%, move acc: 29%, fill: 6%
..........
Iteration 270/1100, grow prob: 53%, move acc: 30%, fill: 6%
..........
Iteration 280/1100, grow prob: 54%, move acc: 29%, fill: 6%
..........
Iteration 290/1100, grow prob: 54%, move acc: 28%, fill: 6%
..........
Iteration 300/1100, grow prob: 54%, move acc: 29%, fill: 6%
..........
Iteration 310/1100, grow prob: 54%, move acc: 29%, fill: 6%
..........
Iteration 320/1100, grow prob: 55%, move acc: 29%, fill: 6%
..........
Iteration 330/1100, grow prob: 53%, move acc: 29%, fill: 6%
..........
Iteration 340/1100, grow prob: 53%, move acc: 28%, fill: 6%
..........
Iteration 350/1100, grow prob: 53%, move acc: 29%, fill: 6%
..........
Iteration 360/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 370/1100, grow prob: 54%, move acc: 28%, fill: 6%
..........
Iteration 380/1100, grow prob: 54%, move acc: 28%, fill: 6%
..........
Iteration 390/1100, grow prob: 53%, move acc: 28%, fill: 6%
..........
Iteration 400/1100, grow prob: 53%, move acc: 28%, fill: 6%
..........
Iteration 410/1100, grow prob: 55%, move acc: 28%, fill: 6%
..........
Iteration 420/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 430/1100, grow prob: 53%, move acc: 28%, fill: 6%
..........
Iteration 440/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 450/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 460/1100, grow prob: 54%, move acc: 26%, fill: 6%
..........
Iteration 470/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 480/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 490/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 500/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 510/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 520/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 530/1100, grow prob: 52%, move acc: 27%, fill: 6%
..........
Iteration 540/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 550/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 560/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 570/1100, grow prob: 54%, move acc: 26%, fill: 6%
..........
Iteration 580/1100, grow prob: 53%, move acc: 27%, fill: 6%
..........
Iteration 590/1100, grow prob: 54%, move acc: 27%, fill: 6%
..........
Iteration 600/1100, grow prob: 53%, move acc: 26%, fill: 6%
..........
Iteration 610/1100, grow prob: 53%, move acc: 25%, fill: 6%
..........
Iteration 620/1100, grow prob: 53%, move acc: 26%, fill: 6%
..........
Iteration 630/1100, grow prob: 53%, move acc: 26%, fill: 6%
..........
Iteration 640/1100, grow prob: 54%, move acc: 26%, fill: 6%
..........
Iteration 650/1100, grow prob: 54%, move acc: 26%, fill: 6%
..........
Iteration 660/1100, grow prob: 54%, move acc: 26%, fill: 6%
..........
Iteration 670/1100, grow prob: 53%, move acc: 26%, fill: 6%
..........
Iteration 680/1100, grow prob: 53%, move acc: 26%, fill: 6%
..........
Iteration 690/1100, grow prob: 53%, move acc: 26%, fill: 6%
..........
Iteration 700/1100, grow prob: 54%, move acc: 26%, fill: 6%
..........
Iteration 710/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 720/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 730/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 740/1100, grow prob: 55%, move acc: 25%, fill: 6%
..........
Iteration 750/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 760/1100, grow prob: 53%, move acc: 25%, fill: 6%
..........
Iteration 770/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 780/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 790/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 800/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 810/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 820/1100, grow prob: 53%, move acc: 24%, fill: 6%
..........
Iteration 830/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 840/1100, grow prob: 53%, move acc: 24%, fill: 6%
..........
Iteration 850/1100, grow prob: 53%, move acc: 25%, fill: 6%
..........
Iteration 860/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 870/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 880/1100, grow prob: 53%, move acc: 25%, fill: 6%
..........
Iteration 890/1100, grow prob: 53%, move acc: 24%, fill: 6%
..........
Iteration 900/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 910/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 920/1100, grow prob: 53%, move acc: 23%, fill: 6%
..........
Iteration 930/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 940/1100, grow prob: 54%, move acc: 25%, fill: 6%
..........
Iteration 950/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 960/1100, grow prob: 53%, move acc: 23%, fill: 6%
..........
Iteration 970/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 980/1100, grow prob: 53%, move acc: 24%, fill: 6%
..........
Iteration 990/1100, grow prob: 53%, move acc: 24%, fill: 6%
..........
Iteration 1000/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 1010/1100, grow prob: 54%, move acc: 23%, fill: 6%
..........
Iteration 1020/1100, grow prob: 54%, move acc: 23%, fill: 6%
..........
Iteration 1030/1100, grow prob: 53%, move acc: 24%, fill: 6%
..........
Iteration 1040/1100, grow prob: 54%, move acc: 24%, fill: 6%
..........
Iteration 1050/1100, grow prob: 54%, move acc: 23%, fill: 6%
..........
Iteration 1060/1100, grow prob: 54%, move acc: 23%, fill: 6%
..........
Iteration 1070/1100, grow prob: 53%, move acc: 23%, fill: 6%
..........
Iteration 1080/1100, grow prob: 54%, move acc: 23%, fill: 6%
..........
Iteration 1090/1100, grow prob: 53%, move acc: 23%, fill: 6%
..........
Iteration 1100/1100, grow prob: 54%, move acc: 24%, fill: 6%
Interpretation of the printout:
grow prob = the fraction of trees that bart tried to add leaves to, rather than remove leaves from
move acc = in the last iteration, the fraction of attempted tree changes that were kept
fill = how much of the fixed tree space is used; if it’s more that ~50% the trees are too deep, increase the number of trees
The fractions refer to the state of the trees after the last iteration, they are not averaged over multiple iterations.
A low acceptance means that the trees are changing very slowly.
The next cell computes the predictions.
from jax import numpy as jnp
# compute predictions
yhat_test = bart.predict(test.x) # posterior samples, n_samples x n_test
yhat_test_mean = jnp.mean(yhat_test, axis=0) # posterior mean point-by-point
yhat_test_var = jnp.var(yhat_test, axis=0) # posterior variance point-by-point
# RMSE
rmse = jnp.sqrt(jnp.mean(jnp.square(yhat_test_mean - test.y)))
expected_error_variance = jnp.mean(jnp.square(bart.sigma))
expected_rmse = jnp.sqrt(jnp.mean(yhat_test_var + expected_error_variance))
avg_sigma = jnp.sqrt(expected_error_variance)
print(f'total sdev: {jnp.std(train.y):#.2g}')
print(f'error sdev: {sigma:#.2g}')
print(f'RMSE: {rmse:#.2g}')
print(f'expected RMSE: {expected_rmse:#.2g}')
print(f'model error sdev: {avg_sigma:#.2g}')
print(f'time: {(end - start) / 60:#.2g} min')
total sdev: 1.0
error sdev: 0.10
RMSE: 0.45
expected RMSE: 0.51
model error sdev: 0.42
time: 9.9 min
The RMSE can at best be as low as the error standard deviation used to generate the data.