Skip to content

Commit 51e606a

Browse files
cyccbxhlhuanjin.cb
authored andcommitted
Fix the NaN recall in xgboost training: Update the tag of xgboost to v2.0.3; Add train code in train iteration; Passed all examples.
1 parent 767f0ea commit 51e606a

File tree

6 files changed

+48
-12
lines changed

6 files changed

+48
-12
lines changed

examples/basic/src/main.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,15 @@ fn main() {
6666

6767
// save and load model file
6868
println!("\nSaving and loading Booster model...");
69-
booster.save("xgb.model").unwrap();
70-
let booster = Booster::load("xgb.model").unwrap();
69+
booster.save("xgb.json").unwrap();
70+
let booster = Booster::load("xgb.json").unwrap();
7171
let preds2 = booster.predict(&dtest).unwrap();
7272
assert_eq!(preds, preds2);
7373

7474
// save and load data matrix file
7575
println!("\nSaving and loading matrix data...");
7676
dtest.save("test.dmat").unwrap();
77-
let dtest2 = DMatrix::load("test.dmat").unwrap();
77+
let dtest2 = DMatrix::load_binary("test.dmat").unwrap();
7878
assert_eq!(booster.predict(&dtest2).unwrap(), preds);
7979

8080
// error handling example

examples/generalised_linear_model/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ fn main() {
1212

1313
// load train and test matrices from text files (in LibSVM format)
1414
println!("Custom objective example...");
15-
let dtrain = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap();
16-
let dtest = DMatrix::load("../../xgboost-sys/xgboost/demo/data/agaricus.txt.test").unwrap();
15+
let dtrain = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#).unwrap();
16+
let dtest = DMatrix::load(r#"{"uri": "../../xgboost-sys/xgboost/demo/data/agaricus.txt.test?format=libsvm"}"#).unwrap();
1717

1818
// configure objectives, metrics, etc.
1919
let learning_params = parameters::learning::LearningTaskParametersBuilder::default()

src/booster.rs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,29 @@ impl Booster {
148148
dmats
149149
};
150150

151-
let bst = Booster::new_with_cached_dmats(&params.booster_params, &cached_dmats)?;
152-
for i in 0..params.boost_rounds as i32 {
151+
let mut bst = Booster::new_with_cached_dmats(&params.booster_params, &cached_dmats)?;
152+
// load distributed code checkpoint from rabit
153+
let mut version = bst.load_rabit_checkpoint()?;
154+
debug!("Loaded Rabit checkpoint: version={}", version);
155+
assert!(unsafe { xgboost_sys::RabitGetWorldSize() != 1 || version == 0 });
156+
let start_iteration = version / 2;
157+
for i in start_iteration..params.boost_rounds as i32 {
158+
// distributed code: need to resume to this point
159+
// skip first update if a recovery step
160+
if version % 2 == 0 {
161+
if let Some(objective_fn) = params.custom_objective_fn {
162+
debug!("Boosting in round: {}", i);
163+
bst.update_custom(params.dtrain, objective_fn)?;
164+
} else {
165+
debug!("Updating in round: {}", i);
166+
bst.update(params.dtrain, i)?;
167+
}
168+
let _ = bst.save_rabit_checkpoint()?;
169+
version += 1;
170+
}
171+
172+
assert!(unsafe { xgboost_sys::RabitGetWorldSize() == 1 || version == xgboost_sys::RabitVersionNumber() });
173+
153174
if let Some(eval_sets) = params.evaluation_sets {
154175
let mut dmat_eval_results = bst.eval_set(eval_sets, i)?;
155176

@@ -182,6 +203,10 @@ impl Booster {
182203
}
183204
println!();
184205
}
206+
207+
// do checkpoint after evaluation, in case evaluation also updates booster.
208+
let _ = bst.save_rabit_checkpoint();
209+
version += 1;
185210
}
186211

187212
Ok(bst)
@@ -536,6 +561,16 @@ impl Booster {
536561
Ok(out_vec.join("\n"))
537562
}
538563

564+
pub(crate) fn load_rabit_checkpoint(&self) -> XGBResult<i32> {
565+
let mut version = 0;
566+
xgb_call!(xgboost_sys::XGBoosterLoadRabitCheckpoint(self.handle, &mut version))?;
567+
Ok(version)
568+
}
569+
570+
pub(crate) fn save_rabit_checkpoint(&self) -> XGBResult<()> {
571+
xgb_call!(xgboost_sys::XGBoosterSaveRabitCheckpoint(self.handle))
572+
}
573+
539574
pub fn set_param(&mut self, name: &str, value: &str) -> XGBResult<()> {
540575
let name = ffi::CString::new(name).unwrap();
541576
let value = ffi::CString::new(value).unwrap();

src/parameters/learning.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ impl Clone for Objective {
8181
impl ToString for Objective {
8282
fn to_string(&self) -> String {
8383
match *self {
84-
Objective::RegLinear => "reg:linear".to_owned(),
84+
Objective::RegLinear => "reg:squarederror".to_owned(),
8585
Objective::RegLogistic => "reg:logistic".to_owned(),
8686
Objective::BinaryLogistic => "binary:logistic".to_owned(),
8787
Objective::BinaryLogisticRaw => "binary:logitraw".to_owned(),
88-
Objective::GpuRegLinear => "gpu:reg:linear".to_owned(),
88+
Objective::GpuRegLinear => "gpu:reg:squarederror".to_owned(),
8989
Objective::GpuRegLogistic => "gpu:reg:logistic".to_owned(),
9090
Objective::GpuBinaryLogistic => "gpu:binary:logistic".to_owned(),
9191
Objective::GpuBinaryLogisticRaw => "gpu:binary:logitraw".to_owned(),

xgboost-sys/build.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ fn main() {
3333
#[cfg(not(feature = "cuda"))]
3434
let mut dst = Config::new(&xgb_root);
3535

36-
let mut dst = dst.uses_cxx11()
36+
let dst = dst.uses_cxx11()
3737
.define("BUILD_STATIC_LIB", "ON");
3838

3939
#[cfg(target_os = "macos")]
@@ -54,7 +54,6 @@ fn main() {
5454

5555
let bindings = bindgen::Builder::default()
5656
.header("wrapper.h")
57-
.blocklist_item("std::.*")// stdlib is not well supported by bindgen
5857
.clang_args(&["-x", "c++", "-std=c++11"])
5958
.clang_arg(format!("-I{}", xgb_root.join("include").display()))
6059
.clang_arg(format!("-I{}", xgb_root.join("rabit/include").display()))
@@ -86,6 +85,8 @@ fn main() {
8685
println!("cargo:rustc-link-lib=c++");
8786
println!("cargo:rustc-link-lib=dylib=omp");
8887
} else {
88+
println!("cargo:rustc-cxxflags=-std=c++17");
89+
println!("cargo:rustc-link-lib=stdc++fs");
8990
println!("cargo:rustc-link-lib=stdc++");
9091
println!("cargo:rustc-link-lib=dylib=gomp");
9192
}

xgboost-sys/xgboost

Submodule xgboost updated 797 files

0 commit comments

Comments
 (0)