rusty_ai/forests/
regressor.rs1use std::error::Error;
2
3use nalgebra::{DMatrix, DVector};
4use rand::{rngs::StdRng, Rng, SeedableRng};
5use rayon::prelude::*;
6
7use crate::{
8 data::dataset::{Dataset, RealNumber},
9 metrics::errors::RegressionMetrics,
10 trees::{params::TreeParams, regressor::DecisionTreeRegressor},
11};
12
13use super::params::ForestParams;
14
15#[derive(Clone, Debug)]
16pub struct RandomForestRegressor<T: RealNumber> {
17 forest_params: ForestParams<DecisionTreeRegressor<T>>,
18 tree_params: TreeParams,
19}
20
21impl<T: RealNumber> Default for RandomForestRegressor<T> {
22 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl<T: RealNumber> RegressionMetrics<T> for RandomForestRegressor<T> {}
33
34impl<T: RealNumber> RandomForestRegressor<T> {
35 pub fn new() -> Self {
41 Self {
42 forest_params: ForestParams::new(),
43 tree_params: TreeParams::new(),
44 }
45 }
46
47 pub fn with_params(
60 num_trees: Option<usize>,
61 min_samples_split: Option<u16>,
62 max_depth: Option<u16>,
63 sample_size: Option<usize>,
64 ) -> Result<Self, Box<dyn Error>> {
65 let mut forest = Self::new();
66
67 forest.set_num_trees(num_trees.unwrap_or(3))?;
68 forest.set_sample_size(sample_size)?;
69 forest.set_min_samples_split(min_samples_split.unwrap_or(2))?;
70 forest.set_max_depth(max_depth)?;
71 Ok(forest)
72 }
73
74 pub fn set_trees(&mut self, trees: Vec<DecisionTreeRegressor<T>>) {
80 self.forest_params.set_trees(trees);
81 }
82
83 pub fn set_num_trees(&mut self, num_trees: usize) -> Result<(), Box<dyn Error>> {
93 self.forest_params.set_num_trees(num_trees)
94 }
95
96 pub fn set_sample_size(&mut self, sample_size: Option<usize>) -> Result<(), Box<dyn Error>> {
106 self.forest_params.set_sample_size(sample_size)
107 }
108
109 pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
119 self.tree_params.set_min_samples_split(min_samples_split)
120 }
121
122 pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
132 self.tree_params.set_max_depth(max_depth)
133 }
134
135 pub fn trees(&self) -> &Vec<DecisionTreeRegressor<T>> {
137 self.forest_params.trees()
138 }
139
140 pub fn num_trees(&self) -> usize {
142 self.forest_params.num_trees()
143 }
144
145 pub fn sample_size(&self) -> Option<usize> {
147 self.forest_params.sample_size()
148 }
149
150 pub fn min_samples_split(&self) -> u16 {
152 self.tree_params.min_samples_split()
153 }
154
155 pub fn max_depth(&self) -> Option<u16> {
157 self.tree_params.max_depth()
158 }
159
160 pub fn fit(
171 &mut self,
172 dataset: &Dataset<T, T>,
173 seed: Option<u64>,
174 ) -> Result<String, Box<dyn Error>> {
175 let mut rng = match seed {
176 Some(seed) => StdRng::seed_from_u64(seed),
177 _ => StdRng::from_entropy(),
178 };
179
180 let seeds = (0..self.num_trees())
181 .map(|_| rng.gen::<u64>())
182 .collect::<Vec<_>>();
183
184 match self.sample_size() {
185 Some(sample_size) if sample_size > dataset.x.nrows() => {
186 return Err("The set sample size is greater than the dataset size.".into())
187 }
188 None => self.set_sample_size(Some(dataset.x.nrows() / self.num_trees()))?,
189 _ => {}
190 }
191 let trees: Result<Vec<_>, String> = seeds
192 .into_par_iter()
193 .map(|tree_seed| {
194 let subset = dataset.samples(self.sample_size().unwrap(), Some(tree_seed));
195 let mut tree = DecisionTreeRegressor::with_params(
196 Some(self.min_samples_split()),
197 self.max_depth(),
198 )
199 .map_err(|error| error.to_string())?;
200 tree.fit(&subset).map_err(|error| error.to_string())?;
201 Ok(tree)
202 })
203 .collect();
204 self.set_trees(trees?);
205 Ok("Finished building the trees.".into())
206 }
207
208 pub fn predict(&self, features: &DMatrix<T>) -> Result<DVector<T>, Box<dyn Error>> {
218 let mut predictions = DVector::from_element(features.nrows(), T::from_f64(0.0).unwrap());
219
220 for i in 0..features.nrows() {
221 let mut total_prediction = T::from_f64(0.0).unwrap();
222 for tree in self.trees() {
223 let prediction = tree.predict(&DMatrix::from_row_slice(
224 1,
225 features.ncols(),
226 features.row(i).transpose().as_slice(),
227 ))?;
228 total_prediction += prediction[0];
229 }
230
231 predictions[i] = total_prediction / T::from_usize(self.trees().len()).unwrap();
232 }
233 Ok(predictions)
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use nalgebra::{DMatrix, DVector};
241
242 fn create_mock_dataset() -> Dataset<f64, f64> {
244 let x = DMatrix::from_row_slice(
245 6,
246 2,
247 &[1.0, 2.0, 1.1, 2.1, 1.2, 2.2, 3.0, 4.0, 3.1, 4.1, 3.2, 4.2],
248 );
249 let y = DVector::from_vec(vec![0.5, 0.5, 0.5, 1.5, 1.5, 1.5]);
250 Dataset::new(x, y)
251 }
252
253 #[test]
254 fn test_default() {
255 let forest = RandomForestRegressor::<f64>::default();
256 assert_eq!(forest.num_trees(), 3);
257 assert_eq!(forest.min_samples_split(), 2);
258 }
259
260 #[test]
261 fn test_with_params() {
262 let forest =
263 RandomForestRegressor::<f64>::with_params(Some(10), Some(4), Some(5), Some(100))
264 .unwrap();
265 assert_eq!(forest.num_trees(), 10);
266 assert_eq!(forest.min_samples_split(), 4);
267 assert_eq!(forest.max_depth(), Some(5));
268 assert_eq!(forest.sample_size(), Some(100));
269 }
270
271 #[test]
272 fn test_fit() {
273 let mut forest = RandomForestRegressor::<f64>::new();
274 let dataset = create_mock_dataset();
275 let fit_result = forest.fit(&dataset, Some(42));
276 assert!(fit_result.is_ok());
277 assert_eq!(forest.trees().len(), 3);
278 }
279
280 #[test]
281 fn test_predict() {
282 let mut forest = RandomForestRegressor::<f64>::new();
283 let dataset = create_mock_dataset();
284 forest.fit(&dataset, Some(42)).unwrap();
285
286 let features = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
287 let predictions = forest.predict(&features).unwrap();
288 assert_eq!(predictions.len(), 2);
289
290 assert!(predictions[0] <= 1.5 && predictions[0] >= 0.5);
291 assert!(predictions[1] <= 1.5 && predictions[1] >= 0.5);
292 }
293}