Skip to content

Support for multi threading in various parts of the sampling algorithm #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
* text=auto

*.c text eol=lf
*.h text eol=lf
*.cc text eol=lf
*.cuh text eol=lf
*.cu text eol=lf
*.py text eol=lf
*.txt text eol=lf
*.R text eol=lf

*.sh text eol=lf
*.ac text eol=lf

*.md text eol=lf
*.csv text eol=lf
7 changes: 7 additions & 0 deletions .github/workflows/cpp-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ jobs:
shell: bash
run: |
echo "build-output-dir=${{ github.workspace }}/build" >> "$GITHUB_OUTPUT"

- name: Set up dependencies (linux clang)
# Set up openMP on ubuntu-latest with clang compiler toolset (doesn't ship with the compiler suite like GCC and MSVC)
if: matrix.os == 'ubuntu-latest' && matrix.cpp_compiler == 'clang++'
run: |
sudo apt-get update && sudo apt-get install -y libomp-dev

- name: Configure CMake
# Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make.
Expand All @@ -69,6 +75,7 @@ jobs:
-DUSE_SANITIZER=OFF
-DBUILD_TEST=ON
-DBUILD_DEBUG_TARGETS=OFF
-DUSE_OPENMP=ON
-S ${{ github.workspace }}

- name: Build
Expand Down
13 changes: 12 additions & 1 deletion .github/workflows/pypi-wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,37 @@ jobs:
include:
- os: ubuntu-latest
cibw_archs: "x86_64"
macos_deployment_target: "10.13" # Unused, just setting the variable as a placeholder
- os: ubuntu-24.04-arm
cibw_archs: "aarch64"
macos_deployment_target: "10.13" # Unused, just setting the variable as a placeholder
- os: windows-latest
cibw_archs: "auto64"
macos_deployment_target: "10.13" # Unused, just setting the variable as a placeholder
- os: macos-13
cibw_archs: "x86_64"
macos_deployment_target: "13.0"
- os: macos-14
cibw_archs: "arm64"
macos_deployment_target: "14.0"

steps:
- uses: actions/checkout@v4
with:
submodules: 'recursive'

- name: Set up openmp (macos)
# Set up openMP on MacOS since it doesn't ship with the apple clang compiler suite
if: matrix.os == 'macos-13' || matrix.os == 'macos-14'
run: |
brew install libomp

- name: Build wheels
uses: pypa/cibuildwheel@v2.23.2
env:
CIBW_SKIP: "pp* *-musllinux_* *-win32"
CIBW_ARCHS: ${{ matrix.cibw_archs }}
MACOSX_DEPLOYMENT_TARGET: "10.13"
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos_deployment_target }}

- uses: actions/upload-artifact@v4
with:
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ jobs:
with:
python-version: "3.10"
cache: "pip"

- name: Set up openmp (macos)
# Set up openMP on MacOS since it doesn't ship with the apple clang compiler suite
if: matrix.os == 'macos-latest'
run: |
brew install libomp

- name: Install Package with Relevant Dependencies
run: |
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/r-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ jobs:
os: [ubuntu-latest, windows-latest, macos-latest]

steps:
- name: Prevent conversion of line endings on Windows
if: startsWith(matrix.os, 'windows')
shell: pwsh
run: git config --global core.autocrlf false

- uses: actions/checkout@v4
with:
submodules: 'recursive'
Expand Down
50 changes: 50 additions & 0 deletions .github/workflows/regression-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
on:
workflow_dispatch:

name: Running stochtree on benchmark datasets

jobs:
stochtree_r_bart:
name: stochtree-r-bart-regression-test
runs-on: ubuntu-latest

steps:
- name: Checkout stochtree repo
uses: actions/checkout@v4
with:
submodules: 'recursive'

- name: Setup pandoc
uses: r-lib/actions/setup-pandoc@v2

- name: Setup R
uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true

- name: Create a properly formatted version of the stochtree R package in a subfolder
run: |
Rscript cran-bootstrap.R 0 0 1

- name: Setup R dependencies
uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::testthat, any::decor, local::stochtree_cran

- name: Create output directory for regression test results
run: |
mkdir -p tools/regression/stochtree_bart_r_results

