tangram_tree/lib.rs
1/*!
2This crate implements machine learning models for regression and classification using ensembles of decision trees. It has many similarities to [LightGBM](github.com/microsoft/lightgbm), [XGBoost](github.com/xgboost/xgboost), and others, but is written in pure Rust.
3
4For an example of regression, see `benchmarks/boston.rs`.rs`. For an example of binary classification, see `benchmarks/heart_disease.rs`. For an example of multiclass classification, see `benchmarks/iris.rs`.
5*/
6
7pub use self::{
8 binary_classifier::{BinaryClassifier, BinaryClassifierTrainOutput},
9 multiclass_classifier::{MulticlassClassifier, MulticlassClassifierTrainOutput},
10 regressor::{Regressor, RegressorTrainOutput},
11};
12use bitvec::prelude::*;
13use tangram_progress_counter::ProgressCounter;
14
15mod binary_classifier;
16mod choose_best_split;
17mod compute_bin_stats;
18mod compute_binned_features;
19mod compute_binning_instructions;
20mod compute_feature_importances;
21mod multiclass_classifier;
22mod pool;
23mod rearrange_examples_index;
24mod regressor;
25pub mod serialize;
26mod shap;
27#[cfg(feature = "timing")]
28mod timing;
29mod train;
30mod train_tree;
31
32pub struct Progress<'a> {
33 pub kill_chip: &'a tangram_kill_chip::KillChip,
34 pub handle_progress_event: &'a mut dyn FnMut(TrainProgressEvent),
35}
36
37/// These are the options passed to `Regressor::train`, `BinaryClassifier::train`, and `MulticlassClassifier::train`.
38#[derive(Clone, Debug)]
39pub struct TrainOptions {
40 /// This option controls whether binned features will be laid out in row major or column major order. Each will produce the same result, but row major will be faster for datasets with more rows and fewer columns, while column major will be faster for datasets with fewer rows and more columns.
41 pub binned_features_layout: BinnedFeaturesLayout,
42 /// If true, the model will include the loss on the training data after each round.
43 pub compute_losses: bool,
44 /// This option controls early stopping. If it is `Some`, then early stopping will be enabled. If it is `None`, then early stopping will be disabled.
45 pub early_stopping_options: Option<EarlyStoppingOptions>,
46 /// This option sets the L2 regularization value for continuous splits, which helps avoid overfitting.
47 pub l2_regularization_for_continuous_splits: f32,
48 /// This option sets the L2 regularization value for discrete splits, which helps avoid overfitting.
49 pub l2_regularization_for_discrete_splits: f32,
50 /// The learning rate scales the leaf values to control the effect each tree has on the output.
51 pub learning_rate: f32,
52 /// This is the maximum depth of a single tree. If this value is `None`, the depth will not be limited.
53 pub max_depth: Option<usize>,
54 /// This is the maximum number of examples to consider when determining the bin thresholds for number features.
55 pub max_examples_for_computing_bin_thresholds: usize,
56 /// This is the maximum number of leaf nodes in a single tree.
57 pub max_leaf_nodes: usize,
58 /// This is the maximum number of rounds of training that will occur. Fewer rounds may be trained if early stopping is enabled.
59 pub max_rounds: usize,
60 /// When computing the bin thresholds for number features, this is the maximum number of bins for valid values to create. If the number of unique values in the number feature is less than this value, the thresholds will be equal to the unique values, which can improve accuracy when number features have a small set of possible values.
61 pub max_valid_bins_for_number_features: u8,
62 /// A split will only be considered valid if the number of training examples sent to each of the resulting children is at least this value.
63 pub min_examples_per_node: usize,
64 /// A node will only be split if the best split achieves at least this minimum gain.
65 pub min_gain_to_split: f32,
66 /// A split will only be considered valid if the sum of hessians in each of the resulting children is at least this value.
67 pub min_sum_hessians_per_node: f32,
68 /// When choosing which direction each enum variant should be sent in a discrete split, the enum variants are sorted by a score computed from the sum of gradients and hessians for examples with that enum variant. This smoothing factor is added to the denominator of that score.
69 pub smoothing_factor_for_discrete_bin_sorting: f32,
70}
71
72impl Default for TrainOptions {
73 fn default() -> TrainOptions {
74 TrainOptions {
75 binned_features_layout: BinnedFeaturesLayout::ColumnMajor,
76 compute_losses: false,
77 early_stopping_options: None,
78 l2_regularization_for_continuous_splits: 0.0,
79 l2_regularization_for_discrete_splits: 10.0,
80 learning_rate: 0.1,
81 max_depth: None,
82 max_leaf_nodes: 31,
83 max_rounds: 100,
84 max_valid_bins_for_number_features: 255,
85 min_examples_per_node: 20,
86 min_gain_to_split: 0.0,
87 min_sum_hessians_per_node: 1e-3,
88 max_examples_for_computing_bin_thresholds: 200_000,
89 smoothing_factor_for_discrete_bin_sorting: 10.0,
90 }
91 }
92}
93
94/// This enum defines whether binned features will be layed out in row major or column major order.
95#[derive(Clone, Copy, Debug)]
96pub enum BinnedFeaturesLayout {
97 RowMajor,
98 ColumnMajor,
99}
100
101/// The parameters in this struct control how to determine whether training should stop early after each round or epoch.
102#[derive(Clone, Debug)]
103pub struct EarlyStoppingOptions {
104 /// This is the fraction of the dataset that is set aside to compute the early stopping metric.
105 pub early_stopping_fraction: f32,
106 /// If this many rounds or epochs pass by without a significant improvement in the early stopping metric over the previous round or epoch, training will be stopped early.
107 pub n_rounds_without_improvement_to_stop: usize,
108 /// This is the minimum descrease in the early stopping metric for a round or epoch to be considered a significant improvement over the previous round or epoch.
109 pub min_decrease_in_loss_for_significant_change: f32,
110}
111
112/// This struct describes the training progress.
113#[derive(Clone, Debug)]
114pub enum TrainProgressEvent {
115 Initialize(ProgressCounter),
116 InitializeDone,
117 Train(ProgressCounter),
118 TrainDone,
119}
120
121/// Trees are stored as a `Vec` of `Node`s. Each branch in the tree has two indexes into the `Vec`, one for each of its children.
122#[derive(Clone, Debug)]
123pub struct Tree {
124 pub nodes: Vec<Node>,
125}
126
127impl Tree {
128 /// Make a prediction.
129 pub fn predict(&self, example: &[tangram_table::TableValue]) -> f32 {
130 // Start at the root node.
131 let mut node_index = 0;
132 // Traverse the tree until we get to a leaf.
133 unsafe {
134 loop {
135 match self.nodes.get_unchecked(node_index) {
136 // We made it to a leaf! The prediction is the leaf's value.
137 Node::Leaf(LeafNode { value, .. }) => return *value as f32,
138 // This branch uses a continuous split.
139 Node::Branch(BranchNode {
140 left_child_index,
141 right_child_index,
142 split:
143 BranchSplit::Continuous(BranchSplitContinuous {
144 feature_index,
145 split_value,
146 ..
147 }),
148 ..
149 }) => {
150 node_index = if example.get_unchecked(*feature_index).as_number().unwrap()
151 <= split_value
152 {
153 *left_child_index
154 } else {
155 *right_child_index
156 };
157 }
158 // This branch uses a discrete split.
159 Node::Branch(BranchNode {
160 left_child_index,
161 right_child_index,
162 split:
163 BranchSplit::Discrete(BranchSplitDiscrete {
164 feature_index,
165 directions,
166 ..
167 }),
168 ..
169 }) => {
170 let bin_index = if let Some(bin_index) =
171 example.get_unchecked(*feature_index).as_enum().unwrap()
172 {
173 bin_index.get()
174 } else {
175 0
176 };
177 let direction = (*directions.get(bin_index).unwrap()).into();
178 node_index = match direction {
179 SplitDirection::Left => *left_child_index,
180 SplitDirection::Right => *right_child_index,
181 };
182 }
183 }
184 }
185 }
186 }
187}
188
189/// A node is either a branch or a leaf.
190#[derive(Clone, Debug)]
191pub enum Node {
192 Branch(BranchNode),
193 Leaf(LeafNode),
194}
195
196impl Node {
197 pub fn as_branch(&self) -> Option<&BranchNode> {
198 match self {
199 Node::Branch(branch) => Some(branch),
200 _ => None,
201 }
202 }
203
204 pub fn as_leaf(&self) -> Option<&LeafNode> {
205 match self {
206 Node::Leaf(leaf) => Some(leaf),
207 _ => None,
208 }
209 }
210
211 pub fn examples_fraction(&self) -> f32 {
212 match self {
213 Node::Leaf(LeafNode {
214 examples_fraction, ..
215 }) => *examples_fraction,
216 Node::Branch(BranchNode {
217 examples_fraction, ..
218 }) => *examples_fraction,
219 }
220 }
221}
222
223/// A `BranchNode` is a branch in a tree.
224#[derive(Clone, Debug)]
225pub struct BranchNode {
226 /// This is the index in the tree's node vector for this node's left child.
227 pub left_child_index: usize,
228 /// This is the index in the tree's node vector for this node's right child.
229 pub right_child_index: usize,
230 /// When making predictions, an example will be sent either to the right or left child. The `split` contains the information necessary to determine which way it will go.
231 pub split: BranchSplit,
232 /// Branch nodes store the fraction of training examples that passed through them during training. This is used to compute SHAP values.
233 pub examples_fraction: f32,
234}
235
236/// A `BranchSplit` describes how examples are sent to the left or right child given their feature values. A `Continous` split is used for number features, and `Discrete` is used for enum features.
237#[derive(Clone, Debug)]
238pub enum BranchSplit {
239 Continuous(BranchSplitContinuous),
240 Discrete(BranchSplitDiscrete),
241}
242
243/// A continuous branch split takes the value of a single number feature, compares it with a `split_value`, and if the value is <= `split_value`, the example is sent left, and if it is > `split_value`, it is sent right.
244#[derive(Clone, Debug)]
245pub struct BranchSplitContinuous {
246 /// This is the index of the feature to get the value for.
247 pub feature_index: usize,
248 /// This is the threshold value of the split.
249 pub split_value: f32,
250 /// This is the direction invalid values should be sent.
251 pub invalid_values_direction: SplitDirection,
252}
253
254#[derive(Clone, Copy, Debug, PartialEq)]
255pub enum SplitDirection {
256 Left,
257 Right,
258}
259
260impl From<bool> for SplitDirection {
261 fn from(value: bool) -> Self {
262 match value {
263 false => SplitDirection::Left,
264 true => SplitDirection::Right,
265 }
266 }
267}
268
269impl From<SplitDirection> for bool {
270 fn from(value: SplitDirection) -> Self {
271 match value {
272 SplitDirection::Left => false,
273 SplitDirection::Right => true,
274 }
275 }
276}
277
278/// A discrete branch split takes the value of a single enum feature and looks up in a bitset which way the example should be sent.
279#[derive(Clone, Debug)]
280pub struct BranchSplitDiscrete {
281 /// This is the index of the feature to get the value for.
282 pub feature_index: usize,
283 /// This specifies which direction, left or right, an example should be sent, based on the value of the chosen feature.
284 pub directions: BitVec<Lsb0, u8>,
285}
286
287/// The leaves in a tree hold the values to output for examples that get sent to them.
288#[derive(Clone, Debug)]
289pub struct LeafNode {
290 /// This is the value to output.
291 pub value: f64,
292 /// Leaf nodes store the fraction of training examples that were sent to them during training. This is used to compute SHAP values.
293 pub examples_fraction: f32,
294}
295
296impl BranchSplit {
297 pub fn feature_index(&self) -> usize {
298 match self {
299 BranchSplit::Continuous(s) => s.feature_index,
300 BranchSplit::Discrete(s) => s.feature_index,
301 }
302 }
303}