xgboost/parameters/tree.rs
1//! BoosterParameters for controlling tree boosters.
2//!
3//!
4use std::default::Default;
5
6use super::Interval;
7
8/// The tree construction algorithm used in XGBoost (see description in the
9/// [reference paper](http://arxiv.org/abs/1603.02754)).
10///
11/// Distributed and external memory version only support approximate algorithm.
12#[derive(Clone)]
13pub enum TreeMethod {
14 /// Use heuristic to choose faster one.
15 ///
16 /// * For small to medium dataset, exact greedy will be used.
17 /// * For very large-dataset, approximate algorithm will be chosen.
18 /// * Because old behavior is always use exact greedy in single machine, user will get a message when
19 /// approximate algorithm is chosen to notify this choice.
20 Auto,
21
22 /// Exact greedy algorithm.
23 Exact,
24
25 /// Approximate greedy algorithm using sketching and histogram.
26 Approx,
27
28 /// Fast histogram optimized approximate greedy algorithm. It uses some performance improvements
29 /// such as bins caching.
30 Hist,
31
32 /// GPU implementation of exact algorithm.
33 GpuExact,
34
35 /// GPU implementation of hist algorithm.
36 GpuHist,
37}
38
39impl ToString for TreeMethod {
40 fn to_string(&self) -> String {
41 match *self {
42 TreeMethod::Auto => "auto".to_owned(),
43 TreeMethod::Exact => "exact".to_owned(),
44 TreeMethod::Approx => "approx".to_owned(),
45 TreeMethod::Hist => "hist".to_owned(),
46 TreeMethod::GpuExact => "gpu_exact".to_owned(),
47 TreeMethod::GpuHist => "gpu_hist".to_owned(),
48 }
49 }
50}
51
52impl Default for TreeMethod {
53 fn default() -> Self { TreeMethod::Auto }
54}
55
56/// Provides a modular way to construct and to modify the trees. This is an advanced parameter that is usually set
57/// automatically, depending on some other parameters. However, it could be also set explicitly by a user.
58#[derive(Clone)]
59pub enum TreeUpdater {
60 /// Non-distributed column-based construction of trees.
61 GrowColMaker,
62
63 /// Distributed tree construction with column-based data splitting mode.
64 DistCol,
65
66 /// Distributed tree construction with row-based data splitting based on global proposal of histogram counting.
67 GrowHistMaker,
68
69 /// Based on local histogram counting.
70 GrowLocalHistMaker,
71
72 /// Uses the approximate sketching algorithm.
73 GrowSkMaker,
74
75 /// Synchronizes trees in all distributed nodes.
76 Sync,
77
78 /// Refreshes tree’s statistics and/or leaf values based on the current data.
79 /// Note that no random subsampling of data rows is performed.
80 Refresh,
81
82 /// Prunes the splits where loss < min_split_loss (or gamma).
83 Prune,
84}
85
86impl ToString for TreeUpdater {
87 fn to_string(&self) -> String {
88 match *self {
89 TreeUpdater::GrowColMaker => "grow_colmaker".to_owned(),
90 TreeUpdater::DistCol => "distcol".to_owned(),
91 TreeUpdater::GrowHistMaker => "grow_histmaker".to_owned(),
92 TreeUpdater::GrowLocalHistMaker => "grow_local_histmaker".to_owned(),
93 TreeUpdater::GrowSkMaker => "grow_skmaker".to_owned(),
94 TreeUpdater::Sync => "sync".to_owned(),
95 TreeUpdater::Refresh => "refresh".to_owned(),
96 TreeUpdater::Prune => "prune".to_owned(),
97 }
98 }
99}
100
101/// A type of boosting process to run.
102#[derive(Clone)]
103pub enum ProcessType {
104 /// The normal boosting process which creates new trees.
105 Default,
106
107 /// Starts from an existing model and only updates its trees. In each boosting iteration,
108 /// a tree from the initial model is taken, a specified sequence of updater plugins is run for that tree,
109 /// and a modified tree is added to the new model. The new model would have either the same or smaller number of
110 /// trees, depending on the number of boosting iteratons performed.
111 /// Currently, the following built-in updater plugins could be meaningfully used with this process type:
112 /// 'refresh', 'prune'. With 'update', one cannot use updater plugins that create new trees.
113 Update,
114}
115
116impl ToString for ProcessType {
117 fn to_string(&self) -> String {
118 match *self {
119 ProcessType::Default => "default".to_owned(),
120 ProcessType::Update => "update".to_owned(),
121 }
122 }
123}
124
125impl Default for ProcessType {
126 fn default() -> Self { ProcessType::Default }
127}
128
129/// Controls the way new nodes are added to the tree.
130#[derive(Clone)]
131pub enum GrowPolicy {
132 /// Split at nodes closest to the root.
133 Depthwise,
134
135 /// Split at noeds with highest loss change.
136 LossGuide,
137}
138
139impl ToString for GrowPolicy {
140 fn to_string(&self) -> String {
141 match *self {
142 GrowPolicy::Depthwise => "depthwise".to_owned(),
143 GrowPolicy::LossGuide => "lossguide".to_owned(),
144 }
145 }
146}
147
148impl Default for GrowPolicy {
149 fn default() -> Self { GrowPolicy::Depthwise }
150}
151
152/// The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU.
153#[derive(Clone)]
154pub enum Predictor {
155 /// Multicore CPU prediction algorithm.
156 Cpu,
157
158 /// Prediction using GPU. Default for ‘gpu_exact’ and ‘gpu_hist’ tree method.
159 Gpu,
160}
161
162impl ToString for Predictor {
163 fn to_string(&self) -> String {
164 match *self {
165 Predictor::Cpu => "cpu_predictor".to_owned(),
166 Predictor::Gpu => "gpu_predictor".to_owned(),
167 }
168 }
169}
170
171impl Default for Predictor {
172 fn default() -> Self { Predictor::Cpu }
173}
174
175/// BoosterParameters for Tree Booster. Create using
176/// [`TreeBoosterParametersBuilder`](struct.TreeBoosterParametersBuilder.html).
177#[derive(Builder, Clone)]
178#[builder(build_fn(validate = "Self::validate"))]
179#[builder(default)]
180pub struct TreeBoosterParameters {
181 /// Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly
182 /// get the weights of new features, and eta actually shrinks the feature weights to make the boosting process
183 /// more conservative.
184 ///
185 /// * range: [0.0, 1.0]
186 /// * default: 0.3
187 eta: f32,
188
189 /// Minimum loss reduction required to make a further partition on a leaf node of the tree.
190 /// The larger, the more conservative the algorithm will be.
191 ///
192 /// * range: [0,∞]
193 /// * default: 0
194 gamma: u32,
195
196 /// Maximum depth of a tree, increase this value will make the model more complex / likely to be overfitting.
197 /// 0 indicates no limit, limit is required for depth-wise grow policy.
198 ///
199 /// * range: [0,∞]
200 /// * default: 6
201 max_depth: u32,
202
203 /// Minimum sum of instance weight (hessian) needed in a child. If the tree partition step results in a leaf
204 /// node with the sum of instance weight less than min_child_weight, then the building process will give up
205 /// further partitioning.
206 /// In linear regression mode, this simply corresponds to minimum number of instances needed to be in each node.
207 /// The larger, the more conservative the algorithm will be.
208 ///
209 /// * range: [0,∞]
210 /// * default: 1
211 min_child_weight: u32,
212
213 /// Maximum delta step we allow each tree’s weight estimation to be.
214 /// If the value is set to 0, it means there is no constraint. If it is set to a positive value,
215 /// it can help making the update step more conservative. Usually this parameter is not needed,
216 /// but it might help in logistic regression when class is extremely imbalanced.
217 /// Set it to value of 1-10 might help control the update.
218 ///
219 /// * range: [0,∞]
220 /// * default: 0
221 max_delta_step: u32,
222
223 /// Subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly collected half
224 /// of the data instances to grow trees and this will prevent overfitting.
225 ///
226 /// * range: (0, 1]
227 /// * default: 1.0
228 subsample: f32,
229
230 /// Subsample ratio of columns when constructing each tree.
231 ///
232 /// * range: (0.0, 1.0]
233 /// * default: 1.0
234 colsample_bytree: f32,
235
236 /// Subsample ratio of columns for each split, in each level.
237 ///
238 /// * range: (0.0, 1.0]
239 /// * default: 1.0
240 colsample_bylevel: f32,
241
242 /// L2 regularization term on weights, increase this value will make model more conservative.
243 ///
244 /// * default: 1
245 lambda: u32,
246
247 /// L1 regularization term on weights, increase this value will make model more conservative.
248 ///
249 /// * default: 0
250 alpha: u32,
251
252 /// The tree construction algorithm used in XGBoost.
253 #[builder(default = "TreeMethod::default()")]
254 tree_method: TreeMethod,
255
256 /// This is only used for approximate greedy algorithm.
257 /// This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select number of bins,
258 /// this comes with theoretical guarantee with sketch accuracy.
259 /// Usually user does not have to tune this. but consider setting to a lower number for more accurate enumeration.
260 ///
261 /// * range: (0.0, 1.0)
262 /// * default: 0.03
263 sketch_eps: f32,
264
265 /// Control the balance of positive and negative weights, useful for unbalanced classes.
266 /// A typical value to consider: sum(negative cases) / sum(positive cases).
267 ///
268 /// default: 1.0
269 scale_pos_weight: f32,
270
271 /// Sequence of tree updaters to run, providing a modular way to construct and to modify the trees.
272 ///
273 /// * default: [TreeUpdater::GrowColMaker, TreeUpdater::Prune]
274 updater: Vec<TreeUpdater>,
275
276 /// This is a parameter of the ‘refresh’ updater plugin. When this flag is true, tree leafs as well as tree nodes'
277 /// stats are updated. When it is false, only node stats are updated.
278 ///
279 /// * default: true
280 refresh_leaf: bool,
281
282 /// A type of boosting process to run.
283 ///
284 /// * default: ProcessType::Default
285 process_type: ProcessType,
286
287 /// Controls a way new nodes are added to the tree. Currently supported only if tree_method is set to 'hist'.
288 grow_policy: GrowPolicy,
289
290 /// Maximum number of nodes to be added. Only relevant for the `GrowPolicy::LossGuide` grow
291 /// policy.
292 ///
293 /// * default: 0
294 max_leaves: u32,
295
296 /// This is only used if 'hist' is specified as tree_method.
297 /// Maximum number of discrete bins to bucket continuous features.
298 /// Increasing this number improves the optimality of splits at the cost of higher computation time.
299 ///
300 /// * default: 256
301 max_bin: u32,
302
303 /// The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU.
304 ///
305 /// * default: [`Predictor::Cpu`](enum.Predictor.html#variant.Cpu)
306 predictor: Predictor,
307}
308
309impl Default for TreeBoosterParameters {
310 fn default() -> Self {
311 TreeBoosterParameters {
312 eta: 0.3,
313 gamma: 0,
314 max_depth: 6,
315 min_child_weight: 1,
316 max_delta_step: 0,
317 subsample: 1.0,
318 colsample_bytree: 1.0,
319 colsample_bylevel: 1.0,
320 lambda: 1,
321 alpha: 0,
322 tree_method: TreeMethod::default(),
323 sketch_eps: 0.03,
324 scale_pos_weight: 1.0,
325 updater: vec![TreeUpdater::GrowColMaker, TreeUpdater::Prune],
326 refresh_leaf: true,
327 process_type: ProcessType::default(),
328 grow_policy: GrowPolicy::default(),
329 max_leaves: 0,
330 max_bin: 256,
331 predictor: Predictor::default(),
332 }
333 }
334}
335
336impl TreeBoosterParameters {
337 pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> {
338 let mut v = Vec::new();
339
340 v.push(("booster".to_owned(), "gbtree".to_owned()));
341
342 v.push(("eta".to_owned(), self.eta.to_string()));
343 v.push(("gamma".to_owned(), self.gamma.to_string()));
344 v.push(("max_depth".to_owned(), self.max_depth.to_string()));
345 v.push(("min_child_weight".to_owned(), self.min_child_weight.to_string()));
346 v.push(("max_delta_step".to_owned(), self.max_delta_step.to_string()));
347 v.push(("subsample".to_owned(), self.subsample.to_string()));
348 v.push(("colsample_bytree".to_owned(), self.colsample_bytree.to_string()));
349 v.push(("colsample_bylevel".to_owned(), self.colsample_bylevel.to_string()));
350 v.push(("lambda".to_owned(), self.lambda.to_string()));
351 v.push(("alpha".to_owned(), self.alpha.to_string()));
352 v.push(("tree_method".to_owned(), self.tree_method.to_string()));
353 v.push(("sketch_eps".to_owned(), self.sketch_eps.to_string()));
354 v.push(("scale_pos_weight".to_owned(), self.scale_pos_weight.to_string()));
355 v.push(("updater".to_owned(), self.updater.iter().map(|u| u.to_string()).collect::<Vec<String>>().join(",")));
356 v.push(("refresh_leaf".to_owned(), (self.refresh_leaf as u8).to_string()));
357 v.push(("process_type".to_owned(), self.process_type.to_string()));
358 v.push(("grow_policy".to_owned(), self.grow_policy.to_string()));
359 v.push(("max_leaves".to_owned(), self.max_leaves.to_string()));
360 v.push(("max_bin".to_owned(), self.max_bin.to_string()));
361 v.push(("predictor".to_owned(), self.predictor.to_string()));
362
363 v
364 }
365}
366
367impl TreeBoosterParametersBuilder {
368 fn validate(&self) -> Result<(), String> {
369 Interval::new_closed_closed(0.0, 1.0).validate(&self.eta, "eta")?;
370 Interval::new_open_closed(0.0, 1.0).validate(&self.subsample, "subsample")?;
371 Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bytree, "colsample_bytree")?;
372 Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bylevel, "colsample_bylevel")?;
373 Interval::new_open_open(0.0, 1.0).validate(&self.sketch_eps, "sketch_eps")?;
374 Ok(())
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn tree_params() {
384 let p = TreeBoosterParameters::default();
385 assert_eq!(p.eta, 0.3);
386 let p = TreeBoosterParametersBuilder::default().build().unwrap();
387 assert_eq!(p.eta, 0.3);
388 }
389}