rbartpackages.bartMachine.bartMachine

class rbartpackages.bartMachine.bartMachine(X=None, y=None, *, Xy=None, num_trees=50, num_burn_in=250, num_iterations_after_burn_in=1000, alpha=0.95, beta=2.0, k=2.0, q=0.9, nu=3.0, prob_rule_class=0.5, mh_prob_steps=None, debug_log=False, run_in_sample=True, s_sq_y='mse', sig_sq_est=None, print_tree_illustrations=False, cov_prior_vec=None, interaction_constraints=None, use_missing_data=False, num_rand_samps_in_library=10000, use_missing_data_dummies_as_covars=False, replace_missing_data_with_x_j_bar=False, impute_missingness_with_rf_impute=False, impute_missingness_with_x_j_bar_for_lm=True, mem_cache_for_speed=True, flush_indices_to_save_RAM=True, serialize=False, seed=None, use_xoshiro=False, verbose=True)[source]

Fit BART to continuous or binary outcomes.

Python interface to R’s bartMachine::bartMachine. The predictors X must be a data frame (factor columns are expanded internally); y is a numeric vector for regression or a two-level factor for classification. Pass X and y separately, or combined as Xy. The number of fitting threads is a package-global setting, see set_bart_machine_num_cores. Arguments left to None are omitted from the R call, so R computes its own defaults, described below.

Parameters:
  • X (DataFrame | None, default: None) – Data frame of predictors; rows are observations. Factors are expanded into indicator columns internally.

  • y (Float64[ndarray, 'n'] | String[ndarray, 'n'] | None, default: None) – Response: numeric for regression, two-level categorical for classification. A numeric numpy array or pandas/polars Series works for regression; classification needs a factor, so pass a string/object numpy array (whose levels are then ordered alphabetically, the first being the positive one) or a categorical Series (which controls the level order).

  • Xy (DataFrame | None, default: None) – Predictors and response combined in one data frame, the response in a column named 'y'; an alternative to passing X and y.

  • num_trees (int, default: 50) – Number of trees in the sum-of-trees model.

  • num_burn_in (int, default: 250) – Number of burn-in MCMC iterations discarded.

  • num_iterations_after_burn_in (int, default: 1000) – Number of posterior draws kept after burn-in.

  • alpha (float, default: 0.95) – Base of the nonterminal-node probability in the tree prior.

  • beta (float, default: 2.0) – Power of the nonterminal-node probability in the tree prior.

  • k (float, default: 2.0) – Number of prior SDs of E[y|x] in half the response range; larger shrinks more.

  • q (float, default: 0.9) – Quantile of the error-variance prior at which the data-based estimate is placed (regression only).

  • nu (float, default: 3.0) – Degrees of freedom of the inverse-chi-squared error-variance prior (regression only).

  • prob_rule_class (float, default: 0.5) – Probability threshold above which a class prediction gets the first (positive) level (classification only).

  • mh_prob_steps (Float64[ndarray, '3'] | None, default: None) – Prior probabilities of the grow/prune/change Metropolis-Hastings tree proposals; default (2.5, 2.5, 4) / 9.

  • debug_log (bool, default: False) – Whether the Java backend logs to a file in the working directory.

  • run_in_sample (bool, default: True) – Whether the in-sample (*_train) statistics are computed.

  • s_sq_y (Literal['mse', 'var'], default: 'mse') – How the error-variance estimate is computed, 'mse' (least-squares residuals) or 'var' (response variance); regression only.

  • sig_sq_est (float | None, default: None) – Data-based error-variance estimate anchoring the prior; default a linear-model estimate (regression only).

  • print_tree_illustrations (bool, default: False) – Whether every Gibbs iteration prints a side-by-side tree illustration; extremely slow.

  • cov_prior_vec (Float64[ndarray, 'p'] | None, default: None) – Relative split-proposal weight of each predictor (after dummification and missingness augmentation); internally normalized.

  • interaction_constraints (dict[str, Integer[ndarray, 'k'] | String[ndarray, 'k']] | None, default: None) – Groups of predictors allowed to interact, as a dict of vectors of 1-based column indices or column names (e.g. {'a': [1, 2], 'b': ['nox']}); rpy2 converts a dict, not a list, to the R list bartMachine wants.

  • use_missing_data (bool, default: False) – Whether missing entries are handled natively by the splits, without imputation.

  • num_rand_samps_in_library (int, default: 10000) – Size of the pre-drawn normal/chi-squared sample library passed to Java.

  • use_missing_data_dummies_as_covars (bool, default: False) – Whether per-predictor missingness indicators are added to the design matrix.

  • replace_missing_data_with_x_j_bar (bool, default: False) – Whether missing entries are imputed with column averages/modes.

  • impute_missingness_with_rf_impute (bool, default: False) – Whether missing entries are filled with randomForest::rfImpute.

  • impute_missingness_with_x_j_bar_for_lm (bool, default: True) – Whether the linear model behind sig_sq_est imputes missing entries with column averages/modes.

  • mem_cache_for_speed (bool, default: True) – Whether the Java backend caches the candidate split values at each node; faster but memory-hungry.

  • flush_indices_to_save_RAM (bool, default: True) – Whether the Java backend flushes internal indices to save memory (disables node_prediction_training_data_indices and get_projection_weights).

  • serialize (bool, default: False) – Whether the Java model is serialized into the R object so it survives saving and reloading; memory-hungry.

  • seed (int | None, default: None) – Seed of the R and Java RNGs; None does not seed. Deterministic only when fitting single-threaded.

  • use_xoshiro (bool, default: False) – Whether the Java backend uses the Xoshiro256PlusPlus RNG rather than the legacy MersenneTwister.

  • verbose (bool, default: True) – Whether fitting progress is printed to the screen.

