rusty_ai/forests/
classifier.rs

1use crate::data::dataset::{Dataset, Number, WholeNumber};
2use crate::metrics::confusion::ClassificationMetrics;
3use crate::trees::classifier::DecisionTreeClassifier;
4use crate::trees::params::TreeClassifierParams;
5use nalgebra::{DMatrix, DVector};
6use rand::rngs::StdRng;
7use rand::{Rng, SeedableRng};
8use rayon::prelude::*;
9use std::collections::HashMap;
10use std::error::Error;
11
12use super::params::ForestParams;
13
14#[derive(Clone, Debug)]
15pub struct RandomForestClassifier<XT: Number, YT: WholeNumber> {
16    forest_params: ForestParams<DecisionTreeClassifier<XT, YT>>,
17    tree_params: TreeClassifierParams,
18}
19
20impl<XT: Number, YT: WholeNumber> ClassificationMetrics<YT> for RandomForestClassifier<XT, YT> {}
21
22impl<XT: Number, YT: WholeNumber> Default for RandomForestClassifier<XT, YT> {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28/// This module contains the implementation of the `RandomForestClassifier` struct.
29///
30/// The `RandomForestClassifier` is a machine learning algorithm that combines multiple decision trees to make predictions.
31/// It is used for classification tasks where the input features are of type `XT` and the target labels are of type `YT`.
32///
33/// # Example
34///
35/// ```rust
36/// use rusty_ai::forests::classifier::RandomForestClassifier;
37/// use rusty_ai::data::dataset::Dataset;
38/// use nalgebra::{DMatrix, DVector};
39///
40/// // Create a mock dataset
41/// let x = DMatrix::from_row_slice(
42///     6,
43///     2,
44///     &[1.0, 2.0, 1.1, 2.1, 1.2, 2.2, 3.0, 4.0, 3.1, 4.1, 3.2, 4.2],
45/// );
46/// let y = DVector::from_vec(vec![0, 0, 0, 1, 1, 1]);
47/// let dataset = Dataset::new(x, y);
48///
49/// // Create a random forest classifier with default parameters
50/// let mut forest = RandomForestClassifier::<f64, u8>::default();
51///
52/// // Fit the classifier to the dataset
53/// forest.fit(&dataset, Some(42)).unwrap();
54///
55/// // Make predictions on new features
56/// let features = DMatrix::from_row_slice(
57///     2,
58///     2,
59///     &[
60///         1.0, 2.0, // Should be classified as class 0
61///         3.0, 4.0, // Should be classified as class 1
62///     ],
63/// );
64/// let predictions = forest.predict(&features).unwrap();
65/// println!("Predictions: {:?}", predictions);
66/// ```
67
68impl<XT: Number, YT: WholeNumber> RandomForestClassifier<XT, YT> {
69    /// Creates a new instance of the Random Forest Classifier.
70    ///
71    /// This function initializes the classifier with empty frequency maps and an empty
72    /// vector to store the count of unique feature values.
73    ///
74    /// # Returns
75    ///
76    /// A new instance of the Random Forest Classifier.
77    pub fn new() -> Self {
78        Self {
79            forest_params: ForestParams::new(),
80            tree_params: TreeClassifierParams::new(),
81        }
82    }
83
84    /// Creates a new instance of the Random Forest Classifier with specified parameters.
85    ///
86    /// # Arguments
87    ///
88    /// * `num_trees` - The number of trees in the forest. If not specified, defaults to 3.
89    /// * `min_samples_split` - The minimum number of samples required to split an internal node. If not specified, defaults to 2.
90    /// * `max_depth` - The maximum depth of the decision trees. If not specified, defaults to None.
91    /// * `criterion` - The function to measure the quality of a split. If not specified, defaults to "gini".
92    /// * `sample_size` - The size of the random subsets of the dataset to train each tree. If not specified, defaults to None.
93    ///
94    /// # Returns
95    ///
96    /// A `Result` containing the Random Forest Classifier instance or an error.
97    pub fn with_params(
98        num_trees: Option<usize>,
99        min_samples_split: Option<u16>,
100        max_depth: Option<u16>,
101        criterion: Option<String>,
102        sample_size: Option<usize>,
103    ) -> Result<Self, Box<dyn Error>> {
104        let mut forest = Self::new();
105
106        forest.set_num_trees(num_trees.unwrap_or(3))?;
107        forest.set_sample_size(sample_size)?;
108        forest.set_min_samples_split(min_samples_split.unwrap_or(2))?;
109        forest.set_max_depth(max_depth)?;
110        forest.set_criterion(criterion.unwrap_or("gini".to_string()))?;
111        Ok(forest)
112    }
113
114    /// Sets the decision trees of the random forest.
115    ///
116    /// # Arguments
117    ///
118    /// * `trees` - A vector of DecisionTreeClassifier instances.
119    pub fn set_trees(&mut self, trees: Vec<DecisionTreeClassifier<XT, YT>>) {
120        self.forest_params.set_trees(trees);
121    }
122
123    /// Sets the number of trees in the random forest.
124    ///
125    /// # Arguments
126    ///
127    /// * `num_trees` - The number of trees.
128    ///
129    /// # Returns
130    ///
131    /// A `Result` indicating success or an error.
132    pub fn set_num_trees(&mut self, num_trees: usize) -> Result<(), Box<dyn Error>> {
133        self.forest_params.set_num_trees(num_trees)
134    }
135
136    /// Sets the sample size for each tree in the random forest.
137    ///
138    /// # Arguments
139    ///
140    /// * `sample_size` - The sample size.
141    ///
142    /// # Returns
143    ///
144    /// A `Result` indicating success or an error.
145    pub fn set_sample_size(&mut self, sample_size: Option<usize>) -> Result<(), Box<dyn Error>> {
146        self.forest_params.set_sample_size(sample_size)
147    }
148
149    /// Sets the minimum number of samples required to split an internal node in each decision tree.
150    ///
151    /// # Arguments
152    ///
153    /// * `min_samples_split` - The minimum number of samples.
154    ///
155    /// # Returns
156    ///
157    /// A `Result` indicating success or an error.
158    pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
159        self.tree_params.set_min_samples_split(min_samples_split)
160    }
161
162    /// Sets the maximum depth of each decision tree in the random forest.
163    ///
164    /// # Arguments
165    ///
166    /// * `max_depth` - The maximum depth.
167    ///
168    /// # Returns
169    ///
170    /// A `Result` indicating success or an error.
171    pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
172        self.tree_params.set_max_depth(max_depth)
173    }
174
175    /// Sets the criterion function to measure the quality of a split in each decision tree.
176    ///
177    /// # Arguments
178    ///
179    /// * `criterion` - The criterion function.
180    ///
181    /// # Returns
182    ///
183    /// A `Result` indicating success or an error.
184    pub fn set_criterion(&mut self, criterion: String) -> Result<(), Box<dyn Error>> {
185        self.tree_params.set_criterion(criterion)
186    }
187
188    /// Returns a reference to the decision trees in the random forest.
189    pub fn trees(&self) -> &Vec<DecisionTreeClassifier<XT, YT>> {
190        self.forest_params.trees()
191    }
192
193    /// Returns the number of trees in the random forest.
194    pub fn num_trees(&self) -> usize {
195        self.forest_params.num_trees()
196    }
197
198    /// Returns the sample size for each tree in the random forest.
199    pub fn sample_size(&self) -> Option<usize> {
200        self.forest_params.sample_size()
201    }
202
203    /// Returns the minimum number of samples required to split an internal node in each decision tree.
204    pub fn min_samples_split(&self) -> u16 {
205        self.tree_params.min_samples_split()
206    }
207
208    /// Returns the maximum depth of each decision tree in the random forest.
209    pub fn max_depth(&self) -> Option<u16> {
210        self.tree_params.max_depth()
211    }
212
213    /// Returns a reference to the criterion function used to measure the quality of a split in each decision tree.
214    pub fn criterion(&self) -> &String {
215        &self.tree_params.criterion
216    }
217
218    /// Fits the random forest to the given dataset.
219    ///
220    /// # Arguments
221    ///
222    /// * `dataset` - The dataset to fit the random forest to.
223    /// * `seed` - The seed for the random number generator used to generate random subsets of the dataset. If not specified, a random seed will be used.
224    ///
225    /// # Returns
226    ///
227    /// A `Result` indicating whether the fitting process was successful or an error occurred.
228    pub fn fit(
229        &mut self,
230        dataset: &Dataset<XT, YT>,
231        seed: Option<u64>,
232    ) -> Result<String, Box<dyn Error>> {
233        let mut rng = match seed {
234            Some(seed) => StdRng::seed_from_u64(seed),
235            _ => StdRng::from_entropy(),
236        };
237
238        let seeds = (0..self.num_trees())
239            .map(|_| rng.gen::<u64>())
240            .collect::<Vec<_>>();
241
242        match self.sample_size() {
243            Some(sample_size) if sample_size > dataset.nrows() => {
244                return Err(format!(
245                    "The set sample size is greater than the dataset size. {} > {}",
246                    sample_size,
247                    dataset.nrows()
248                )
249                .into());
250            }
251            None => self.set_sample_size(Some(dataset.nrows() / self.num_trees()))?,
252            _ => {}
253        }
254
255        let trees: Result<Vec<_>, String> = seeds
256            .into_par_iter()
257            .map(|tree_seed| {
258                let subset = dataset.samples(self.sample_size().unwrap(), Some(tree_seed));
259                let mut tree = DecisionTreeClassifier::with_params(
260                    Some(self.criterion().clone()),
261                    Some(self.min_samples_split()),
262                    self.max_depth(),
263                )
264                .map_err(|error| error.to_string())?;
265                tree.fit(&subset).map_err(|error| error.to_string())?;
266                Ok(tree)
267            })
268            .collect();
269        self.set_trees(trees?);
270        Ok("Finished building the trees".into())
271    }
272
273    /// Predicts the class labels for the given features using the random forest.
274    ///
275    /// # Arguments
276    ///
277    /// * `features` - The features to predict the class labels for.
278    ///
279    /// # Returns
280    ///
281    /// A `Result` containing a vector of predicted class labels or an error if the prediction
282    /// process fails.
283    pub fn predict(&self, features: &DMatrix<XT>) -> Result<DVector<YT>, Box<dyn Error>> {
284        let mut predictions = DVector::from_element(features.nrows(), YT::from_u8(0).unwrap());
285
286        for i in 0..features.nrows() {
287            let mut class_counts = HashMap::new();
288            for tree in self.trees() {
289                let prediction = tree
290                    .predict(&DMatrix::from_row_slice(
291                        1,
292                        features.ncols(),
293                        features.row(i).transpose().as_slice(),
294                    ))
295                    .map_err(|error| error.to_string())?;
296                *class_counts.entry(prediction[0]).or_insert(0) += 1;
297            }
298
299            let chosen_class = class_counts
300                .into_iter()
301                .max_by_key(|&(_, count)| count)
302                .map(|(class, _)| class)
303                .ok_or(
304                    "Prediction failure. No trees built or class counts are empty.".to_string(),
305                )?;
306            predictions[i] = chosen_class;
307        }
308        Ok(predictions)
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    fn create_mock_dataset() -> Dataset<f64, u8> {
317        let x = DMatrix::from_row_slice(
318            6,
319            2,
320            &[1.0, 2.0, 1.1, 2.1, 1.2, 2.2, 3.0, 4.0, 3.1, 4.1, 3.2, 4.2],
321        );
322        let y = DVector::from_vec(vec![0, 0, 0, 1, 1, 1]);
323        Dataset::new(x, y)
324    }
325
326    #[test]
327    fn test_default() {
328        let forest = RandomForestClassifier::<f64, u8>::default();
329        assert_eq!(forest.num_trees(), 3); // Default number of trees
330        assert_eq!(forest.min_samples_split(), 2); // Default min_samples_split
331    }
332
333    #[test]
334    fn test_new() {
335        let forest = RandomForestClassifier::<f64, u8>::new();
336        assert_eq!(forest.num_trees(), 3); // Default number of trees
337        assert_eq!(forest.min_samples_split(), 2); // Default min_samples_split
338    }
339
340    #[test]
341    fn test_with_params() {
342        let forest = RandomForestClassifier::<f64, u8>::with_params(
343            Some(10),                    // num_trees
344            Some(4),                     // min_samples_split
345            Some(5),                     // max_depth
346            Some("entropy".to_string()), // criterion
347            Some(100),                   // sample_size
348        )
349        .unwrap();
350        assert_eq!(forest.num_trees(), 10);
351        assert_eq!(forest.min_samples_split(), 4);
352        assert_eq!(forest.max_depth(), Some(5));
353        assert_eq!(forest.criterion(), "entropy");
354        assert_eq!(forest.sample_size(), Some(100));
355    }
356
357    #[test]
358    fn test_too_low_sample_size() {
359        let forest = RandomForestClassifier::<f64, u8>::new().set_sample_size(Some(0));
360        assert!(forest.is_err());
361        assert_eq!(
362            forest.unwrap_err().to_string(),
363            "The sample size must be greater than 0."
364        );
365    }
366
367    #[test]
368    fn test_too_low_num_trees() {
369        let forest = RandomForestClassifier::<f64, u8>::new().set_num_trees(1);
370        assert!(forest.is_err());
371        assert_eq!(
372            forest.unwrap_err().to_string(),
373            "The number of trees must be greater than 1."
374        );
375    }
376
377    #[test]
378    fn test_fit() {
379        let mut forest = RandomForestClassifier::<f64, u8>::new();
380        let dataset = create_mock_dataset();
381        let fit_result = forest.fit(&dataset, Some(42)); // Using a fixed seed for reproducibility
382        assert!(fit_result.is_ok());
383        assert_eq!(forest.trees().len(), 3); // Should have 3 trees after fitting
384    }
385
386    #[test]
387    fn test_fit_too_many_samples() {
388        let mut forest = RandomForestClassifier::<f64, u8>::new();
389        let _ = forest.set_sample_size(Some(1000));
390        let dataset = create_mock_dataset();
391        let fit_result = forest.fit(&dataset, Some(42)); // Using a fixed seed for reproducibility
392
393        assert!(fit_result.is_err());
394        assert_eq!(
395            fit_result.unwrap_err().to_string(),
396            "The set sample size is greater than the dataset size. 1000 > 6"
397        );
398    }
399
400    #[test]
401    fn test_predict() {
402        let mut forest = RandomForestClassifier::<f64, u8>::new();
403        let _ = forest.set_sample_size(Some(3));
404        let dataset = create_mock_dataset();
405        forest.fit(&dataset, Some(42)).unwrap();
406
407        let features = DMatrix::from_row_slice(
408            2,
409            2,
410            &[
411                1.0, 2.0, // Should be classified as class 0
412                3.0, 4.0, // Should be classified as class 1
413            ],
414        );
415        let predictions = forest.predict(&features).unwrap();
416        assert_eq!(predictions, DVector::from_vec(vec![0, 1]));
417    }
418}