rbartpackages.dbarts.dbarts¶
- class rbartpackages.dbarts.dbarts(formula, data=None, *, test=None, subset=None, weights=None, offset=None, offset_test=None, verbose=False, n_samples=None, tree_prior=None, node_prior=None, resid_prior=None, proposal_probs=None, control=None, sigma=None)[source]¶
Create a low-level
dbartssampler.Python interface to R’s
dbarts::dbarts, a mutable reference-class sampler that can be run, stopped, restarted, and modified in place. A stringformulaargument is converted to an R formula; the backwards-compatible matrix form (formula/dataas thex_train/y_trainpair) passes through unchanged. The named numeric vector form ofproposal_probsis given as a dictionary in Python. Arguments left toNoneare omitted from the R call, so R computes its own defaults, described below.The methods below modify the sampler in place or return results. Its fields are exposed as read-only properties read off the R object at each access, so they track the in-place updates.
- Parameters:
formula (
str|Float64[ndarray, 'n p']|DataFrame) – A model formula (as a string), or, in matrix mode, thex_trainpredictor matrix.data (
Float64[ndarray, 'n']|DataFrame|None, default:None) – The data frame theformularefers to, or, in matrix mode, they_trainresponse vector.test (
Float64[ndarray, 'm p']|DataFrame|None, default:None) – Test predictors, with the same columns as the training data.subset (
Integer[ndarray, 'k']|None, default:None) – Subset of observations to keep.weights (
Float64[ndarray, 'n']|None, default:None) – Per-observation weights; the model becomesy | x ~ N(f(x), sigma^2 / w).offset (
Float64[ndarray, 'n']|float|None, default:None) – Offset added tof(x); useful for binary responses, whereP(Y = 1 | x) = Phi(f(x) + offset).offset_test (
Float64[ndarray, 'm']|float|None, default:None) – Theoffsetfor the test data; defaults tooffsetwhen applicable.verbose (
bool, default:False) – Whether additional output is printed to the R console.n_samples (
int|None, default:None) – Default number of posterior samples per run, overridable indbarts.run;Nonekeeps thecontrolvalue (R’s default 800 otherwise). Passing it overridescontrol.tree_prior (
object|None, default:None) – Tree-structure prior, as an R expression of the formcgmorcgm(power, base, split.probs).node_prior (
object|None, default:None) – End-node prior, as an R expression of the formnormalornormal(k).resid_prior (
object|None, default:None) – Residual-variance prior, as an R expression of the formchisqorchisq(df, quant).proposal_probs (
dict[str,float] |Float64[ndarray, '4']|None, default:None) – Tree-proposal probabilities, as a dict with keys'birth_death','change','swap'(the proposal frequencies) and'birth'(the birth/death split).control (
dbartsControl|None, default:None) – AdbartsControlconfiguring the sampler.sigma (
float|None, default:None) – Estimate of the residual SD;Nonederives it from a linear fit.
Notes
The
tree_prior,node_prior, andresid_priorarguments are evaluated by R with non-standard scoping; to depart from the defaults pass an R language object (e.g. fromrpy2.robjects.r).R documentation
title ----- Discrete Bayesian Additive Regression Trees Sampler name ---- dbarts alias ----- dbarts description ----------- Creates a sampler object for a given problem which fits a Bayesian Additive Regreesion Trees model. Internally stores state in such a way as to be mutable. usage ----- dbarts( formula, data, test, subset, weights, offset, offset.test = offset, verbose = FALSE, n.samples = 800L, tree.prior = cgm, node.prior = normal, resid.prior = chisq, proposal.probs = c( birth_death = 0.5, swap = 0.1, change = 0.4, birth = 0.5), control = dbarts::dbartsControl(), sigma = NA_real_) arguments --------- formula An object of class formula following an analogous model description syntax as lm . For backwards compatibility, can also be the bart matrix x.train . data An optional data frame, list, or environment containing predictors to be used with the model. For backwards compatibility, can also be the bart vector y.train . test An optional matrix or data frame with the same number of predictors as data , or formula in backwards compatibility mode. If column names are present, a matching algorithm is used. subset An optional vector specifying a subset of observations to be used in the fitting process. weights An optional vector of weights to be used in the fitting process. When present, BART fits a model with observations y \mid x \sim N(f(x), \sigma^2 / w) y | x ~ N(f(x), \sigma^2 / w) , where f(x) is the unknown function. offset An optional vector specifying an offset from 0 for the relationship between the underyling function, f(x) , and the response y . Only is useful for binary responses, in which case the model fit is to assume P(Y = 1 \mid X = x) = \Phi(f(x) + \mathrm{offset}) P(Y = 1 | X = x) = \Phi(f(x) + offset) , where \Phi is the standard normal cumulative distribution function. offset.test The equivalent of offset for test observations. Will attempt to use offset when applicable. verbose A logical determining if additional output is printed to the console. See dbartsControl . n.samples A positive integer setting the default number of posterior samples to be returned for each run of the sampler. Can be overriden at run-time. See dbartsControl . tree.prior An expression of the form cgm or cgm(power, base, split.probs) setting the tree prior used in fitting. node.prior An expression of the form normal or normal(k) that sets the prior used on the averages within nodes. resid.prior An expression of the form chisq or chisq(df, quant) that sets the prior used on the residual/error variance. proposal.probs Named numeric vector or NULL , optionally specifying the proposal rules and their probabilities. Elements should be "birth_death" , "change" , and "swap" to control tree change proposals, and "birth" to give the relative frequency of birth/death in the "birth_death" step. control An object inheriting from dbartsControl , created by the dbartsControl function. sigma A positive numeric estimate of the residual standard deviation. If NA , a linear model is used with all of the predictors to obtain one. details ------- Discrete sampler refers to that dbarts is implemented using ReferenceClasses , so that there exists a mutable object constructed in C++ that is largely obscured from R. The dbarts function is the primary way of creating a dbartsSampler , for which a variety of methods exist. value ----- A reference object of dbartsSampler .
- property control: dbartsControl[source]¶
Return the control object of the sampler, as a
dbartsControlwrapper.
- property data: dbartsData[source]¶
Return the data object of the sampler, as a
dbartsDatawrapper.
- property state: NamedList | None[source]¶
The per-chain sampler states;
Noneunless cached withupdateState.
- run(numBurnIn=None, numSamples=None, *, updateState=None, numThreads=None)[source]¶
Run the sampler for
numBurnInburn-in plusnumSampleskept iterations.Either count left to
Noneis filled from thecontrolobject. The draws are returned as a dict of arrays,Noneif the sampler is run for zero samples.- Parameters:
numBurnIn (
int|None, default:None) – Number of burn-in iterations to discard.numSamples (
int|None, default:None) – Number of posterior samples to keep.updateState (
bool|None, default:None) – Whether to refresh the cached state after the run.numThreads (
int|None, default:None) – Number of threads to use for the run.
- Returns:
RunSamples|None– The draws as a dict of arrays, orNonefor a zero-sample run.
- sampleTreesFromPrior(*, updateState=None)[source]¶
Draw the tree structures from the prior, keeping the node parameters.
This leaves the sampler in an invalid state until the node parameters are drawn too.
- sampleNodeParametersFromPrior(*, updateState=None)[source]¶
Draw the end-node parameters from the prior, keeping the trees.
- copy(*, shallow=None)[source]¶
Create a deep (default) or shallow copy of the sampler.
The R method is broken once the sampler state has been cached; create the sampler with
dbartsControl(updateState=False)to use it.
- predict(x_test, offset_test=None, *, n_threads=None)[source]¶
Predict at new points without re-running the sampler.
Uses the current trees, giving a single prediction per point, or each kept set of trees with a
keepTreescontrol.- Parameters:
x_test (
Float64[ndarray, 'm p']|DataFrame) – New test predictors, with the same columns as the model.offset_test (
Float64[ndarray, 'm']|float|None, default:None) – Offset for the new points.n_threads (
int|None, default:None) – Number of threads to use; chains are predicted in parallel if more than one.
- Returns:
Float64[ndarray, 'm']|Float64[ndarray, 'm ndpost']– One prediction per point, or the kept-tree draws with akeepTreescontrol.
- setControl(newControl)[source]¶
Replace the control object of the sampler; needs
n_samplesset.- Parameters:
newControl (
dbartsControl) – The replacementdbartsControl.- Return type:
- setData(newData, *, updateState=None)[source]¶
Replace the data object of the sampler (a
dbartsData).- Parameters:
newData (
dbartsData) – The replacementdbartsData.updateState (
bool|None, default:None) – Whether to refresh the sampler’s cached state afterwards.
- Return type:
- setOffset(offset, *, updateScale=None, updateState=None)[source]¶
Replace the offset vector.
- Parameters:
offset (
Float64[ndarray, 'n']|float|None) – The replacement offset (a scalar is expanded to all observations), orNoneto clear it.updateScale (
bool|None, default:None) – Whether BART’s internal scale updates with the new offset; only valid during burn-in.updateState (
bool|None, default:None) – Whether to refresh the sampler’s cached state afterwards.
- Return type:
- setPredictor(x, column=None, forceUpdate=None, *, updateCutPoints=None, updateState=None)[source]¶
Replace the predictor matrix (or the 1-based
column).Unforced updates (
forceUpdate=False, the single-column default) return whether the update succeeded: it fails if a tree ends up with an empty leaf, rolling back the change. Whole-matrix updates are forced by default.- Parameters:
x (
Float64[ndarray, 'n cols']|Float64[ndarray, 'n']) – The replacement predictors: a whole matrix, or a single column’s values whencolumnis given.column (
int|String[ndarray, 'cols']|None, default:None) – The 1-based index or name of the single column to replace; the whole matrix is replaced if omitted.forceUpdate (
bool|None, default:None) – Whether to keep the update even if it leaves a tree with an empty leaf; defaultTruefor a whole matrix,Falsefor a column.updateCutPoints (
bool|None, default:None) – Whether to recompute the decision-rule cutpoints from the new predictors.updateState (
bool|None, default:None) – Whether to refresh the sampler’s cached state afterwards.
- Returns:
Int32[ndarray, '1']|None– Whether the update succeeded for an unforced update, elseNone.
- setTestPredictor(x_test, column=None)[source]¶
Replace the test predictor matrix (or the 1-based
column).- Parameters:
x_test (
Float64[ndarray, 'm cols']|Float64[ndarray, 'm']) – The replacement test predictors: a whole matrix, or a single column’s values whencolumnis given.column (
int|String[ndarray, 'cols']|None, default:None) – The 1-based index or name of the single column to replace; the whole matrix is replaced if omitted.
- Return type:
- setTestPredictorAndOffset(x_test, offset_test)[source]¶
Replace the test predictor matrix and the test offset.
- printTrees(treeNums, chainNums=None, sampleNums=None)[source]¶
Print the given trees to the R console.
- Parameters:
treeNums (
int|Integer[ndarray, 't']) – 1-based indices of the trees to print.chainNums (
int|Integer[ndarray, 'c']|None, default:None) – 1-based indices of the chains to print; all chains if omitted.sampleNums (
int|Integer[ndarray, 's']|None, default:None) – 1-based indices of the samples to print; the current trees if omitted.
- Return type:
- plotTree(treeNum, chainNum=None, sampleNum=None, *, treePlotPars=None)[source]¶
Plot the given tree with R graphics.
- Parameters:
treeNum (
int) – 1-based index of the tree to plot.chainNum (
int|None, default:None) – 1-based index of the chain to plot from.sampleNum (
int|None, default:None) – 1-based index of the sample to plot from.treePlotPars (
dict[str,float] |None, default:None) – Plot geometry, as a dict with keys'nodeHeight','nodeWidth', and'nodeGap'.
- Return type: