-
Notifications
You must be signed in to change notification settings - Fork 14
Open
Description
In some simulation studies, stochtree
(R package) is dramatically slower in Windows than in MacOS or Linux [all are running on performant hardware, but of course not completely apples-to-apples]. Features that appear to trigger this massive performance differential:
- Data generating processes that encourage deep trees (because of deep interactions between features)
- Large sample sizes that support growing deep trees
- A large cutpoint grid in the grow-from-root algorithm (the
cutpoint_grid_size
parameter in thestochtree::bart
andstochtree::bcf
function signatures)
To view this performance gap, run the following code on both Windows and MacOS / Linux
# Load libraries
library(stochtree)
library(rnn)
# Random seed
random_seed <- 1234
set.seed(random_seed)
# Fixed parameters
sample_size <- 500000
alpha <- 1.0
beta <- 0.1
ntree <- 50
num_iter <- 10
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 10
cutpoint_grid_size <- 10
min_samples_leaf <- 1
nu <- 3
lambda <- NULL
q <- 0.9
sigma2_init <- NULL
sample_tau <- F
sample_sigma <- T
# Initial DGP setup
n0 <- 500
p <- 10
n <- n0*(2^p)
k <- 2
p1 <- 20
noise <- 0.1
# Full factorial covariate reference frame
xtemp <- as.data.frame(as.factor(rep(0:(2^p-1),n0)))
xtemp1 <- rep(0:(2^p-1),n0)
x <- t(sapply(xtemp1,function(j) as.numeric(int2bin(j,p))))
X_superset <- x*abs(rnorm(length(x))) - (1-x)*abs(rnorm(length(x)))
# Generate outcome
M <- model.matrix(~.-1,data = xtemp)
M <- cbind(rep(1,n),M)
beta.true <- -10*abs(rnorm(ncol(M)))
beta.true[1] <- 0.5
non_zero_betas <- c(1,sample(1:ncol(M), p1-1))
beta.true[-non_zero_betas] <- 0
Y <- M %*% beta.true + rnorm(n, 0, noise)
y_superset <- as.numeric(Y>0)
# Downsample to desired n
subset_inds <- order(sample(1:nrow(X_superset), sample_size, replace = F))
X <- X_superset[subset_inds,]
y <- y_superset[subset_inds]
system.time({
bart_obj <- stochtree::bart(
X_train = X, y_train = y, alpha = alpha, beta = beta,
min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q,
sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr,
num_burnin = num_burnin, num_mcmc = num_mcmc, cutpoint_grid_size = cutpoint_grid_size,
sample_tau = sample_tau, sample_sigma = sample_sigma, random_seed = random_seed
)
})
Metadata
Metadata
Assignees
Labels
No labels