rusty_ai/forests/
regressor.rs

1use std::error::Error;
2
3use nalgebra::{DMatrix, DVector};
4use rand::{rngs::StdRng, Rng, SeedableRng};
5use rayon::prelude::*;
6
7use crate::{
8    data::dataset::{Dataset, RealNumber},
9    metrics::errors::RegressionMetrics,
10    trees::{params::TreeParams, regressor::DecisionTreeRegressor},
11};
12
13use super::params::ForestParams;
14
15#[derive(Clone, Debug)]
16pub struct RandomForestRegressor<T: RealNumber> {
17    forest_params: ForestParams<DecisionTreeRegressor<T>>,
18    tree_params: TreeParams,
19}
20
21impl<T: RealNumber> Default for RandomForestRegressor<T> {
22    /// Creates a new `RandomForestRegressor` with default parameters.
23    ///
24    /// # Returns
25    ///
26    /// A new instance of the `RandomForestRegressor`.
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl<T: RealNumber> RegressionMetrics<T> for RandomForestRegressor<T> {}
33
34impl<T: RealNumber> RandomForestRegressor<T> {
35    /// Creates a new `RandomForestRegressor` with default parameters.
36    ///
37    /// # Returns
38    ///
39    /// A new instance of the `RandomForestRegressor`.
40    pub fn new() -> Self {
41        Self {
42            forest_params: ForestParams::new(),
43            tree_params: TreeParams::new(),
44        }
45    }
46
47    /// Creates a new `RandomForestRegressor` with the specified parameters.
48    ///
49    /// # Arguments
50    ///
51    /// * `num_trees` - The number of trees in the random forest. If not specified, the default value is 3.
52    /// * `min_samples_split` - The minimum number of samples required to split an internal node. If not specified, the default value is 2.
53    /// * `max_depth` - The maximum depth of the decision trees. If not specified, there is no maximum depth.
54    /// * `sample_size` - The size of the random subsets of the training data used to train each tree. If not specified, the default value is calculated as the total number of samples divided by the number of trees.
55    ///
56    /// # Returns
57    ///
58    /// A `Result` containing the `RandomForestRegressor` if the parameters are valid, or a `Box<dyn Error>` if an error occurs.
59    pub fn with_params(
60        num_trees: Option<usize>,
61        min_samples_split: Option<u16>,
62        max_depth: Option<u16>,
63        sample_size: Option<usize>,
64    ) -> Result<Self, Box<dyn Error>> {
65        let mut forest = Self::new();
66
67        forest.set_num_trees(num_trees.unwrap_or(3))?;
68        forest.set_sample_size(sample_size)?;
69        forest.set_min_samples_split(min_samples_split.unwrap_or(2))?;
70        forest.set_max_depth(max_depth)?;
71        Ok(forest)
72    }
73
74    /// Sets the decision trees for the random forest regressor.
75    ///
76    /// # Arguments
77    ///
78    /// * `trees` - A vector of `DecisionTreeRegressor` instances.
79    pub fn set_trees(&mut self, trees: Vec<DecisionTreeRegressor<T>>) {
80        self.forest_params.set_trees(trees);
81    }
82
83    /// Sets the number of trees in the random forest regressor.
84    ///
85    /// # Arguments
86    ///
87    /// * `num_trees` - The number of trees.
88    ///
89    /// # Returns
90    ///
91    /// Returns `Ok(())` if successful, otherwise returns an error.
92    pub fn set_num_trees(&mut self, num_trees: usize) -> Result<(), Box<dyn Error>> {
93        self.forest_params.set_num_trees(num_trees)
94    }
95
96    /// Sets the sample size for each tree in the random forest regressor.
97    ///
98    /// # Arguments
99    ///
100    /// * `sample_size` - The sample size for each tree. Use `None` for full sample size.
101    ///
102    /// # Returns
103    ///
104    /// Returns `Ok(())` if successful, otherwise returns an error.
105    pub fn set_sample_size(&mut self, sample_size: Option<usize>) -> Result<(), Box<dyn Error>> {
106        self.forest_params.set_sample_size(sample_size)
107    }
108
109    /// Sets the minimum number of samples required to split an internal node in each decision tree.
110    ///
111    /// # Arguments
112    ///
113    /// * `min_samples_split` - The minimum number of samples required to split an internal node.
114    ///
115    /// # Returns
116    ///
117    /// Returns `Ok(())` if successful, otherwise returns an error.
118    pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
119        self.tree_params.set_min_samples_split(min_samples_split)
120    }
121
122    /// Sets the maximum depth of each decision tree in the random forest regressor.
123    ///
124    /// # Arguments
125    ///
126    /// * `max_depth` - The maximum depth of each decision tree. Use `None` for unlimited depth.
127    ///
128    /// # Returns
129    ///
130    /// Returns `Ok(())` if successful, otherwise returns an error.
131    pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
132        self.tree_params.set_max_depth(max_depth)
133    }
134
135    /// Returns a reference to the decision trees in the random forest regressor.
136    pub fn trees(&self) -> &Vec<DecisionTreeRegressor<T>> {
137        self.forest_params.trees()
138    }
139
140    /// Returns the number of trees in the random forest regressor.
141    pub fn num_trees(&self) -> usize {
142        self.forest_params.num_trees()
143    }
144
145    /// Returns the sample size for each tree in the random forest regressor.
146    pub fn sample_size(&self) -> Option<usize> {
147        self.forest_params.sample_size()
148    }
149
150    /// Returns the minimum number of samples required to split an internal node in each decision tree.
151    pub fn min_samples_split(&self) -> u16 {
152        self.tree_params.min_samples_split()
153    }
154
155    /// Returns the maximum depth of each decision tree in the random forest regressor.
156    pub fn max_depth(&self) -> Option<u16> {
157        self.tree_params.max_depth()
158    }
159
160    /// Fits the random forest regressor to the given dataset.
161    ///
162    /// # Arguments
163    ///
164    /// * `dataset` - The dataset to fit the random forest regressor to.
165    /// * `seed` - The seed for the random number generator. Use `None` for a random seed.
166    ///
167    /// # Returns
168    ///
169    /// Returns a string indicating the completion of the fitting process if successful, otherwise returns an error.
170    pub fn fit(
171        &mut self,
172        dataset: &Dataset<T, T>,
173        seed: Option<u64>,
174    ) -> Result<String, Box<dyn Error>> {
175        let mut rng = match seed {
176            Some(seed) => StdRng::seed_from_u64(seed),
177            _ => StdRng::from_entropy(),
178        };
179
180        let seeds = (0..self.num_trees())
181            .map(|_| rng.gen::<u64>())
182            .collect::<Vec<_>>();
183
184        match self.sample_size() {
185            Some(sample_size) if sample_size > dataset.x.nrows() => {
186                return Err("The set sample size is greater than the dataset size.".into())
187            }
188            None => self.set_sample_size(Some(dataset.x.nrows() / self.num_trees()))?,
189            _ => {}
190        }
191        let trees: Result<Vec<_>, String> = seeds
192            .into_par_iter()
193            .map(|tree_seed| {
194                let subset = dataset.samples(self.sample_size().unwrap(), Some(tree_seed));
195                let mut tree = DecisionTreeRegressor::with_params(
196                    Some(self.min_samples_split()),
197                    self.max_depth(),
198                )
199                .map_err(|error| error.to_string())?;
200                tree.fit(&subset).map_err(|error| error.to_string())?;
201                Ok(tree)
202            })
203            .collect();
204        self.set_trees(trees?);
205        Ok("Finished building the trees.".into())
206    }
207
208    /// Predicts the target values for the given features using the random forest regressor.
209    ///
210    /// # Arguments
211    ///
212    /// * `features` - The features to predict the target values for.
213    ///
214    /// # Returns
215    ///
216    /// Returns a vector of predicted target values if successful, otherwise returns an error.
217    pub fn predict(&self, features: &DMatrix<T>) -> Result<DVector<T>, Box<dyn Error>> {
218        let mut predictions = DVector::from_element(features.nrows(), T::from_f64(0.0).unwrap());
219
220        for i in 0..features.nrows() {
221            let mut total_prediction = T::from_f64(0.0).unwrap();
222            for tree in self.trees() {
223                let prediction = tree.predict(&DMatrix::from_row_slice(
224                    1,
225                    features.ncols(),
226                    features.row(i).transpose().as_slice(),
227                ))?;
228                total_prediction += prediction[0];
229            }
230
231            predictions[i] = total_prediction / T::from_usize(self.trees().len()).unwrap();
232        }
233        Ok(predictions)
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use nalgebra::{DMatrix, DVector};
241
242    // Helper function to create a small mock dataset
243    fn create_mock_dataset() -> Dataset<f64, f64> {
244        let x = DMatrix::from_row_slice(
245            6,
246            2,
247            &[1.0, 2.0, 1.1, 2.1, 1.2, 2.2, 3.0, 4.0, 3.1, 4.1, 3.2, 4.2],
248        );
249        let y = DVector::from_vec(vec![0.5, 0.5, 0.5, 1.5, 1.5, 1.5]);
250        Dataset::new(x, y)
251    }
252
253    #[test]
254    fn test_default() {
255        let forest = RandomForestRegressor::<f64>::default();
256        assert_eq!(forest.num_trees(), 3);
257        assert_eq!(forest.min_samples_split(), 2);
258    }
259
260    #[test]
261    fn test_with_params() {
262        let forest =
263            RandomForestRegressor::<f64>::with_params(Some(10), Some(4), Some(5), Some(100))
264                .unwrap();
265        assert_eq!(forest.num_trees(), 10);
266        assert_eq!(forest.min_samples_split(), 4);
267        assert_eq!(forest.max_depth(), Some(5));
268        assert_eq!(forest.sample_size(), Some(100));
269    }
270
271    #[test]
272    fn test_fit() {
273        let mut forest = RandomForestRegressor::<f64>::new();
274        let dataset = create_mock_dataset();
275        let fit_result = forest.fit(&dataset, Some(42));
276        assert!(fit_result.is_ok());
277        assert_eq!(forest.trees().len(), 3);
278    }
279
280    #[test]
281    fn test_predict() {
282        let mut forest = RandomForestRegressor::<f64>::new();
283        let dataset = create_mock_dataset();
284        forest.fit(&dataset, Some(42)).unwrap();
285
286        let features = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
287        let predictions = forest.predict(&features).unwrap();
288        assert_eq!(predictions.len(), 2);
289
290        assert!(predictions[0] <= 1.5 && predictions[0] >= 0.5);
291        assert!(predictions[1] <= 1.5 && predictions[1] >= 0.5);
292    }
293}