xgboost/parameters/
tree.rs

1//! BoosterParameters for controlling tree boosters.
2//!
3//!
4use std::default::Default;
5
6use super::Interval;
7
8/// The tree construction algorithm used in XGBoost (see description in the
9/// [reference paper](http://arxiv.org/abs/1603.02754)).
10///
11/// Distributed and external memory version only support approximate algorithm.
12#[derive(Clone)]
13pub enum TreeMethod {
14    /// Use heuristic to choose faster one.
15    ///
16    /// * For small to medium dataset, exact greedy will be used.
17    /// * For very large-dataset, approximate algorithm will be chosen.
18    /// * Because old behavior is always use exact greedy in single machine, user will get a message when
19    ///   approximate algorithm is chosen to notify this choice.
20    Auto,
21
22    /// Exact greedy algorithm.
23    Exact,
24
25    /// Approximate greedy algorithm using sketching and histogram.
26    Approx,
27
28    /// Fast histogram optimized approximate greedy algorithm. It uses some performance improvements
29    /// such as bins caching.
30    Hist,
31
32    /// GPU implementation of exact algorithm.
33    GpuExact,
34
35    /// GPU implementation of hist algorithm.
36    GpuHist,
37}
38
39impl ToString for TreeMethod {
40    fn to_string(&self) -> String {
41        match *self {
42            TreeMethod::Auto => "auto".to_owned(),
43            TreeMethod::Exact => "exact".to_owned(),
44            TreeMethod::Approx => "approx".to_owned(),
45            TreeMethod::Hist => "hist".to_owned(),
46            TreeMethod::GpuExact => "gpu_exact".to_owned(),
47            TreeMethod::GpuHist => "gpu_hist".to_owned(),
48        }
49    }
50}
51
52impl Default for TreeMethod {
53    fn default() -> Self { TreeMethod::Auto }
54}
55
56/// Provides a modular way to construct and to modify the trees. This is an advanced parameter that is usually set
57/// automatically, depending on some other parameters. However, it could be also set explicitly by a user.
58#[derive(Clone)]
59pub enum TreeUpdater {
60    /// Non-distributed column-based construction of trees.
61    GrowColMaker,
62
63    /// Distributed tree construction with column-based data splitting mode.
64    DistCol,
65
66    /// Distributed tree construction with row-based data splitting based on global proposal of histogram counting.
67    GrowHistMaker,
68
69    /// Based on local histogram counting.
70    GrowLocalHistMaker,
71
72    /// Uses the approximate sketching algorithm.
73    GrowSkMaker,
74
75    /// Synchronizes trees in all distributed nodes.
76    Sync,
77
78    /// Refreshes tree’s statistics and/or leaf values based on the current data.
79    /// Note that no random subsampling of data rows is performed.
80    Refresh,
81
82    /// Prunes the splits where loss < min_split_loss (or gamma).
83    Prune,
84}
85
86impl ToString for TreeUpdater {
87    fn to_string(&self) -> String {
88        match *self {
89            TreeUpdater::GrowColMaker => "grow_colmaker".to_owned(),
90            TreeUpdater::DistCol => "distcol".to_owned(),
91            TreeUpdater::GrowHistMaker => "grow_histmaker".to_owned(),
92            TreeUpdater::GrowLocalHistMaker => "grow_local_histmaker".to_owned(),
93            TreeUpdater::GrowSkMaker => "grow_skmaker".to_owned(),
94            TreeUpdater::Sync => "sync".to_owned(),
95            TreeUpdater::Refresh => "refresh".to_owned(),
96            TreeUpdater::Prune => "prune".to_owned(),
97        }
98    }
99}
100
101/// A type of boosting process to run.
102#[derive(Clone)]
103pub enum ProcessType {
104    /// The normal boosting process which creates new trees.
105    Default,
106
107    /// Starts from an existing model and only updates its trees. In each boosting iteration,
108    /// a tree from the initial model is taken, a specified sequence of updater plugins is run for that tree,
109    /// and a modified tree is added to the new model. The new model would have either the same or smaller number of
110    /// trees, depending on the number of boosting iteratons performed.
111    /// Currently, the following built-in updater plugins could be meaningfully used with this process type:
112    /// 'refresh', 'prune'. With 'update', one cannot use updater plugins that create new trees.
113    Update,
114}
115
116impl ToString for ProcessType {
117    fn to_string(&self) -> String {
118        match *self {
119            ProcessType::Default => "default".to_owned(),
120            ProcessType::Update => "update".to_owned(),
121        }
122    }
123}
124
125impl Default for ProcessType {
126    fn default() -> Self { ProcessType::Default }
127}
128
129/// Controls the way new nodes are added to the tree.
130#[derive(Clone)]
131pub enum GrowPolicy {
132    /// Split at nodes closest to the root.
133    Depthwise,
134
135    /// Split at noeds with highest loss change.
136    LossGuide,
137}
138
139impl ToString for GrowPolicy {
140    fn to_string(&self) -> String {
141        match *self {
142            GrowPolicy::Depthwise => "depthwise".to_owned(),
143            GrowPolicy::LossGuide => "lossguide".to_owned(),
144        }
145    }
146}
147
148impl Default for GrowPolicy {
149    fn default() -> Self { GrowPolicy::Depthwise }
150}
151
152/// The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU.
153#[derive(Clone)]
154pub enum Predictor {
155    /// Multicore CPU prediction algorithm.
156    Cpu,
157
158    /// Prediction using GPU. Default for ‘gpu_exact’ and ‘gpu_hist’ tree method.
159    Gpu,
160}
161
162impl ToString for Predictor {
163    fn to_string(&self) -> String {
164        match *self {
165            Predictor::Cpu => "cpu_predictor".to_owned(),
166            Predictor::Gpu => "gpu_predictor".to_owned(),
167        }
168    }
169}
170
171impl Default for Predictor {
172    fn default() -> Self { Predictor::Cpu }
173}
174
175/// BoosterParameters for Tree Booster. Create using
176/// [`TreeBoosterParametersBuilder`](struct.TreeBoosterParametersBuilder.html).
177#[derive(Builder, Clone)]
178#[builder(build_fn(validate = "Self::validate"))]
179#[builder(default)]
180pub struct TreeBoosterParameters {
181    /// Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly
182    /// get the weights of new features, and eta actually shrinks the feature weights to make the boosting process
183    /// more conservative.
184    ///
185    /// * range: [0.0, 1.0]
186    /// * default: 0.3
187    eta: f32,
188
189    /// Minimum loss reduction required to make a further partition on a leaf node of the tree.
190    /// The larger, the more conservative the algorithm will be.
191    ///
192    /// * range: [0,∞]
193    /// * default: 0
194    gamma: u32,
195
196    /// Maximum depth of a tree, increase this value will make the model more complex / likely to be overfitting.
197    /// 0 indicates no limit, limit is required for depth-wise grow policy.
198    ///
199    /// * range: [0,∞]
200    /// * default: 6
201    max_depth: u32,
202
203    /// Minimum sum of instance weight (hessian) needed in a child. If the tree partition step results in a leaf
204    /// node with the sum of instance weight less than min_child_weight, then the building process will give up
205    /// further partitioning.
206    /// In linear regression mode, this simply corresponds to minimum number of instances needed to be in each node.
207    /// The larger, the more conservative the algorithm will be.
208    ///
209    /// * range: [0,∞]
210    /// * default: 1
211    min_child_weight: u32,
212
213    /// Maximum delta step we allow each tree’s weight estimation to be.
214    /// If the value is set to 0, it means there is no constraint. If it is set to a positive value,
215    /// it can help making the update step more conservative. Usually this parameter is not needed,
216    /// but it might help in logistic regression when class is extremely imbalanced.
217    /// Set it to value of 1-10 might help control the update.
218    ///
219    /// * range: [0,∞]
220    /// * default: 0
221    max_delta_step: u32,
222
223    /// Subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly collected half
224    /// of the data instances to grow trees and this will prevent overfitting.
225    ///
226    /// * range: (0, 1]
227    /// * default: 1.0
228    subsample: f32,
229
230    /// Subsample ratio of columns when constructing each tree.
231    ///
232    /// * range: (0.0, 1.0]
233    /// * default: 1.0
234    colsample_bytree: f32,
235
236    /// Subsample ratio of columns for each split, in each level.
237    ///
238    /// * range: (0.0, 1.0]
239    /// * default: 1.0
240    colsample_bylevel: f32,
241
242    /// L2 regularization term on weights, increase this value will make model more conservative.
243    ///
244    /// * default: 1
245    lambda: u32,
246
247    /// L1 regularization term on weights, increase this value will make model more conservative.
248    ///
249    /// * default: 0
250    alpha: u32,
251
252    /// The tree construction algorithm used in XGBoost.
253    #[builder(default = "TreeMethod::default()")]
254    tree_method: TreeMethod,
255
256    /// This is only used for approximate greedy algorithm.
257    /// This roughly translated into O(1 / sketch_eps) number of bins. Compared to directly select number of bins,
258    /// this comes with theoretical guarantee with sketch accuracy.
259    /// Usually user does not have to tune this. but consider setting to a lower number for more accurate enumeration.
260    ///
261    /// * range: (0.0, 1.0)
262    /// * default: 0.03
263    sketch_eps: f32,
264
265    /// Control the balance of positive and negative weights, useful for unbalanced classes.
266    /// A typical value to consider: sum(negative cases) / sum(positive cases).
267    ///
268    /// default: 1.0
269    scale_pos_weight: f32,
270
271    /// Sequence of tree updaters to run, providing a modular way to construct and to modify the trees.
272    ///
273    /// * default: [TreeUpdater::GrowColMaker, TreeUpdater::Prune]
274    updater: Vec<TreeUpdater>,
275
276    /// This is a parameter of the ‘refresh’ updater plugin. When this flag is true, tree leafs as well as tree nodes'
277    /// stats are updated. When it is false, only node stats are updated.
278    ///
279    /// * default: true
280    refresh_leaf: bool,
281
282    /// A type of boosting process to run.
283    ///
284    /// * default: ProcessType::Default
285    process_type: ProcessType,
286
287    /// Controls a way new nodes are added to the tree.  Currently supported only if tree_method is set to 'hist'.
288    grow_policy: GrowPolicy,
289
290    /// Maximum number of nodes to be added. Only relevant for the `GrowPolicy::LossGuide` grow
291    /// policy.
292    ///
293    /// * default: 0
294    max_leaves: u32,
295
296    /// This is only used if 'hist' is specified as tree_method.
297    /// Maximum number of discrete bins to bucket continuous features.
298    /// Increasing this number improves the optimality of splits at the cost of higher computation time.
299    ///
300    /// * default: 256
301    max_bin: u32,
302
303    /// The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU.
304    ///
305    /// * default: [`Predictor::Cpu`](enum.Predictor.html#variant.Cpu)
306    predictor: Predictor,
307}
308
309impl Default for TreeBoosterParameters {
310    fn default() -> Self {
311        TreeBoosterParameters {
312            eta: 0.3,
313            gamma: 0,
314            max_depth: 6,
315            min_child_weight: 1,
316            max_delta_step: 0,
317            subsample: 1.0,
318            colsample_bytree: 1.0,
319            colsample_bylevel: 1.0,
320            lambda: 1,
321            alpha: 0,
322            tree_method: TreeMethod::default(),
323            sketch_eps: 0.03,
324            scale_pos_weight: 1.0,
325            updater: vec![TreeUpdater::GrowColMaker, TreeUpdater::Prune],
326            refresh_leaf: true,
327            process_type: ProcessType::default(),
328            grow_policy: GrowPolicy::default(),
329            max_leaves: 0,
330            max_bin: 256,
331            predictor: Predictor::default(),
332        }
333    }
334}
335
336impl TreeBoosterParameters {
337    pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> {
338        let mut v = Vec::new();
339
340        v.push(("booster".to_owned(), "gbtree".to_owned()));
341
342        v.push(("eta".to_owned(), self.eta.to_string()));
343        v.push(("gamma".to_owned(), self.gamma.to_string()));
344        v.push(("max_depth".to_owned(), self.max_depth.to_string()));
345        v.push(("min_child_weight".to_owned(), self.min_child_weight.to_string()));
346        v.push(("max_delta_step".to_owned(), self.max_delta_step.to_string()));
347        v.push(("subsample".to_owned(), self.subsample.to_string()));
348        v.push(("colsample_bytree".to_owned(), self.colsample_bytree.to_string()));
349        v.push(("colsample_bylevel".to_owned(), self.colsample_bylevel.to_string()));
350        v.push(("lambda".to_owned(), self.lambda.to_string()));
351        v.push(("alpha".to_owned(), self.alpha.to_string()));
352        v.push(("tree_method".to_owned(), self.tree_method.to_string()));
353        v.push(("sketch_eps".to_owned(), self.sketch_eps.to_string()));
354        v.push(("scale_pos_weight".to_owned(), self.scale_pos_weight.to_string()));
355        v.push(("updater".to_owned(), self.updater.iter().map(|u| u.to_string()).collect::<Vec<String>>().join(",")));
356        v.push(("refresh_leaf".to_owned(), (self.refresh_leaf as u8).to_string()));
357        v.push(("process_type".to_owned(), self.process_type.to_string()));
358        v.push(("grow_policy".to_owned(), self.grow_policy.to_string()));
359        v.push(("max_leaves".to_owned(), self.max_leaves.to_string()));
360        v.push(("max_bin".to_owned(), self.max_bin.to_string()));
361        v.push(("predictor".to_owned(), self.predictor.to_string()));
362
363        v
364    }
365}
366
367impl TreeBoosterParametersBuilder {
368    fn validate(&self) -> Result<(), String> {
369        Interval::new_closed_closed(0.0, 1.0).validate(&self.eta, "eta")?;
370        Interval::new_open_closed(0.0, 1.0).validate(&self.subsample, "subsample")?;
371        Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bytree, "colsample_bytree")?;
372        Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bylevel, "colsample_bylevel")?;
373        Interval::new_open_open(0.0, 1.0).validate(&self.sketch_eps, "sketch_eps")?;
374        Ok(())
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn tree_params() {
384        let p = TreeBoosterParameters::default();
385        assert_eq!(p.eta, 0.3);
386        let p = TreeBoosterParametersBuilder::default().build().unwrap();
387        assert_eq!(p.eta, 0.3);
388    }
389}