Quickstart

Basics

Import and use the bartz.BART.gbart class:

from bartz.BART import gbart
bart = gbart(X, y, ...)
y_pred = bart.predict(X_test)

The interface hews to the R package BART3, with a few differences explained in the documentation of bartz.BART.gbart. This interface has the longest ancestry, descending from the original BART implementation.

There is also a new interface where new features are added, bartz.Bart. However it is unstable and continuously evolving.

Getting a GPU

bartz works decently on CPU, but it shines on GPU where it can be up to 200x faster. Your personal computer may already have a nvidia GPU. The integrated GPUs on Apple Silicon Macs are not supported and won’t ever be. If you don’t have a GPU, you can rent one on the cloud, starting from about 0.10 $/hour, e.g., on vast.ai. If you are not familiar with connecting to remote machines, you can use Google colab.

JAX

bartz is implemented using jax, a Google library for machine learning. It allows to run the code on GPU or TPU and do various other things.

For basic usage, JAX is just an alternative implementation of numpy. The arrays returned by gbart are “jax arrays” instead of “numpy arrays”, but there is no perceived difference in their functionality. If you pass numpy arrays to bartz, they will be converted automatically. You don’t have to deal with jax in any way.

For advanced usage, refer to the jax documentation.

Advanced

bartz exposes the various functions that implement the MCMC of BART. You can use those yourself to try to make your own variant of BART. See the rest of the documentation for reference; the main entry points are bartz.mcmcstep.init and bartz.mcmcloop.run_mcmc.