1use super::{node::TreeNode, params::TreeParams};
3use crate::{
4 data::dataset::{Dataset, RealNumber},
5 metrics::errors::RegressionMetrics,
6};
7use nalgebra::{DMatrix, DVector};
8use rayon::iter::{IntoParallelIterator, ParallelIterator};
9use std::{error::Error, f64, marker::PhantomData};
10
11pub struct SplitData<T: RealNumber> {
12 pub feature_index: usize,
13 pub threshold: T,
14 pub left: Dataset<T, T>,
15 pub right: Dataset<T, T>,
16 information_gain: f64,
17}
18
19#[derive(Clone, Debug)]
21pub struct DecisionTreeRegressor<T: RealNumber> {
22 root: Option<Box<TreeNode<T, T>>>,
23 tree_params: TreeParams,
24
25 _marker: PhantomData<T>,
26}
27
28impl<T: RealNumber> Default for DecisionTreeRegressor<T> {
29 fn default() -> Self {
31 Self::new()
32 }
33}
34
35impl<T: RealNumber> RegressionMetrics<T> for DecisionTreeRegressor<T> {}
36
37impl<T: RealNumber> DecisionTreeRegressor<T> {
38 pub fn new() -> Self {
40 Self {
41 root: None,
42 tree_params: TreeParams::new(),
43 _marker: PhantomData,
44 }
45 }
46
47 pub fn with_params(
62 min_samples_split: Option<u16>,
63 max_depth: Option<u16>,
64 ) -> Result<Self, Box<dyn Error>> {
65 let mut tree = Self::new();
66
67 tree.set_min_samples_split(min_samples_split.unwrap_or(2))?;
68 tree.set_max_depth(max_depth)?;
69 Ok(tree)
70 }
71
72 pub fn set_min_samples_split(&mut self, min_samples_split: u16) -> Result<(), Box<dyn Error>> {
82 self.tree_params.set_min_samples_split(min_samples_split)
83 }
84
85 pub fn set_max_depth(&mut self, max_depth: Option<u16>) -> Result<(), Box<dyn Error>> {
95 self.tree_params.set_max_depth(max_depth)
96 }
97
98 pub fn max_depth(&self) -> Option<u16> {
100 self.tree_params.max_depth()
101 }
102
103 pub fn min_samples_split(&self) -> u16 {
105 self.tree_params.min_samples_split()
106 }
107
108 pub fn fit(&mut self, dataset: &Dataset<T, T>) -> Result<String, Box<dyn Error>> {
122 self.root = Some(Box::new(self.build_tree(
123 dataset,
124 self.max_depth().map(|_| 0),
125 self.variance(&dataset.y),
126 )?));
127 Ok("Finished building the tree.".into())
128 }
129
130 pub fn predict(&self, prediction_features: &DMatrix<T>) -> Result<DVector<T>, String> {
144 if self.root.is_none() {
145 return Err("Tree wasn't built yet.".to_string());
146 }
147 let predictions: Vec<_> = prediction_features
148 .row_iter()
149 .map(|row| Self::make_prediction(row.transpose(), self.root.as_ref().unwrap()))
150 .collect();
151
152 Ok(DVector::from_vec(predictions))
153 }
154
155 fn make_prediction(features: DVector<T>, node: &TreeNode<T, T>) -> T {
156 if let Some(value) = &node.value {
157 return *value;
158 }
159 match &features[node.feature_index.unwrap()] {
160 x if x <= node.threshold.as_ref().unwrap() => {
161 return Self::make_prediction(features, node.left.as_ref().unwrap())
162 }
163 _ => return Self::make_prediction(features, node.right.as_ref().unwrap()),
164 }
165 }
166
167 fn build_tree(
168 &mut self,
169 dataset: &Dataset<T, T>,
170 current_depth: Option<u16>,
171 base_variance: f64,
172 ) -> Result<TreeNode<T, T>, Box<dyn Error>> {
173 let (x, y) = &dataset.into_parts();
174 let (num_samples, num_features) = x.shape();
175
176 let is_homogenous = self.variance(y) < 0.01 * base_variance;
177 if num_samples >= self.min_samples_split().into()
178 && current_depth <= self.max_depth()
179 && !is_homogenous
180 {
181 let splits = (0..num_features)
182 .into_par_iter()
183 .map(|feature_idx| self.get_split(dataset, feature_idx))
184 .collect::<Vec<_>>();
185
186 let valid_splits = splits
187 .into_iter()
188 .filter_map(Result::ok)
189 .collect::<Vec<_>>();
190
191 if valid_splits.is_empty() {
192 return Ok(TreeNode::new(Some(self.mean(y))));
193 }
194
195 let best_split = match valid_splits.into_iter().max_by(|split1, split2| {
196 split1
197 .information_gain
198 .partial_cmp(&split2.information_gain)
199 .unwrap_or(std::cmp::Ordering::Equal)
200 }) {
201 Some(split) => split,
202 _ => {
203 return Err("No best split found.".into());
204 }
205 };
206 let left_child = best_split.left;
207 let right_child = best_split.right;
208 if best_split.information_gain > 0.0 {
209 let new_depth = current_depth.map(|depth| depth + 1);
210 let left_node = self.build_tree(&left_child, new_depth, base_variance)?;
211 let right_node = self.build_tree(&right_child, new_depth, base_variance)?;
212 return Ok(TreeNode {
213 feature_index: Some(best_split.feature_index),
214 threshold: Some(best_split.threshold),
215 left: Some(Box::new(left_node)),
216 right: Some(Box::new(right_node)),
217 value: None,
218 });
219 }
220 }
221
222 let leaf_value = self.mean(y);
223 Ok(TreeNode::new(Some(leaf_value)))
224 }
225
226 fn get_split(
227 &self,
228 dataset: &Dataset<T, T>,
229 feature_index: usize,
230 ) -> Result<SplitData<T>, String> {
231 let mut best_split: Option<SplitData<T>> = None;
232 let mut best_information_gain = f64::NEG_INFINITY;
233
234 let mut unique_values: Vec<_> = dataset.x.column(feature_index).iter().cloned().collect();
235 unique_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
236 unique_values.dedup();
237
238 for value in &unique_values {
239 let (left_child, right_child) = dataset.split_on_threshold(feature_index, *value);
240
241 if left_child.is_not_empty() && right_child.is_not_empty() {
242 let current_information_gain =
243 self.calculate_variance_reduction(&dataset.y, &left_child.y, &right_child.y);
244
245 if current_information_gain > best_information_gain {
246 best_split = Some(SplitData {
247 feature_index,
248 threshold: *value,
249 left: left_child,
250 right: right_child,
251 information_gain: current_information_gain,
252 });
253 best_information_gain = current_information_gain;
254 }
255 }
256 }
257 best_split.ok_or("No split found.".into())
258 }
259
260 fn calculate_variance_reduction(
261 &self,
262 parent_y: &DVector<T>,
263 left_y: &DVector<T>,
264 right_y: &DVector<T>,
265 ) -> f64 {
266 let variance = self.variance(parent_y);
267 let left_variance = self.variance(left_y);
268 let right_variance = self.variance(right_y);
269 let num_samples = parent_y.len() as f64;
270 variance
271 - (left_variance * (left_y.len() as f64) / num_samples)
272 - (right_variance * (right_y.len() as f64) / num_samples)
273 }
274
275 fn variance(&self, y: &DVector<T>) -> f64 {
276 let mean = self.mean(y);
277 let variance = y.iter().fold(T::from_f64(0.0).unwrap(), |acc, x| {
278 acc + (*x - mean) * (*x - mean)
279 });
280 let variance_f64 = T::to_f64(&variance).unwrap();
281 variance_f64 / y.len() as f64
282 }
283
284 fn mean(&self, y: &DVector<T>) -> T {
285 let zero = T::from_f64(0.0).unwrap();
286 let sum: T = y.iter().fold(zero, |acc, x| acc + *x);
287 sum / T::from_usize(y.len()).unwrap()
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use nalgebra::DVector;
295
296 #[test]
297 fn test_mean() {
298 let y = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
299 let regressor: DecisionTreeRegressor<f64> = DecisionTreeRegressor::new();
300 let mean = regressor.mean(&y);
301 assert_eq!(mean, 3.5);
302 }
303
304 #[test]
305 fn test_variance() {
306 let y = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
307 let regressor: DecisionTreeRegressor<f64> = DecisionTreeRegressor::new();
308 let variance = regressor.variance(&y);
309 assert_eq!(variance, 2.0);
310 }
311
312 #[test]
313 fn test_calculate_variance_reduction() {
314 let parent_y = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
315 let left_y = DVector::from_vec(vec![1.0, 2.0]);
316 let right_y = DVector::from_vec(vec![3.0, 4.0, 5.0]);
317 let regressor: DecisionTreeRegressor<f64> = DecisionTreeRegressor::new();
318 let variance_reduction =
319 regressor.calculate_variance_reduction(&parent_y, &left_y, &right_y);
320 assert!(variance_reduction > 0.0);
321 }
322
323 #[test]
324 fn test_fit_and_predict() {
325 let x = DMatrix::from_vec(6, 1, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
326 let y = DVector::from_vec(vec![1.0, 4.0, 9.0, 16.0, 25.0, 36.0]);
327 let dataset = Dataset::new(x, y);
328 let mut regressor: DecisionTreeRegressor<f64> = DecisionTreeRegressor::new();
329 let _ = regressor.fit(&dataset);
330
331 let test_x = DMatrix::from_vec(3, 1, vec![2.0, 3.0, 4.0]);
332 let predictions = regressor.predict(&test_x).unwrap();
333
334 assert_eq!(predictions.len(), 3);
335 assert!(predictions.iter().all(|&x| x >= 0.0));
336 }
337
338 #[test]
339 fn test_fit_and_predict_with_single_row() {
340 let x = DMatrix::from_vec(1, 2, vec![1.0, 2.0]);
341 let y = DVector::from_vec(vec![1.0]);
342 let dataset = Dataset::new(x, y);
343 let mut regressor: DecisionTreeRegressor<f64> = DecisionTreeRegressor::new();
344 let _ = regressor.fit(&dataset);
345
346 let test_x = DMatrix::from_vec(1, 2, vec![2.0, 3.0]);
347 let predictions = regressor.predict(&test_x).unwrap();
348
349 assert_eq!(predictions.len(), 1);
350 assert!(predictions.iter().all(|&x| x >= 0.0));
351 }
352}