tangram_tree/
lib.rs

1/*!
2This crate implements machine learning models for regression and classification using ensembles of decision trees. It has many similarities to [LightGBM](github.com/microsoft/lightgbm), [XGBoost](github.com/xgboost/xgboost), and others, but is written in pure Rust.
3
4For an example of regression, see `benchmarks/boston.rs`.rs`. For an example of binary classification, see `benchmarks/heart_disease.rs`. For an example of multiclass classification, see `benchmarks/iris.rs`.
5*/
6
7pub use self::{
8	binary_classifier::{BinaryClassifier, BinaryClassifierTrainOutput},
9	multiclass_classifier::{MulticlassClassifier, MulticlassClassifierTrainOutput},
10	regressor::{Regressor, RegressorTrainOutput},
11};
12use bitvec::prelude::*;
13use tangram_progress_counter::ProgressCounter;
14
15mod binary_classifier;
16mod choose_best_split;
17mod compute_bin_stats;
18mod compute_binned_features;
19mod compute_binning_instructions;
20mod compute_feature_importances;
21mod multiclass_classifier;
22mod pool;
23mod rearrange_examples_index;
24mod regressor;
25pub mod serialize;
26mod shap;
27#[cfg(feature = "timing")]
28mod timing;
29mod train;
30mod train_tree;
31
32pub struct Progress<'a> {
33	pub kill_chip: &'a tangram_kill_chip::KillChip,
34	pub handle_progress_event: &'a mut dyn FnMut(TrainProgressEvent),
35}
36
37/// These are the options passed to `Regressor::train`, `BinaryClassifier::train`, and `MulticlassClassifier::train`.
38#[derive(Clone, Debug)]
39pub struct TrainOptions {
40	/// This option controls whether binned features will be laid out in row major or column major order. Each will produce the same result, but row major will be faster for datasets with more rows and fewer columns, while column major will be faster for datasets with fewer rows and more columns.
41	pub binned_features_layout: BinnedFeaturesLayout,
42	/// If true, the model will include the loss on the training data after each round.
43	pub compute_losses: bool,
44	/// This option controls early stopping. If it is `Some`, then early stopping will be enabled. If it is `None`, then early stopping will be disabled.
45	pub early_stopping_options: Option<EarlyStoppingOptions>,
46	/// This option sets the L2 regularization value for continuous splits, which helps avoid overfitting.
47	pub l2_regularization_for_continuous_splits: f32,
48	/// This option sets the L2 regularization value for discrete splits, which helps avoid overfitting.
49	pub l2_regularization_for_discrete_splits: f32,
50	/// The learning rate scales the leaf values to control the effect each tree has on the output.
51	pub learning_rate: f32,
52	/// This is the maximum depth of a single tree. If this value is `None`, the depth will not be limited.
53	pub max_depth: Option<usize>,
54	/// This is the maximum number of examples to consider when determining the bin thresholds for number features.
55	pub max_examples_for_computing_bin_thresholds: usize,
56	/// This is the maximum number of leaf nodes in a single tree.
57	pub max_leaf_nodes: usize,
58	/// This is the maximum number of rounds of training that will occur. Fewer rounds may be trained if early stopping is enabled.
59	pub max_rounds: usize,
60	/// When computing the bin thresholds for number features, this is the maximum number of bins for valid values to create. If the number of unique values in the number feature is less than this value, the thresholds will be equal to the unique values, which can improve accuracy when number features have a small set of possible values.
61	pub max_valid_bins_for_number_features: u8,
62	/// A split will only be considered valid if the number of training examples sent to each of the resulting children is at least this value.
63	pub min_examples_per_node: usize,
64	/// A node will only be split if the best split achieves at least this minimum gain.
65	pub min_gain_to_split: f32,
66	/// A split will only be considered valid if the sum of hessians in each of the resulting children is at least this value.
67	pub min_sum_hessians_per_node: f32,
68	/// When choosing which direction each enum variant should be sent in a discrete split, the enum variants are sorted by a score computed from the sum of gradients and hessians for examples with that enum variant. This smoothing factor is added to the denominator of that score.
69	pub smoothing_factor_for_discrete_bin_sorting: f32,
70}
71
72impl Default for TrainOptions {
73	fn default() -> TrainOptions {
74		TrainOptions {
75			binned_features_layout: BinnedFeaturesLayout::ColumnMajor,
76			compute_losses: false,
77			early_stopping_options: None,
78			l2_regularization_for_continuous_splits: 0.0,
79			l2_regularization_for_discrete_splits: 10.0,
80			learning_rate: 0.1,
81			max_depth: None,
82			max_leaf_nodes: 31,
83			max_rounds: 100,
84			max_valid_bins_for_number_features: 255,
85			min_examples_per_node: 20,
86			min_gain_to_split: 0.0,
87			min_sum_hessians_per_node: 1e-3,
88			max_examples_for_computing_bin_thresholds: 200_000,
89			smoothing_factor_for_discrete_bin_sorting: 10.0,
90		}
91	}
92}
93
94/// This enum defines whether binned features will be layed out in row major or column major order.
95#[derive(Clone, Copy, Debug)]
96pub enum BinnedFeaturesLayout {
97	RowMajor,
98	ColumnMajor,
99}
100
101/// The parameters in this struct control how to determine whether training should stop early after each round or epoch.
102#[derive(Clone, Debug)]
103pub struct EarlyStoppingOptions {
104	/// This is the fraction of the dataset that is set aside to compute the early stopping metric.
105	pub early_stopping_fraction: f32,
106	/// If this many rounds or epochs pass by without a significant improvement in the early stopping metric over the previous round or epoch, training will be stopped early.
107	pub n_rounds_without_improvement_to_stop: usize,
108	/// This is the minimum descrease in the early stopping metric for a round or epoch to be considered a significant improvement over the previous round or epoch.
109	pub min_decrease_in_loss_for_significant_change: f32,
110}
111
112/// This struct describes the training progress.
113#[derive(Clone, Debug)]
114pub enum TrainProgressEvent {
115	Initialize(ProgressCounter),
116	InitializeDone,
117	Train(ProgressCounter),
118	TrainDone,
119}
120
121/// Trees are stored as a `Vec` of `Node`s. Each branch in the tree has two indexes into the `Vec`, one for each of its children.
122#[derive(Clone, Debug)]
123pub struct Tree {
124	pub nodes: Vec<Node>,
125}
126
127impl Tree {
128	/// Make a prediction.
129	pub fn predict(&self, example: &[tangram_table::TableValue]) -> f32 {
130		// Start at the root node.
131		let mut node_index = 0;
132		// Traverse the tree until we get to a leaf.
133		unsafe {
134			loop {
135				match self.nodes.get_unchecked(node_index) {
136					// We made it to a leaf! The prediction is the leaf's value.
137					Node::Leaf(LeafNode { value, .. }) => return *value as f32,
138					// This branch uses a continuous split.
139					Node::Branch(BranchNode {
140						left_child_index,
141						right_child_index,
142						split:
143							BranchSplit::Continuous(BranchSplitContinuous {
144								feature_index,
145								split_value,
146								..
147							}),
148						..
149					}) => {
150						node_index = if example.get_unchecked(*feature_index).as_number().unwrap()
151							<= split_value
152						{
153							*left_child_index
154						} else {
155							*right_child_index
156						};
157					}
158					// This branch uses a discrete split.
159					Node::Branch(BranchNode {
160						left_child_index,
161						right_child_index,
162						split:
163							BranchSplit::Discrete(BranchSplitDiscrete {
164								feature_index,
165								directions,
166								..
167							}),
168						..
169					}) => {
170						let bin_index = if let Some(bin_index) =
171							example.get_unchecked(*feature_index).as_enum().unwrap()
172						{
173							bin_index.get()
174						} else {
175							0
176						};
177						let direction = (*directions.get(bin_index).unwrap()).into();
178						node_index = match direction {
179							SplitDirection::Left => *left_child_index,
180							SplitDirection::Right => *right_child_index,
181						};
182					}
183				}
184			}
185		}
186	}
187}
188
189/// A node is either a branch or a leaf.
190#[derive(Clone, Debug)]
191pub enum Node {
192	Branch(BranchNode),
193	Leaf(LeafNode),
194}
195
196impl Node {
197	pub fn as_branch(&self) -> Option<&BranchNode> {
198		match self {
199			Node::Branch(branch) => Some(branch),
200			_ => None,
201		}
202	}
203
204	pub fn as_leaf(&self) -> Option<&LeafNode> {
205		match self {
206			Node::Leaf(leaf) => Some(leaf),
207			_ => None,
208		}
209	}
210
211	pub fn examples_fraction(&self) -> f32 {
212		match self {
213			Node::Leaf(LeafNode {
214				examples_fraction, ..
215			}) => *examples_fraction,
216			Node::Branch(BranchNode {
217				examples_fraction, ..
218			}) => *examples_fraction,
219		}
220	}
221}
222
223/// A `BranchNode` is a branch in a tree.
224#[derive(Clone, Debug)]
225pub struct BranchNode {
226	/// This is the index in the tree's node vector for this node's left child.
227	pub left_child_index: usize,
228	/// This is the index in the tree's node vector for this node's right child.
229	pub right_child_index: usize,
230	/// When making predictions, an example will be sent either to the right or left child. The `split` contains the information necessary to determine which way it will go.
231	pub split: BranchSplit,
232	/// Branch nodes store the fraction of training examples that passed through them during training. This is used to compute SHAP values.
233	pub examples_fraction: f32,
234}
235
236/// A `BranchSplit` describes how examples are sent to the left or right child given their feature values. A `Continous` split is used for number features, and `Discrete` is used for enum features.
237#[derive(Clone, Debug)]
238pub enum BranchSplit {
239	Continuous(BranchSplitContinuous),
240	Discrete(BranchSplitDiscrete),
241}
242
243/// A continuous branch split takes the value of a single number feature, compares it with a `split_value`, and if the value is <= `split_value`, the example is sent left, and if it is > `split_value`, it is sent right.
244#[derive(Clone, Debug)]
245pub struct BranchSplitContinuous {
246	/// This is the index of the feature to get the value for.
247	pub feature_index: usize,
248	/// This is the threshold value of the split.
249	pub split_value: f32,
250	/// This is the direction invalid values should be sent.
251	pub invalid_values_direction: SplitDirection,
252}
253
254#[derive(Clone, Copy, Debug, PartialEq)]
255pub enum SplitDirection {
256	Left,
257	Right,
258}
259
260impl From<bool> for SplitDirection {
261	fn from(value: bool) -> Self {
262		match value {
263			false => SplitDirection::Left,
264			true => SplitDirection::Right,
265		}
266	}
267}
268
269impl From<SplitDirection> for bool {
270	fn from(value: SplitDirection) -> Self {
271		match value {
272			SplitDirection::Left => false,
273			SplitDirection::Right => true,
274		}
275	}
276}
277
278/// A discrete branch split takes the value of a single enum feature and looks up in a bitset which way the example should be sent.
279#[derive(Clone, Debug)]
280pub struct BranchSplitDiscrete {
281	/// This is the index of the feature to get the value for.
282	pub feature_index: usize,
283	/// This specifies which direction, left or right, an example should be sent, based on the value of the chosen feature.
284	pub directions: BitVec<Lsb0, u8>,
285}
286
287/// The leaves in a tree hold the values to output for examples that get sent to them.
288#[derive(Clone, Debug)]
289pub struct LeafNode {
290	/// This is the value to output.
291	pub value: f64,
292	/// Leaf nodes store the fraction of training examples that were sent to them during training. This is used to compute SHAP values.
293	pub examples_fraction: f32,
294}
295
296impl BranchSplit {
297	pub fn feature_index(&self) -> usize {
298		match self {
299			BranchSplit::Continuous(s) => s.feature_index,
300			BranchSplit::Discrete(s) => s.feature_index,
301		}
302	}
303}