- name: Run the regression test benchmark suite
run: |
Rscript tools/regression/regression_test_dispatch_bart.R

- name: Collate and analyze regression test results
run: |
Rscript tools/regression/regression_test_analysis_bart.R

- name: Store benchmark test results as an artifact of the run
uses: actions/upload-artifact@v4
with:
name: stochtree-r-bart-summary
path: tools/regression/stochtree_bart_r_results/stochtree_bart_r_summary.csv
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ po/*~
# RStudio Connect folder
rsconnect/

# Configuration files generated by R build
config.status
config.log
src/Makevars

## Python gitignore

# Byte-compiled / optimized / DLL files
Expand Down
83 changes: 69 additions & 14 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Build options
option(USE_DEBUG "Set to ON for Debug mode" OFF)
option(USE_DEBUG "Build with debug symbols and without optimization" OFF)
option(USE_SANITIZER "Use santizer flags" OFF)
option(USE_OPENMP "Use openMP" ON)
option(USE_HOMEBREW_FALLBACK "(macOS-only) also look in 'brew --prefix' for libraries (e.g. OpenMP)" ON)
option(BUILD_TEST "Build C++ tests with Google Test" OFF)
option(BUILD_DEBUG_TARGETS "Build Standalone C++ Programs for Debugging" ON)
option(BUILD_PYTHON "Build Shared Library for Python Package" OFF)
Expand All @@ -9,8 +11,8 @@ option(BUILD_PYTHON "Build Shared Library for Python Package" OFF)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# Default to CMake 3.16
cmake_minimum_required(VERSION 3.16)
# Default to CMake 3.20
cmake_minimum_required(VERSION 3.20)

# Define the project
project(stochtree LANGUAGES C CXX)
Expand All @@ -34,6 +36,13 @@ if(USE_DEBUG)
add_definitions(-DDEBUG)
endif()

# Linker flags (empty by default, updated if using openmp)
set(
STOCHTREE_LINK_FLAGS
""
)

# Unix / MinGW compiler flags
if(UNIX OR MINGW OR CYGWIN)
set(
CMAKE_CXX_FLAGS
Expand All @@ -42,11 +51,12 @@ if(UNIX OR MINGW OR CYGWIN)
if(USE_DEBUG)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O3")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unknown-pragmas -Wno-unused-private-field")
endif()

# MSVC compiler flags
if(MSVC)
set(
variables
Expand All @@ -72,6 +82,33 @@ else()
endif()
endif()

# OpenMP
if(USE_OPENMP)
add_definitions(-DSTOCHTREE_OPENMP_AVAILABLE)
if(APPLE)
find_package(OpenMP)
if(NOT OpenMP_FOUND)
if(USE_HOMEBREW_FALLBACK)
execute_process(COMMAND brew --prefix libomp
OUTPUT_VARIABLE HOMEBREW_LIBOMP_PREFIX
OUTPUT_STRIP_TRAILING_WHITESPACE)
set(OpenMP_C_FLAGS "-Xclang -fopenmp -I${HOMEBREW_LIBOMP_PREFIX}/include")
set(OpenMP_CXX_FLAGS "-Xclang -fopenmp -I${HOMEBREW_LIBOMP_PREFIX}/include")
set(OpenMP_C_INCLUDE_DIR "")
set(OpenMP_CXX_INCLUDE_DIR "")
set(OpenMP_C_LIB_NAMES libomp)
set(OpenMP_CXX_LIB_NAMES libomp)
set(OpenMP_libomp_LIBRARY ${HOMEBREW_LIBOMP_PREFIX}/lib/libomp.dylib)
endif()
find_package(OpenMP REQUIRED)
endif()
else()
find_package(OpenMP REQUIRED)
endif()
# Update flags with openmp
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()

# Header file directory
set(StochTree_HEADER_DIR ${PROJECT_SOURCE_DIR}/include)

Expand All @@ -80,6 +117,8 @@ set(BOOSTMATH_HEADER_DIR ${PROJECT_SOURCE_DIR}/deps/boost_math/include)

# Eigen header file directory
set(EIGEN_HEADER_DIR ${PROJECT_SOURCE_DIR}/deps/eigen)
add_definitions(-DEIGEN_MPL2_ONLY)
add_definitions(-DEIGEN_DONT_PARALLELIZE)

# fast_double_parser header file directory
set(FAST_DOUBLE_PARSER_HEADER_DIR ${PROJECT_SOURCE_DIR}/deps/fast_double_parser/include)
Expand Down Expand Up @@ -109,10 +148,11 @@ file(
add_library(stochtree_objs OBJECT ${SOURCES})

# Include the headers in the source library
target_include_directories(stochtree_objs PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})

if(APPLE)
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
if(USE_OPENMP)
target_include_directories(stochtree_objs PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
target_link_libraries(stochtree_objs PRIVATE ${OpenMP_libomp_LIBRARY})
else()
target_include_directories(stochtree_objs PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
endif()

# Python shared library
Expand All @@ -122,8 +162,13 @@ if (BUILD_PYTHON)
pybind11_add_module(stochtree_cpp src/py_stochtree.cpp)

# Link to C++ source and headers
target_include_directories(stochtree_cpp PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
target_link_libraries(stochtree_cpp PRIVATE stochtree_objs)
if(USE_OPENMP)
target_include_directories(stochtree_cpp PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
target_link_libraries(stochtree_cpp PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY})
else()
target_include_directories(stochtree_cpp PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
target_link_libraries(stochtree_cpp PRIVATE stochtree_objs)
endif()

# EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a
# define (VERSION_INFO) here.
Expand Down Expand Up @@ -154,8 +199,13 @@ if(BUILD_TEST)
file(GLOB CPP_TEST_SOURCES test/cpp/*.cpp)
add_executable(teststochtree ${CPP_TEST_SOURCES})
set(STOCHTREE_TEST_HEADER_DIR ${PROJECT_SOURCE_DIR}/test/cpp)
target_include_directories(teststochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${STOCHTREE_TEST_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
target_link_libraries(teststochtree PRIVATE stochtree_objs GTest::gtest_main)
if(USE_OPENMP)
target_include_directories(teststochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${STOCHTREE_TEST_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
target_link_libraries(teststochtree PRIVATE stochtree_objs GTest::gtest_main ${OpenMP_libomp_LIBRARY})
else()
target_include_directories(teststochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${STOCHTREE_TEST_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
target_link_libraries(teststochtree PRIVATE stochtree_objs GTest::gtest_main)
endif()
gtest_discover_tests(teststochtree)
endif()

Expand All @@ -164,7 +214,12 @@ if(BUILD_DEBUG_TARGETS)
# Build test suite
add_executable(debugstochtree debug/api_debug.cpp)
set(StochTree_DEBUG_HEADER_DIR ${PROJECT_SOURCE_DIR}/debug)
target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
target_link_libraries(debugstochtree PRIVATE stochtree_objs)
if(USE_OPENMP)
target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
target_link_libraries(debugstochtree PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY})
else()
target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
target_link_libraries(debugstochtree PRIVATE stochtree_objs)
endif()
endif()

8 changes: 6 additions & 2 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
#' - `num_threads` Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads.
#'
#' @param mean_forest_params (Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
Expand Down Expand Up @@ -130,7 +131,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
rfx_working_parameter_prior_cov = NULL,
rfx_group_parameter_prior_cov = NULL,
rfx_variance_prior_shape = 1,
rfx_variance_prior_scale = 1
rfx_variance_prior_scale = 1,
num_threads = -1
)
general_params_updated <- preprocessParams(
general_params_default, general_params
Expand Down Expand Up @@ -186,6 +188,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov
rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape
rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale
num_threads <- general_params_updated$num_threads

# 2. Mean forest parameters
num_trees_mean <- mean_forest_params_updated$num_trees
Expand Down Expand Up @@ -795,7 +798,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
forest_model_mean$sample_one_iteration(
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
global_model_config = global_model_config, num_threads = num_threads,
keep_forest = keep_sample, gfr = TRUE
)

# Cache train set predictions since they are already computed during sampling
Expand Down
Loading
Loading