Notes

The private R argument covariates_to_permute (used internally by cov_importance_test) is not exposed.

R documentation

title
-----

Build a BART Model

name
----

bartMachine

alias
-----

build_bart_machine

description
-----------

 Builds a BART model for regression or classification.


usage
-----


 bartMachine(
   X = NULL,
   y = NULL,
   Xy = NULL,
   num_trees = 50,
   num_burn_in = 250,
   num_iterations_after_burn_in = 1000,
   alpha = 0.95,
   beta = 2,
   k = 2,
   q = 0.9,
   nu = 3,
   prob_rule_class = 0.5,
   mh_prob_steps = c(2.5, 2.5, 4)/9,
   debug_log = FALSE,
   run_in_sample = TRUE,
   s_sq_y = "mse",
   sig_sq_est = NULL,
   print_tree_illustrations = FALSE,
   cov_prior_vec = NULL,
   interaction_constraints = NULL,
   use_missing_data = FALSE,
   covariates_to_permute = NULL,
   num_rand_samps_in_library = 10000,
   use_missing_data_dummies_as_covars = FALSE,
   replace_missing_data_with_x_j_bar = FALSE,
   impute_missingness_with_rf_impute = FALSE,
   impute_missingness_with_x_j_bar_for_lm = TRUE,
   mem_cache_for_speed = TRUE,
   flush_indices_to_save_RAM = TRUE,
   serialize = FALSE,
   seed = NULL,
   use_xoshiro = FALSE,
   verbose = TRUE
 )

 build_bart_machine(
   X = NULL,
   y = NULL,
   Xy = NULL,
   num_trees = 50,
   num_burn_in = 250,
   num_iterations_after_burn_in = 1000,
   alpha = 0.95,
   beta = 2,
   k = 2,
   q = 0.9,
   nu = 3,
   prob_rule_class = 0.5,
   mh_prob_steps = c(2.5, 2.5, 4)/9,
   debug_log = FALSE,
   run_in_sample = TRUE,
   s_sq_y = "mse",
   sig_sq_est = NULL,
   print_tree_illustrations = FALSE,
   cov_prior_vec = NULL,
   interaction_constraints = NULL,
   use_missing_data = FALSE,
   covariates_to_permute = NULL,
   num_rand_samps_in_library = 10000,
   use_missing_data_dummies_as_covars = FALSE,
   replace_missing_data_with_x_j_bar = FALSE,
   impute_missingness_with_rf_impute = FALSE,
   impute_missingness_with_x_j_bar_for_lm = TRUE,
   mem_cache_for_speed = TRUE,
   flush_indices_to_save_RAM = TRUE,
   serialize = FALSE,
   seed = NULL,
   use_xoshiro = FALSE,
   verbose = TRUE
 )


