1use crate::data::dataset::{Dataset, Number, WholeNumber};
2use crate::metrics::confusion::ClassificationMetrics;
3use crate::trees::classifier::DecisionTreeClassifier;
4use crate::trees::params::TreeClassifierParams;
5use nalgebra::{DMatrix, DVector};
6use rand::rngs::StdRng;
7use rand::{Rng, SeedableRng};
8use rayon::prelude::*;
9use std::collections::HashMap;
10use std::error::Error;
11
12use super::params::ForestParams;
13
14#[derive(Clone, Debug)]
15pub struct RandomForestClassifier<XT: Number, YT: WholeNumber> {
16 forest_params: ForestParams<DecisionTreeClassifier<XT, YT>>,
17 tree_params: TreeClassifierParams,
18}
19
20impl<XT: Number, YT: WholeNumber> ClassificationMetrics<YT> for RandomForestClassifier<XT, YT> {}
21
22impl<XT: Number, YT: WholeNumber> Default for RandomForestClassifier<XT, YT> {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl<XT: Number, YT: WholeNumber> RandomForestClassifier<XT, YT> {
69 pub fn new() -> Self {
78 Self {
79 forest_params: ForestParams::new(),
80 tree_params: TreeClassifierParams::new(),
81 }
82 }
83
84 pub fn with_params(
98 num_trees: Option<usize>,
99 min_samples_split: Option<u16>,
100 max_depth: Option<u16>,
101 criterion: Option<String>,
102 sample_size: Option<usize>,
103 ) -> Result<Self, Box<dyn Error>> {
104 let mut forest = Self::new();
105
106 forest.set_num_trees(num_trees.unwrap_or(3))?;
107 forest.set_sample_size(sample_size)?;
108 forest.set_min_samples_split(min_samples_split.unwrap_or(2))?;
109 forest.set_max_depth(max_depth)?;
110 forest.set_criterion(criterion.unwrap_or("gini".to_string()))?;
111 Ok(forest)
112 }
113
114 pub fn set_trees(&mut self, trees: Vec<DecisionTreeClassifier<XT, YT>>) {
120 self.forest_params.set_trees(trees);
121 }
122
123 pub fn set_num_trees(&mut self, num_trees: usize) -> Result<(), Box<dyn Error>> {
133 self.forest_params.set_num_trees(num_trees)
134 }
135
136 pub fn set_sample_size(&mut self, sample_size: Option<usize>) -> Result<(), Box<dyn Error>> {
146 self.forest_params.set_sample_size(sample_size)
147 }
148
149 pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
159 self.tree_params.set_min_samples_split(min_samples_split)
160 }
161
162 pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
172 self.tree_params.set_max_depth(max_depth)
173 }
174
175 pub fn set_criterion(&mut self, criterion: String) -> Result<(), Box<dyn Error>> {
185 self.tree_params.set_criterion(criterion)
186 }
187
188 pub fn trees(&self) -> &Vec<DecisionTreeClassifier<XT, YT>> {
190 self.forest_params.trees()
191 }
192
193 pub fn num_trees(&self) -> usize {
195 self.forest_params.num_trees()
196 }
197
198 pub fn sample_size(&self) -> Option<usize> {
200 self.forest_params.sample_size()
201 }
202
203 pub fn min_samples_split(&self) -> u16 {
205 self.tree_params.min_samples_split()
206 }
207
208 pub fn max_depth(&self) -> Option<u16> {
210 self.tree_params.max_depth()
211 }
212
213 pub fn criterion(&self) -> &String {
215 &self.tree_params.criterion
216 }
217
218 pub fn fit(
229 &mut self,
230 dataset: &Dataset<XT, YT>,
231 seed: Option<u64>,
232 ) -> Result<String, Box<dyn Error>> {
233 let mut rng = match seed {
234 Some(seed) => StdRng::seed_from_u64(seed),
235 _ => StdRng::from_entropy(),
236 };
237
238 let seeds = (0..self.num_trees())
239 .map(|_| rng.gen::<u64>())
240 .collect::<Vec<_>>();
241
242 match self.sample_size() {
243 Some(sample_size) if sample_size > dataset.nrows() => {
244 return Err(format!(
245 "The set sample size is greater than the dataset size. {} > {}",
246 sample_size,
247 dataset.nrows()
248 )
249 .into());
250 }
251 None => self.set_sample_size(Some(dataset.nrows() / self.num_trees()))?,
252 _ => {}
253 }
254
255 let trees: Result<Vec<_>, String> = seeds
256 .into_par_iter()
257 .map(|tree_seed| {
258 let subset = dataset.samples(self.sample_size().unwrap(), Some(tree_seed));
259 let mut tree = DecisionTreeClassifier::with_params(
260 Some(self.criterion().clone()),
261 Some(self.min_samples_split()),
262 self.max_depth(),
263 )
264 .map_err(|error| error.to_string())?;
265 tree.fit(&subset).map_err(|error| error.to_string())?;
266 Ok(tree)
267 })
268 .collect();
269 self.set_trees(trees?);
270 Ok("Finished building the trees".into())
271 }
272
273 pub fn predict(&self, features: &DMatrix<XT>) -> Result<DVector<YT>, Box<dyn Error>> {
284 let mut predictions = DVector::from_element(features.nrows(), YT::from_u8(0).unwrap());
285
286 for i in 0..features.nrows() {
287 let mut class_counts = HashMap::new();
288 for tree in self.trees() {
289 let prediction = tree
290 .predict(&DMatrix::from_row_slice(
291 1,
292 features.ncols(),
293 features.row(i).transpose().as_slice(),
294 ))
295 .map_err(|error| error.to_string())?;
296 *class_counts.entry(prediction[0]).or_insert(0) += 1;
297 }
298
299 let chosen_class = class_counts
300 .into_iter()
301 .max_by_key(|&(_, count)| count)
302 .map(|(class, _)| class)
303 .ok_or(
304 "Prediction failure. No trees built or class counts are empty.".to_string(),
305 )?;
306 predictions[i] = chosen_class;
307 }
308 Ok(predictions)
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 fn create_mock_dataset() -> Dataset<f64, u8> {
317 let x = DMatrix::from_row_slice(
318 6,
319 2,
320 &[1.0, 2.0, 1.1, 2.1, 1.2, 2.2, 3.0, 4.0, 3.1, 4.1, 3.2, 4.2],
321 );
322 let y = DVector::from_vec(vec![0, 0, 0, 1, 1, 1]);
323 Dataset::new(x, y)
324 }
325
326 #[test]
327 fn test_default() {
328 let forest = RandomForestClassifier::<f64, u8>::default();
329 assert_eq!(forest.num_trees(), 3); assert_eq!(forest.min_samples_split(), 2); }
332
333 #[test]
334 fn test_new() {
335 let forest = RandomForestClassifier::<f64, u8>::new();
336 assert_eq!(forest.num_trees(), 3); assert_eq!(forest.min_samples_split(), 2); }
339
340 #[test]
341 fn test_with_params() {
342 let forest = RandomForestClassifier::<f64, u8>::with_params(
343 Some(10), Some(4), Some(5), Some("entropy".to_string()), Some(100), )
349 .unwrap();
350 assert_eq!(forest.num_trees(), 10);
351 assert_eq!(forest.min_samples_split(), 4);
352 assert_eq!(forest.max_depth(), Some(5));
353 assert_eq!(forest.criterion(), "entropy");
354 assert_eq!(forest.sample_size(), Some(100));
355 }
356
357 #[test]
358 fn test_too_low_sample_size() {
359 let forest = RandomForestClassifier::<f64, u8>::new().set_sample_size(Some(0));
360 assert!(forest.is_err());
361 assert_eq!(
362 forest.unwrap_err().to_string(),
363 "The sample size must be greater than 0."
364 );
365 }
366
367 #[test]
368 fn test_too_low_num_trees() {
369 let forest = RandomForestClassifier::<f64, u8>::new().set_num_trees(1);
370 assert!(forest.is_err());
371 assert_eq!(
372 forest.unwrap_err().to_string(),
373 "The number of trees must be greater than 1."
374 );
375 }
376
377 #[test]
378 fn test_fit() {
379 let mut forest = RandomForestClassifier::<f64, u8>::new();
380 let dataset = create_mock_dataset();
381 let fit_result = forest.fit(&dataset, Some(42)); assert!(fit_result.is_ok());
383 assert_eq!(forest.trees().len(), 3); }
385
386 #[test]
387 fn test_fit_too_many_samples() {
388 let mut forest = RandomForestClassifier::<f64, u8>::new();
389 let _ = forest.set_sample_size(Some(1000));
390 let dataset = create_mock_dataset();
391 let fit_result = forest.fit(&dataset, Some(42)); assert!(fit_result.is_err());
394 assert_eq!(
395 fit_result.unwrap_err().to_string(),
396 "The set sample size is greater than the dataset size. 1000 > 6"
397 );
398 }
399
400 #[test]
401 fn test_predict() {
402 let mut forest = RandomForestClassifier::<f64, u8>::new();
403 let _ = forest.set_sample_size(Some(3));
404 let dataset = create_mock_dataset();
405 forest.fit(&dataset, Some(42)).unwrap();
406
407 let features = DMatrix::from_row_slice(
408 2,
409 2,
410 &[
411 1.0, 2.0, 3.0, 4.0, ],
414 );
415 let predictions = forest.predict(&features).unwrap();
416 assert_eq!(predictions, DVector::from_vec(vec![0, 1]));
417 }
418}