Skip to content

Scaling issues with large trees and large datasets on windows #77

@andrewherren

Description

@andrewherren

In some simulation studies, stochtree (R package) is dramatically slower in Windows than in MacOS or Linux [all are running on performant hardware, but of course not completely apples-to-apples]. Features that appear to trigger this massive performance differential:

  1. Data generating processes that encourage deep trees (because of deep interactions between features)
  2. Large sample sizes that support growing deep trees
  3. A large cutpoint grid in the grow-from-root algorithm (the cutpoint_grid_size parameter in the stochtree::bart and stochtree::bcf function signatures)

To view this performance gap, run the following code on both Windows and MacOS / Linux

# Load libraries
library(stochtree)
library(rnn)

# Random seed
random_seed <- 1234
set.seed(random_seed)

# Fixed parameters
sample_size <- 500000
alpha <- 1.0
beta <- 0.1
ntree <- 50
num_iter <- 10
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 10
cutpoint_grid_size <- 10
min_samples_leaf <- 1
nu <- 3
lambda <- NULL
q <- 0.9
sigma2_init <- NULL
sample_tau <- F
sample_sigma <- T

# Initial DGP setup
n0 <- 500
p <- 10
n <- n0*(2^p)
k <- 2
p1 <- 20
noise <- 0.1

# Full factorial covariate reference frame
xtemp <- as.data.frame(as.factor(rep(0:(2^p-1),n0)))
xtemp1 <- rep(0:(2^p-1),n0)
x <- t(sapply(xtemp1,function(j) as.numeric(int2bin(j,p))))
X_superset <- x*abs(rnorm(length(x))) - (1-x)*abs(rnorm(length(x)))

# Generate outcome
M <- model.matrix(~.-1,data = xtemp)
M <- cbind(rep(1,n),M)
beta.true <- -10*abs(rnorm(ncol(M)))
beta.true[1] <- 0.5
non_zero_betas <- c(1,sample(1:ncol(M), p1-1))   
beta.true[-non_zero_betas] <- 0      
Y <- M %*% beta.true + rnorm(n, 0, noise)
y_superset <- as.numeric(Y>0)

# Downsample to desired n
subset_inds <- order(sample(1:nrow(X_superset), sample_size, replace = F))
X <- X_superset[subset_inds,]
y <- y_superset[subset_inds]

system.time({
    bart_obj <- stochtree::bart(
        X_train = X, y_train = y, alpha = alpha, beta = beta, 
        min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, 
        sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, 
        num_burnin = num_burnin, num_mcmc = num_mcmc, cutpoint_grid_size = cutpoint_grid_size, 
        sample_tau = sample_tau, sample_sigma = sample_sigma, random_seed = random_seed
    )
})

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions