rusty_ai/trees/
regressor.rs

1//! Decision Tree Regressor
2use super::{node::TreeNode, params::TreeParams};
3use crate::{
4    data::dataset::{Dataset, RealNumber},
5    metrics::errors::RegressionMetrics,
6};
7use nalgebra::{DMatrix, DVector};
8use rayon::iter::{IntoParallelIterator, ParallelIterator};
9use std::{error::Error, f64, marker::PhantomData};
10
11pub struct SplitData<T: RealNumber> {
12    pub feature_index: usize,
13    pub threshold: T,
14    pub left: Dataset<T, T>,
15    pub right: Dataset<T, T>,
16    information_gain: f64,
17}
18
19/// Decision Tree Regressor
20#[derive(Clone, Debug)]
21pub struct DecisionTreeRegressor<T: RealNumber> {
22    root: Option<Box<TreeNode<T, T>>>,
23    tree_params: TreeParams,
24
25    _marker: PhantomData<T>,
26}
27
28impl<T: RealNumber> Default for DecisionTreeRegressor<T> {
29            /// Creates a new instance of the decision tree regressor with default parameters.
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl<T: RealNumber> RegressionMetrics<T> for DecisionTreeRegressor<T> {}
36
37impl<T: RealNumber> DecisionTreeRegressor<T> {
38        /// Creates a new instance of the decision tree regressor with default parameters.
39    pub fn new() -> Self {
40        Self {
41            root: None,
42            tree_params: TreeParams::new(),
43            _marker: PhantomData,
44        }
45    }
46
47    /// Creates a new instance of the decision tree regressor with custom parameters.
48    ///
49    /// # Arguments
50    ///
51    /// * `min_samples_split` - The minimum number of samples required to split an internal node.
52    /// * `max_depth` - The maximum depth of the tree.
53    ///
54    /// # Returns
55    ///
56    /// A new instance of the decision tree regressor with the specified parameters.
57    ///
58    /// # Errors
59    ///
60    /// This method will return an error if the minimum number of samples to split is less than 2 or if the maximum depth is less than 1.
61    pub fn with_params(
62        min_samples_split: Option<u16>,
63        max_depth: Option<u16>,
64    ) -> Result<Self, Box<dyn Error>> {
65        let mut tree = Self::new();
66
67        tree.set_min_samples_split(min_samples_split.unwrap_or(2))?;
68        tree.set_max_depth(max_depth)?;
69        Ok(tree)
70    }
71
72        /// Sets the minimum number of samples required to split an internal node.
73    ///
74    /// # Arguments
75    ///
76    /// * `min_samples_split` - The minimum number of samples required to split an internal node.
77    ///
78    /// # Errors
79    ///
80    /// This method will return an error if the minimum number of samples to split is less than 2.
81    pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
82        self.tree_params.set_min_samples_split(min_samples_split)
83    }
84
85        /// Sets the maximum depth of the tree.
86    ///
87    /// # Arguments
88    ///
89    /// * `max_depth` - The maximum depth of the tree.
90    ///
91    /// # Errors
92    ///
93    /// This method will return an error if the maximum depth is less than 1.
94    pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
95        self.tree_params.set_max_depth(max_depth)
96    }
97
98        /// Returns the maximum depth of the tree.
99    pub fn max_depth(&self) -> Option<u16> {
100        self.tree_params.max_depth()
101    }
102
103        /// Returns the minimum number of samples required to split an internal node.
104    pub fn min_samples_split(&self) -> u16 {
105        self.tree_params.min_samples_split()
106    }
107
108        /// Builds the decision tree from a dataset.
109    ///
110    /// # Arguments
111    ///
112    /// * `dataset` - The dataset containing features and labels.
113    ///
114    /// # Returns
115    ///
116    /// A string indicating that the tree was built successfully.
117    ///
118    /// # Errors
119    ///
120    /// This method will return an error if the tree couldn't be built.
121    pub fn fit(&mut self, dataset: &Dataset<T, T>) -> Result<String, Box<dyn Error>> {
122        self.root = Some(Box::new(self.build_tree(
123            dataset,
124            self.max_depth().map(|_| 0),
125            self.variance(&dataset.y),
126        )?));
127        Ok("Finished building the tree.".into())
128    }
129
130    /// Predicts the labels for new data.
131    ///
132    /// # Arguments
133    ///
134    /// * `features` - The matrix of features for the new data.
135    ///
136    /// # Returns
137    ///
138    /// A vector containing the predicted target values for the new data.
139    ///
140    /// # Errors
141    ///
142    /// This method will return an error if the tree wasn't built yet.
143    pub fn predict(&self, prediction_features: &DMatrix<T>) -> Result<DVector<T>, String> {
144        if self.root.is_none() {
145            return Err("Tree wasn't built yet.".to_string());
146        }
147        let predictions: Vec<_> = prediction_features
148            .row_iter()
149            .map(|row| Self::make_prediction(row.transpose(), self.root.as_ref().unwrap()))
150            .collect();
151
152        Ok(DVector::from_vec(predictions))
153    }
154
155    fn make_prediction(features: DVector<T>, node: &TreeNode<T, T>) -> T {
156        if let Some(value) = &node.value {
157            return *value;
158        }
159        match &features[node.feature_index.unwrap()] {
160            x if x <= node.threshold.as_ref().unwrap() => {
161                return Self::make_prediction(features, node.left.as_ref().unwrap())
162            }
163            _ => return Self::make_prediction(features, node.right.as_ref().unwrap()),
164        }
165    }
166
167    fn build_tree(
168        &mut self,
169        dataset: &Dataset<T, T>,
170        current_depth: Option<u16>,
171        base_variance: f64,
172    ) -> Result<TreeNode<T, T>, Box<dyn Error>> {
173        let (x, y) = &dataset.into_parts();
174        let (num_samples, num_features) = x.shape();
175
176        let is_homogenous = self.variance(y) < 0.01 * base_variance;
177        if num_samples >= self.min_samples_split().into()
178            && current_depth <= self.max_depth()
179            && !is_homogenous
180        {
181            let splits = (0..num_features)
182                .into_par_iter()
183                .map(|feature_idx| self.get_split(dataset, feature_idx))
184                .collect::<Vec<_>>();
185
186            let valid_splits = splits
187                .into_iter()
188                .filter_map(Result::ok)
189                .collect::<Vec<_>>();
190
191            if valid_splits.is_empty() {
192                return Ok(TreeNode::new(Some(self.mean(y))));
193            }
194
195            let best_split = match valid_splits.into_iter().max_by(|split1, split2| {
196                split1
197                    .information_gain
198                    .partial_cmp(&split2.information_gain)
199                    .unwrap_or(std::cmp::Ordering::Equal)
200            }) {
201                Some(split) => split,
202                _ => {
203                    return Err("No best split found.".into());
204                }
205            };
206            let left_child = best_split.left;
207            let right_child = best_split.right;
208            if best_split.information_gain > 0.0 {
209                let new_depth = current_depth.map(|depth| depth + 1);
210                let left_node = self.build_tree(&left_child, new_depth, base_variance)?;
211                let right_node = self.build_tree(&right_child, new_depth, base_variance)?;
212                return Ok(TreeNode {
213                    feature_index: Some(best_split.feature_index),
214                    threshold: Some(best_split.threshold),
215                    left: Some(Box::new(left_node)),
216                    right: Some(Box::new(right_node)),
217                    value: None,
218                });
219            }
220        }
221
222        let leaf_value = self.mean(y);
223        Ok(TreeNode::new(Some(leaf_value)))
224    }
225
226    fn get_split(
227        &self,
228        dataset: &Dataset<T, T>,
229        feature_index: usize,
230    ) -> Result<SplitData<T>, String> {
231        let mut best_split: Option<SplitData<T>> = None;
232        let mut best_information_gain = f64::NEG_INFINITY;
233
234        let mut unique_values: Vec<_> = dataset.x.column(feature_index).iter().cloned().collect();
235        unique_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
236        unique_values.dedup();
237
238        for value in &unique_values {
239            let (left_child, right_child) = dataset.split_on_threshold(feature_index, *value);
240
241            if left_child.is_not_empty() && right_child.is_not_empty() {
242                let current_information_gain =
243                    self.calculate_variance_reduction(&dataset.y, &left_child.y, &right_child.y);
244
245                if current_information_gain > best_information_gain {
246                    best_split = Some(SplitData {
247                        feature_index,
248                        threshold: *value,
249                        left: left_child,
250                        right: right_child,
251                        information_gain: current_information_gain,
252                    });
253                    best_information_gain = current_information_gain;
254                }
255            }
256        }
257        best_split.ok_or("No split found.".into())
258    }
259
260    fn calculate_variance_reduction(
261        &self,
262        parent_y: &DVector<T>,
263        left_y: &DVector<T>,
264        right_y: &DVector<T>,
265    ) -> f64 {
266        let variance = self.variance(parent_y);
267        let left_variance = self.variance(left_y);
268        let right_variance = self.variance(right_y);
269        let num_samples = parent_y.len() as f64;
270        variance
271            - (left_variance * (left_y.len() as f64) / num_samples)
272            - (right_variance * (right_y.len() as f64) / num_samples)
273    }
274
275    fn variance(&self, y: &DVector<T>) -> f64 {
276        let mean = self.mean(y);
277        let variance = y.iter().fold(T::from_f64(0.0).unwrap(), |acc, x| {
278            acc + (*x - mean) * (*x - mean)
279        });
280        let variance_f64 = T::to_f64(&variance).unwrap();
281        variance_f64 / y.len() as f64
282    }
283
284    fn mean(&self, y: &DVector<T>) -> T {
285        let zero = T::from_f64(0.0).unwrap();
286        let sum: T = y.iter().fold(zero, |acc, x| acc + *x);
287        sum / T::from_usize(y.len()).unwrap()
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use nalgebra::DVector;
295
296    #[test]
297    fn test_mean() {
298        let y = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
299        let regressor: DecisionTreeRegressor<f64> = DecisionTreeRegressor::new();
300        let mean = regressor.mean(&y);
301        assert_eq!(mean, 3.5);
302    }
303
304    #[test]
305    fn test_variance() {
306        let y = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
307        let regressor: DecisionTreeRegressor<f64> = DecisionTreeRegressor::new();
308        let variance = regressor.variance(&y);
309        assert_eq!(variance, 2.0);
310    }
311
312    #[test]
313    fn test_calculate_variance_reduction() {
314        let parent_y = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
315        let left_y = DVector::from_vec(vec![1.0, 2.0]);
316        let right_y = DVector::from_vec(vec![3.0, 4.0, 5.0]);
317        let regressor: DecisionTreeRegressor<f64> = DecisionTreeRegressor::new();
318        let variance_reduction =
319            regressor.calculate_variance_reduction(&parent_y, &left_y, &right_y);
320        assert!(variance_reduction > 0.0);
321    }
322
323    #[test]
324    fn test_fit_and_predict() {
325        let x = DMatrix::from_vec(6, 1, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
326        let y = DVector::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0, 36.0]);
327        let dataset = Dataset::new(x, y);
328        let mut regressor: DecisionTreeRegressor<f64> = DecisionTreeRegressor::new();
329        let _ = regressor.fit(&dataset);
330
331        let test_x = DMatrix::from_vec(3, 1, vec![2.0, 3.0, 4.0]);
332        let predictions = regressor.predict(&test_x).unwrap();
333
334        assert_eq!(predictions.len(), 3);
335        assert!(predictions.iter().all(|&x| x >= 0.0));
336    }
337
338    #[test]
339    fn test_fit_and_predict_with_single_row() {
340        let x = DMatrix::from_vec(1, 2, vec![1.0, 2.0]);
341        let y = DVector::from_vec(vec![1.0]);
342        let dataset = Dataset::new(x, y);
343        let mut regressor: DecisionTreeRegressor<f64> = DecisionTreeRegressor::new();
344        let _ = regressor.fit(&dataset);
345
346        let test_x = DMatrix::from_vec(1, 2, vec![2.0, 3.0]);
347        let predictions = regressor.predict(&test_x).unwrap();
348
349        assert_eq!(predictions.len(), 1);
350        assert!(predictions.iter().all(|&x| x >= 0.0));
351    }
352}