arguments
---------


 X Data frame of predictors. Factors are automatically converted to dummies internally.

 y Vector of response variable. If  y  is  numeric  or  integer , a BART model for regression is built. If  y  is a factor with two levels, a BART model for classification is built.

 Xy A data frame of predictors and the response. The response column must be named ``y''.

 num_trees The number of trees to be grown in the sum-of-trees model.

 num_burn_in Number of MCMC samples to be discarded as ``burn-in''.

 num_iterations_after_burn_in Number of MCMC samples to draw from the posterior distribution of  \hat{f}(x) .

 alpha Base hyperparameter in tree prior for whether a node is nonterminal or not.

 beta Power hyperparameter in tree prior for whether a node is nonterminal or not.

 k For regression,  k  determines the prior probability that  E(Y|X)  is contained in the interval  (y_{min}, y_{max}) , based on a normal distribution. For example, when  k=2 , the prior probability is 95%. For classification,  k  determines the prior probability that  E(Y|X)  is between  (-3,3) . Note that a larger value of  k  results in more shrinkage and a more conservative fit.

 q Quantile of the prior on the error variance at which the data-based estimate is placed. Note that the larger the value of  q , the more aggressive the fit as you are placing more prior weight on values lower than the data-based estimate. Not used for classification.

 nu Degrees of freedom for the inverse  \chi^2  prior. Not used for classification.

 prob_rule_class Threshold for classification. Any observation with a conditional probability greater than  prob_class_rule  is assigned the ``positive'' outcome. Note that the first level of the response is treated as the ``positive'' outcome and the second is treated as the ``negative'' outcome.

 mh_prob_steps Vector of prior probabilities for proposing changes to the tree structures: (GROW, PRUNE, CHANGE)

 debug_log If TRUE, additional information about the model construction are printed to a file in the working directory.

 run_in_sample If TRUE, in-sample statistics such as  \hat{f}(x) , Pseudo- R^2 , and RMSE are computed. Setting this to FALSE when not needed can decrease computation time.

 s_sq_y If ``mse'', a data-based estimated of the error variance is computed as the MSE from ordinary least squares regression. If ``var''., the data-based estimate is computed as the variance of the response. Not used in classification.

 sig_sq_est Pass in an estimate of the maximum sig_sq of the model. This is useful to cache somewhere and then pass in during cross-validation since the default method of estimation is a linear model. In large dimensions, linear model estimation is slow.

 print_tree_illustrations For every Gibbs iteration, print out an illustration of the trees side-by-side. This is excruciatingly SLOW!

 cov_prior_vec Vector assigning relative weights to how often a particular variable should be proposed as a candidate for a split. The vector is internally normalized so that the weights sum to 1. Note that the length of this vector must equal the length of the design matrix after dummification and augmentation of indicators of missingness (if used). To see what the dummified matrix looks like, use  dummify_data . See Bleich et al. (2013) for more details on when this feature is most appropriate.

 interaction_constraints A list of vectors indicating where the vectors are sets of elements allowed to interact with one another. The elements in each
 vector correspond to features in the data frame  X  specified by either the column number as a numeric value or the column
 name as a string e.g.  list(c(1, 2), c("nox", "rm")) . The elements of the vectors can be reused among components for any
 level of interaction complexity you wish. Default is  NULL  which corresponds to the vanilla modeling procedure where
 all interactions are legal. For a pure generalized added model, use  as.list(seq(1 : p))  where  p
 is the number of columns in the design matrix  X .

 use_missing_data If TRUE, the missing data feature is used to automatically handle missing data without imputation. See Kapelner and Bleich (2013) for details.

 covariates_to_permute Private argument for  cov_importance_test . Not needed by user.

 num_rand_samps_in_library Before building a BART model, samples from the Standard Normal and  \chi^2(\nu)  are drawn to be used in the MCMC steps. This parameter determines the number of samples to be taken.

 use_missing_data_dummies_as_covars If TRUE, additional indicator variables for whether or not an observation in a particular column is missing are included. See Kapelner and Bleich (2013) for details.

 replace_missing_data_with_x_j_bar If TRUE ,missing entries in  X  are imputed with average value or modal category.

 impute_missingness_with_rf_impute If TRUE, missing entries are filled in using the rf.impute() function from the  randomForest  library.

 impute_missingness_with_x_j_bar_for_lm If TRUE, when computing the data-based estimate of  \sigma^2 , missing entries are imputed with average value or modal category.

 mem_cache_for_speed Speed enhancement that caches the predictors and the split values that are available at each node for selecting new rules. If the number
 of predictors is large, the memory requirements become large. We recommend keeping this on (default) and turning it off if you experience out-of-memory errors.

 flush_indices_to_save_RAM Setting this flag to  TRUE  saves memory with the downside that you cannot use the functions  node_prediction_training_data_indices  nor  get_projection_weights .

 serialize Setting this option to  TRUE  will allow serialization of bartMachine objects which allows for persistence between
 R sessions if the object is saved and reloaded. Note that serialized objects can take up a large amount of memory.
 Thus, the default is  FALSE .

 seed Optional: sets the seed in both R and Java. Default is  NULL  which does not set the seed in R nor Java.
 Setting the seed enforces deterministic behavior only in the case when one core is used (the default before
 set_bart_machine_num_cores() was invoked .

 use_xoshiro if TRUE, use the Xoshiro256PlusPlus random number generator; if FALSE, use the legacy MersenneTwister
 random number generator (default is FALSE)

 verbose Prints information about progress of the algorithm to the screen.


value
-----


 Returns an object of class ``bartMachine''. The ``bartMachine'' object contains a list of the following components:

    java_bart_machine A pointer to the BART Java object.
    train_data_features The names of the variables used in the training data.
    training_data_features_with_missing_features. The names of the variables used in the training data. If  use_missing_data_dummies_as_covars = TRUE , this also includes dummies for any predictors that contain at least one missing entry (named ``M_<feature>'').
    y The values of the response for the training data.
    y_levels The levels of the response (for classification only).
    pred_type Whether the model was build for regression of classification.
    model_matrix_training_data The training data with factors converted to dummies.
    num_cores The number of cores used to build the BART model.
    sig_sq_est The data-based estimate of  \sigma^2  used to create the prior on the error variance for the BART model.
    time_to_build Total time to build the BART model.
    y_hat_train The posterior means of  \hat{f}(x)  for each observation. Only returned if  run_in_sample = TRUE .
    residuals The model residuals given by  y  -  y_hat_train . Only returned if  run_in_sample = TRUE .
    L1_err_train L1 error on the training set. Only returned if  run_in_sample = TRUE .
    L2_err_train L2 error on the training set. Only returned if  run_in_sample = TRUE .
    PseudoRsq Calculated as 1 - SSE / SST where SSE is the sum of square errors in the training data and SST is the sample variance of the response times  n-1 . Only returned if  run_in_sample = TRUE .
    rmse_train Root mean square error on the training set. Only returned if  run_in_sample = TRUE .

 Additionally, the parameters passed to the function  bartMachine  are also components of the list.


note
----


 This function is parallelized by the number of cores set by  set_bart_machine_num_cores . Each core will create an
 independent MCMC chain of size
 num_burn_in   +   num_iterations_after_burn_in / bart_machine_num_cores .


author
------


 Adam Kapelner and Justin Bleich


references
----------


 Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning
 with Bayesian Additive Regression Trees. Journal of Statistical
 Software, 70(4), 1-40.  \Sexpr[results=rd]{tools:::Rd_expr_doi("#1")} 10.18637/jss.v070.i04 text doi:10.18637/jss.v070.i04 <https://doi.org/10.18637/jss.v070.i04> latex https://doi.org/10.18637/jss.v070.i04 doi:10.18637 \slash{} jss.v070.i04 https://doi.org/10.18637/jss.v070.i04 doi:10.18637/jss.v070.i04

 HA Chipman, EI George, and RE McCulloch. BART: Bayesian Additive Regressive Trees.
 The Annals of Applied Statistics, 4(1): 266--298, 2010.

 A Kapelner and J Bleich. Prediction with Missing Data via Bayesian Additive Regression
 Trees. Canadian Journal of Statistics, 43(2): 224-239, 2015

 J Bleich, A Kapelner, ST Jensen, and EI George. Variable Selection Inference for Bayesian
 Additive Regression Trees. ArXiv e-prints, 2013.


seealso
-------


 bartMachineCV


examples
--------



 ##regression example

 ##generate Friedman data
 set.seed(11)
 n  = 200
 p = 5
 X = data.frame(matrix(runif(n * p), ncol = p))
 y = 10 * sin(pi* X[ ,1] * X[,2]) +20 * (X[,3] -.5)^2 + 10 * X[ ,4] + 5 * X[,5] + rnorm(n)

 ##build BART regression model
 bart_machine = bartMachine(X, y)
 summary(bart_machine)

 ##Build another BART regression model
 bart_machine = bartMachine(X,y, num_trees = 200, num_burn_in = 500,
 num_iterations_after_burn_in = 1000)

 ##Classification example

 #get data and only use 2 factors
 data(iris)
 iris2 = iris[51:150,]
 iris2$Species = factor(iris2$Species)

 #build BART classification model
 bart_machine = build_bart_machine(iris2[ ,1:4], iris2$Species)

 ##get estimated probabilities
 phat = bart_machine$p_hat_train
 ##look at in-sample confusion matrix
 bart_machine$confusion_matrix
L1_err_train: float | None = None

In-sample L1 error (regression with run_in_sample only).

L2_err_train: float | None = None

In-sample L2 error (regression with run_in_sample only).

PseudoRsq: float | None = None

In-sample 1 - L2_err_train / L2 of the mean (regression with run_in_sample only).

X: DataFrame

Training predictors as supplied (factors not expanded).

A polars frame if polars is installed, else a pandas one.

alpha: float

Base of the nonterminal-node probability in the tree prior.

beta: float

Power of the nonterminal-node probability in the tree prior.

confusion_matrix: DataFrame | None = None

In-sample confusion matrix with error rates (classification with run_in_sample only).

A polars frame (so without the row labels) if polars is installed, else a pandas one.

cov_prior_vec: Float64[ndarray, 'p'] | None = None

Relative split-proposal weight of each predictor.

None unless given, except that expanded factors get a default down-weighting their indicator columns by the number of levels.

debug_log: bool

Whether the Java backend logged to a file.

flush_indices_to_save_RAM: bool

Whether the Java backend flushed internal indices to save memory.

impute_missingness_with_rf_impute: bool

Whether missing training entries got randomForest::rfImpute imputations added.

impute_missingness_with_x_j_bar_for_lm: bool

Whether the linear model behind sig_sq_est imputed missing entries with column averages.

java_bart_machine: RS4

RJava reference to the Java model; opaque to Python.

k: float

Number of prior SDs of E[y|x] in half the response range; larger shrinks more.

mem_cache_for_speed: bool

Whether the Java backend cached predictor-index sets at the nodes.

mh_prob_steps: Float64[ndarray, '3']

Probabilities of the grow/prune/change Metropolis-Hastings proposals (normalized).

misclassification_error: float | None = None

In-sample misclassification rate (classification with run_in_sample only).

model_matrix_training_data: Float64[ndarray, 'n p+1']

Preprocessed training matrix; the response is the last column (1 = first level).

n: int

Number of training observations.

nu: float

Degrees of freedom of the inverse-chi-squared error-variance prior (regression).

num_burn_in: int

Number of burn-in MCMC iterations discarded.

num_cores: int

Number of threads used to fit the model.

num_gibbs: int

Total number of MCMC iterations, burn-in included.

num_iterations_after_burn_in: int

Number of posterior draws kept.

num_rand_samps_in_library: int

Size of the pre-drawn normal/chi-squared sample library passed to Java.

num_trees: int

Number of trees in the sum-of-trees model.

p: int

Number of predictors after preprocessing (factors expanded, missingness dummies included).

p_hat_train: Float64[ndarray, 'n'] | None = None

In-sample probability of the first level (classification with run_in_sample only).

pred_type: str

'regression' or 'classification'.

prob_rule_class: float

Probability threshold above which class predictions get the first level.

q: float

Quantile of the error-variance prior at which sig_sq_est is placed (regression).

replace_missing_data_with_x_j_bar: bool

Whether missing entries were imputed with column averages/modes.

residuals: Float64[ndarray, 'n'] | None = None

In-sample y - y_hat_train (regression with run_in_sample only).

rmse_train: float | None = None

In-sample root-mean-square error (regression with run_in_sample only).

run_in_sample: bool

Whether the in-sample (*_train) outputs were computed.

s_sq_y: str

How sig_sq_est is estimated, 'mse' (linear model) or 'var' (sample variance).

seed: int | None = None

Seed of the Java RNG; None if not given.

serialize: bool

Whether the Java model was serialized into the R object (to survive saving).

sig_sq_est: float | None = None

Data-based error-variance estimate anchoring the prior (regression only).

training_data_features: String[ndarray, '<=p']

Names of the design-matrix columns, excluding the missingness dummies.

training_data_features_with_missing_features: String[ndarray, 'p']

Names of all design-matrix columns, missingness dummies included (if used).

use_missing_data: bool

Whether missing entries were handled natively by the splits.

use_missing_data_dummies_as_covars: bool

Whether per-predictor missingness dummies were added to the design matrix.

use_xoshiro: bool

Whether the Java backend used the xoshiro RNG.

verbose: bool

Whether fitting messages were printed.

y: Float64[ndarray, 'n'] | String[ndarray, 'n']

Training response, numeric for regression and the labels for classification.

y_hat_train: Float64[ndarray, 'n'] | String[ndarray, 'n'] | None = None

In-sample posterior means (regression) or thresholded labels (classification).

Computed with run_in_sample only.

y_levels: String[ndarray, '2'] | None = None

The response levels, the first being the target one (classification only).

time_to_build: float

Wall-clock seconds taken to fit the model.

interaction_constraints: tuple[Float64[ndarray, 'group[i]'], ...] | None = None

Groups of predictors allowed to interact, as 0-based column indices; None if not given.

predict(new_data, *, type=None, prob_rule_class=None, verbose=None)[source]

Posterior-mean predictions at the rows of new_data.

For regression fits, the posterior mean of f(x). For classification fits, the probability of the first level (type='prob', the default) or the corresponding labels (type='class'). Arguments left to None are omitted from the R call, so R computes its own defaults.

Parameters:
  • new_data (DataFrame) – Predictors to predict at, with the same columns as the training data.

  • type (Literal['prob', 'class'] | None, default: None) – For classification fits, whether to return the first-level probability ('prob') or the predicted labels ('class'); ignored for regression.

  • prob_rule_class (float | None, default: None) – Probability threshold for a 'class' prediction; default the fit’s prob_rule_class.

  • verbose (bool | None, default: None) – Whether to print prediction messages to the R console.

Returns:

Float64[ndarray, 'm'] | String[ndarray, 'm'] – The posterior means (or labels with type='class') at new_data.