1use crate::dataset::Dataset;
8use crate::error::{Result, ScryLearnError};
9use crate::weights::{compute_sample_weights, ClassWeight};
10
11use super::{
12 compute_impurity, compute_impurity_weighted, majority_class, weighted_majority_class,
13 BestSplit, FlatTree, SplitCriterion, TreeNode,
14};
15
16pub(crate) fn presort_indices(data: &Dataset, indices: &[usize]) -> Vec<Vec<usize>> {
22 let n_features = data.n_features();
23 let mut sorted_by_feature = Vec::with_capacity(n_features);
24 for feat_idx in 0..n_features {
25 let col = &data.features[feat_idx];
26 let mut sorted = indices.to_vec();
27 sorted.sort_unstable_by(|&a, &b| {
28 col[a]
29 .partial_cmp(&col[b])
30 .unwrap_or(std::cmp::Ordering::Equal)
31 });
32 sorted_by_feature.push(sorted);
33 }
34 sorted_by_feature
35}
36
37fn filter_sorted(global_sorted: &[Vec<usize>], membership: &[bool]) -> Vec<Vec<usize>> {
39 global_sorted
40 .iter()
41 .map(|gs| gs.iter().copied().filter(|&idx| membership[idx]).collect())
42 .collect()
43}
44
45fn partition_sorted(
52 mut sorted_by_feature: Vec<Vec<usize>>,
53 split_col: &[f64],
54 threshold: f64,
55 _left_count: usize,
56 right_count: usize,
57) -> (Vec<Vec<usize>>, Vec<Vec<usize>>) {
58 let n_feat = sorted_by_feature.len();
59 let mut right_sorted = Vec::with_capacity(n_feat);
60 for feat_sorted in &mut sorted_by_feature {
61 let mut right = Vec::with_capacity(right_count);
62 let mut write = 0;
63 for read in 0..feat_sorted.len() {
64 let idx = feat_sorted[read];
65 if split_col[idx] <= threshold {
66 feat_sorted[write] = idx;
67 write += 1;
68 } else {
69 right.push(idx);
70 }
71 }
72 feat_sorted.truncate(write);
73 right_sorted.push(right);
74 }
75 (sorted_by_feature, right_sorted)
76}
77
78fn fill_feature_buf(
85 feature_buf: &mut Vec<usize>,
86 n_features: usize,
87 max_features: Option<usize>,
88 rng: &mut crate::rng::FastRng,
89) {
90 feature_buf.clear();
91 feature_buf.extend(0..n_features);
92 if let Some(max_f) = max_features {
93 let m = max_f.min(n_features);
94 for i in 0..m {
95 let j = rng.usize(i..n_features);
96 feature_buf.swap(i, j);
97 }
98 feature_buf.truncate(m);
99 }
100}
101
102#[derive(Clone)]
108#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
109#[non_exhaustive]
110pub struct DecisionTreeClassifier {
111 max_depth: Option<usize>,
112 min_samples_split: usize,
113 min_samples_leaf: usize,
114 max_features: Option<usize>,
115 criterion: SplitCriterion,
116 ccp_alpha: f64,
117 pub(crate) class_weight: ClassWeight,
119 pub(crate) sample_weights: Option<Vec<f64>>,
121 pub(crate) flat_tree: Option<FlatTree>,
123 n_classes: usize,
124 n_features: usize,
125 pub(crate) feature_importances_: Vec<f64>,
126 #[cfg_attr(feature = "serde", serde(default))]
127 _schema_version: u32,
128}
129
130impl DecisionTreeClassifier {
131 pub fn new() -> Self {
133 Self {
134 max_depth: None,
135 min_samples_split: 2,
136 min_samples_leaf: 1,
137 max_features: None,
138 criterion: SplitCriterion::Gini,
139 ccp_alpha: 0.0,
140 class_weight: ClassWeight::Uniform,
141 sample_weights: None,
142 flat_tree: None,
143 n_classes: 0,
144 n_features: 0,
145 feature_importances_: Vec::new(),
146 _schema_version: crate::version::SCHEMA_VERSION,
147 }
148 }
149
150 pub fn max_depth(mut self, d: usize) -> Self {
152 self.max_depth = Some(d);
153 self
154 }
155
156 pub fn min_samples_split(mut self, n: usize) -> Self {
158 self.min_samples_split = n;
159 self
160 }
161
162 pub fn min_samples_leaf(mut self, n: usize) -> Self {
164 self.min_samples_leaf = n;
165 self
166 }
167
168 pub fn max_features(mut self, n: usize) -> Self {
170 self.max_features = Some(n);
171 self
172 }
173
174 pub fn criterion(mut self, c: SplitCriterion) -> Self {
176 self.criterion = c;
177 self
178 }
179
180 pub fn class_weight(mut self, cw: ClassWeight) -> Self {
194 self.class_weight = cw;
195 self
196 }
197
198 pub fn ccp_alpha(mut self, alpha: f64) -> Self {
212 self.ccp_alpha = alpha;
213 self
214 }
215
216 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
218 data.validate_finite()?;
219 let indices: Vec<usize> = (0..data.n_samples()).collect();
220 self.fit_on_indices(data, &indices)
221 }
222
223 pub(crate) fn fit_on_indices(
232 &mut self,
233 data: &Dataset,
234 sample_indices: &[usize],
235 ) -> Result<()> {
236 let sorted_by_feature = presort_indices(data, sample_indices);
237 self.fit_with_sorted(data, sample_indices, sorted_by_feature)
238 }
239
240 pub(crate) fn fit_on_indices_presorted(
246 &mut self,
247 data: &Dataset,
248 sample_indices: &[usize],
249 global_sorted: &[Vec<usize>],
250 ) -> Result<()> {
251 let membership_len = global_sorted.first().map_or(0, Vec::len);
253 let mut membership = vec![false; membership_len];
254 for &i in sample_indices {
255 membership[i] = true;
256 }
257 let sorted_by_feature = filter_sorted(global_sorted, &membership);
258 self.fit_with_sorted(data, sample_indices, sorted_by_feature)
259 }
260
261 fn fit_with_sorted(
263 &mut self,
264 data: &Dataset,
265 sample_indices: &[usize],
266 sorted_by_feature: Vec<Vec<usize>>,
267 ) -> Result<()> {
268 let n = sample_indices.len();
269 if n == 0 {
270 return Err(ScryLearnError::EmptyDataset);
271 }
272
273 self.n_features = data.n_features();
274 self.n_classes = data.n_classes();
275 self.feature_importances_ = vec![0.0; self.n_features];
276
277 let weights = match &self.class_weight {
279 ClassWeight::Uniform => None,
280 cw => Some(compute_sample_weights(&data.target, cw)),
281 };
282 self.sample_weights = weights;
283
284 let mut feature_buf = Vec::with_capacity(self.n_features);
285 let mut split_rng = crate::rng::FastRng::new(0);
286
287 let tree = if self.sample_weights.is_some() {
288 self.build_tree_weighted(
289 data,
290 sorted_by_feature,
291 n,
292 0,
293 &mut feature_buf,
294 &mut split_rng,
295 )
296 } else {
297 self.build_tree(
298 data,
299 sorted_by_feature,
300 n,
301 0,
302 &mut feature_buf,
303 &mut split_rng,
304 )
305 };
306
307 let tree = if self.ccp_alpha > 0.0 {
309 tree.prune_ccp(self.ccp_alpha)
310 } else {
311 tree
312 };
313
314 let flat = FlatTree::from_tree_node(&tree, self.n_classes);
316 self.flat_tree = Some(flat);
317
318 let total: f64 = self.feature_importances_.iter().sum();
320 if total > 0.0 {
321 for imp in &mut self.feature_importances_ {
322 *imp /= total;
323 }
324 }
325
326 self.sample_weights = None;
328
329 Ok(())
330 }
331
332 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
336 crate::version::check_schema_version(self._schema_version)?;
337 let ft = self.flat_tree.as_ref().ok_or(ScryLearnError::NotFitted)?;
338 Ok(ft.predict(features))
339 }
340
341 pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
343 let ft = self.flat_tree.as_ref().ok_or(ScryLearnError::NotFitted)?;
344 let n_classes = self.n_classes;
345 Ok(features
346 .iter()
347 .map(|row| ft.predict_proba_sample(row, n_classes))
348 .collect())
349 }
350
351 pub fn feature_importances(&self) -> Result<Vec<f64>> {
353 if self.flat_tree.is_none() {
354 return Err(ScryLearnError::NotFitted);
355 }
356 Ok(self.feature_importances_.clone())
357 }
358
359 pub fn flat_tree(&self) -> Option<&FlatTree> {
361 self.flat_tree.as_ref()
362 }
363
364 pub fn depth(&self) -> usize {
366 self.flat_tree.as_ref().map_or(0, FlatTree::depth)
367 }
368
369 pub fn n_leaves(&self) -> usize {
371 self.flat_tree.as_ref().map_or(0, FlatTree::n_leaves)
372 }
373
374 pub fn n_features(&self) -> usize {
376 self.n_features
377 }
378
379 pub fn n_classes(&self) -> usize {
381 self.n_classes
382 }
383
384 pub fn cost_complexity_pruning_path(&self, data: &Dataset) -> Result<(Vec<f64>, Vec<f64>)> {
393 let mut unpruned = self.clone();
395 unpruned.ccp_alpha = 0.0;
396 unpruned.fit(data)?;
397
398 let indices: Vec<usize> = (0..data.n_samples()).collect();
400 let sorted_by_feature = presort_indices(data, &indices);
401 let n = indices.len();
402 let mut feature_buf = Vec::with_capacity(unpruned.n_features);
403 let mut split_rng = crate::rng::FastRng::new(0);
404
405 let tree = if unpruned.sample_weights.is_some() {
406 unpruned.build_tree_weighted(
407 data,
408 sorted_by_feature,
409 n,
410 0,
411 &mut feature_buf,
412 &mut split_rng,
413 )
414 } else {
415 unpruned.build_tree(
416 data,
417 sorted_by_feature,
418 n,
419 0,
420 &mut feature_buf,
421 &mut split_rng,
422 )
423 };
424 Ok(tree.cost_complexity_pruning_path())
425 }
426
427 fn build_tree(
436 &mut self,
437 data: &Dataset,
438 sorted_by_feature: Vec<Vec<usize>>,
439 n_root_samples: usize,
440 depth: usize,
441 feature_buf: &mut Vec<usize>,
442 split_rng: &mut crate::rng::FastRng,
443 ) -> TreeNode {
444 let active = &sorted_by_feature[0];
445 let n_actual = active.len();
446
447 let mut class_counts = vec![0usize; self.n_classes];
449 for &idx in active {
450 let c = data.target[idx] as usize;
451 if c < self.n_classes {
452 class_counts[c] += 1;
453 }
454 }
455 let impurity = compute_impurity(&class_counts, n_actual, self.criterion);
456
457 let max_depth_reached = self.max_depth.is_some_and(|d| depth >= d);
459 let too_few_samples = n_actual < self.min_samples_split;
460 let is_pure = impurity < 1e-12;
461
462 if max_depth_reached || too_few_samples || is_pure {
463 return TreeNode::Leaf {
464 prediction: majority_class(&class_counts),
465 n_samples: n_actual,
466 class_counts,
467 impurity,
468 };
469 }
470
471 let best = self.find_best_split(
473 data,
474 &sorted_by_feature,
475 &class_counts,
476 n_actual,
477 feature_buf,
478 split_rng,
479 );
480
481 let node_prediction = majority_class(&class_counts);
482
483 match best {
484 None => TreeNode::Leaf {
485 prediction: node_prediction,
486 n_samples: n_actual,
487 class_counts,
488 impurity,
489 },
490 Some(split) => {
491 let col = &data.features[split.feature_idx];
492 let threshold = split.threshold;
493
494 let mut left_count = 0usize;
496 let mut right_count = 0usize;
497 for &idx in active {
498 if col[idx] <= threshold {
499 left_count += 1;
500 } else {
501 right_count += 1;
502 }
503 }
504
505 if left_count < self.min_samples_leaf || right_count < self.min_samples_leaf {
506 return TreeNode::Leaf {
507 prediction: node_prediction,
508 n_samples: n_actual,
509 class_counts,
510 impurity,
511 };
512 }
513
514 let weighted_impurity_decrease = (n_actual as f64 / n_root_samples as f64)
516 * (impurity - split.impurity_decrease);
517 self.feature_importances_[split.feature_idx] += weighted_impurity_decrease.max(0.0);
518
519 let (left_sorted, right_sorted) =
521 partition_sorted(sorted_by_feature, col, threshold, left_count, right_count);
522
523 let left = self.build_tree(
524 data,
525 left_sorted,
526 n_root_samples,
527 depth + 1,
528 feature_buf,
529 split_rng,
530 );
531 let right = self.build_tree(
532 data,
533 right_sorted,
534 n_root_samples,
535 depth + 1,
536 feature_buf,
537 split_rng,
538 );
539
540 TreeNode::Split {
541 feature_idx: split.feature_idx,
542 threshold,
543 left: Box::new(left),
544 right: Box::new(right),
545 n_samples: n_actual,
546 impurity,
547 class_counts,
548 prediction: node_prediction,
549 }
550 }
551 }
552 }
553
554 fn find_best_split(
556 &self,
557 data: &Dataset,
558 sorted_by_feature: &[Vec<usize>],
559 parent_counts: &[usize],
560 n_parent: usize,
561 feature_buf: &mut Vec<usize>,
562 split_rng: &mut crate::rng::FastRng,
563 ) -> Option<BestSplit> {
564 let n_features = data.n_features();
565 let mut best: Option<BestSplit> = None;
566
567 fill_feature_buf(feature_buf, n_features, self.max_features, split_rng);
568
569 for &feat_idx in feature_buf.iter() {
570 let col = &data.features[feat_idx];
571 let sorted = &sorted_by_feature[feat_idx];
572
573 let mut left_counts = vec![0usize; self.n_classes];
574 let mut left_n = 0;
575 let mut prev_val = f64::NEG_INFINITY;
576
577 for &idx in sorted {
578 let val = col[idx];
579
580 if left_n > 0 && (val - prev_val).abs() > 1e-12 {
582 let right_n = n_parent - left_n;
583 if left_n >= self.min_samples_leaf && right_n >= self.min_samples_leaf {
584 let right_counts: Vec<usize> = parent_counts
585 .iter()
586 .zip(left_counts.iter())
587 .map(|(&p, &l)| p - l)
588 .collect();
589
590 let left_imp = compute_impurity(&left_counts, left_n, self.criterion);
591 let right_imp = compute_impurity(&right_counts, right_n, self.criterion);
592 let weighted_imp = (left_n as f64 * left_imp + right_n as f64 * right_imp)
593 / n_parent as f64;
594
595 let threshold = f64::midpoint(prev_val, val);
596
597 let is_better = best
598 .as_ref()
599 .is_none_or(|b| weighted_imp < b.impurity_decrease);
600
601 if is_better {
602 best = Some(BestSplit {
603 feature_idx: feat_idx,
604 threshold,
605 impurity_decrease: weighted_imp,
606 });
607 }
608 }
609 }
610
611 let class = data.target[idx] as usize;
613 if class < self.n_classes {
614 left_counts[class] += 1;
615 }
616 left_n += 1;
617 prev_val = val;
618 }
619 }
620
621 best
622 }
623
624 fn build_tree_weighted(
630 &mut self,
631 data: &Dataset,
632 sorted_by_feature: Vec<Vec<usize>>,
633 n_root_samples: usize,
634 depth: usize,
635 feature_buf: &mut Vec<usize>,
636 split_rng: &mut crate::rng::FastRng,
637 ) -> TreeNode {
638 let weights = self.sample_weights.as_ref().expect("weights must be set");
639 let active = &sorted_by_feature[0];
640 let n_actual = active.len();
641
642 let mut w_counts = vec![0.0_f64; self.n_classes];
644 let mut w_total = 0.0_f64;
645 let mut class_counts = vec![0usize; self.n_classes];
646
647 for &idx in active {
648 let c = data.target[idx] as usize;
649 let w = weights[idx];
650 if c < self.n_classes {
651 w_counts[c] += w;
652 class_counts[c] += 1;
653 }
654 w_total += w;
655 }
656
657 let impurity = compute_impurity_weighted(&w_counts, w_total, self.criterion);
658
659 let max_depth_reached = self.max_depth.is_some_and(|d| depth >= d);
661 let too_few_samples = n_actual < self.min_samples_split;
662 let is_pure = impurity < 1e-12;
663
664 if max_depth_reached || too_few_samples || is_pure {
665 return TreeNode::Leaf {
666 prediction: weighted_majority_class(&w_counts),
667 n_samples: n_actual,
668 class_counts,
669 impurity,
670 };
671 }
672
673 let best = self.find_best_split_weighted(
674 data,
675 &sorted_by_feature,
676 &w_counts,
677 w_total,
678 n_actual,
679 feature_buf,
680 split_rng,
681 );
682
683 let node_prediction = weighted_majority_class(&w_counts);
684
685 match best {
686 None => TreeNode::Leaf {
687 prediction: node_prediction,
688 n_samples: n_actual,
689 class_counts,
690 impurity,
691 },
692 Some(split) => {
693 let col = &data.features[split.feature_idx];
694 let threshold = split.threshold;
695
696 let mut left_count = 0usize;
697 let mut right_count = 0usize;
698 for &idx in active {
699 if col[idx] <= threshold {
700 left_count += 1;
701 } else {
702 right_count += 1;
703 }
704 }
705
706 if left_count < self.min_samples_leaf || right_count < self.min_samples_leaf {
707 return TreeNode::Leaf {
708 prediction: node_prediction,
709 n_samples: n_actual,
710 class_counts,
711 impurity,
712 };
713 }
714
715 let weighted_impurity_decrease = (n_actual as f64 / n_root_samples as f64)
717 * (impurity - split.impurity_decrease);
718 self.feature_importances_[split.feature_idx] += weighted_impurity_decrease.max(0.0);
719
720 let (left_sorted, right_sorted) =
721 partition_sorted(sorted_by_feature, col, threshold, left_count, right_count);
722
723 let left = self.build_tree_weighted(
724 data,
725 left_sorted,
726 n_root_samples,
727 depth + 1,
728 feature_buf,
729 split_rng,
730 );
731 let right = self.build_tree_weighted(
732 data,
733 right_sorted,
734 n_root_samples,
735 depth + 1,
736 feature_buf,
737 split_rng,
738 );
739
740 TreeNode::Split {
741 feature_idx: split.feature_idx,
742 threshold,
743 left: Box::new(left),
744 right: Box::new(right),
745 n_samples: n_actual,
746 impurity,
747 class_counts,
748 prediction: node_prediction,
749 }
750 }
751 }
752 }
753
754 fn find_best_split_weighted(
756 &self,
757 data: &Dataset,
758 sorted_by_feature: &[Vec<usize>],
759 parent_w_counts: &[f64],
760 w_parent_total: f64,
761 n_parent: usize,
762 feature_buf: &mut Vec<usize>,
763 split_rng: &mut crate::rng::FastRng,
764 ) -> Option<BestSplit> {
765 let weights = self.sample_weights.as_ref().expect("weights must be set");
766 let n_features = data.n_features();
767 let mut best: Option<BestSplit> = None;
768
769 fill_feature_buf(feature_buf, n_features, self.max_features, split_rng);
770
771 for &feat_idx in feature_buf.iter() {
772 let col = &data.features[feat_idx];
773 let sorted = &sorted_by_feature[feat_idx];
774
775 let mut left_w_counts = vec![0.0_f64; self.n_classes];
776 let mut left_w_total = 0.0_f64;
777 let mut left_n = 0usize;
778 let mut prev_val = f64::NEG_INFINITY;
779
780 for &idx in sorted {
781 let val = col[idx];
782 let w = weights[idx];
783
784 if left_n > 0 && (val - prev_val).abs() > 1e-12 {
785 let right_n = n_parent - left_n;
786 if left_n >= self.min_samples_leaf && right_n >= self.min_samples_leaf {
787 let right_w_total = w_parent_total - left_w_total;
788 let right_w_counts: Vec<f64> = parent_w_counts
789 .iter()
790 .zip(left_w_counts.iter())
791 .map(|(&p, &l)| (p - l).max(0.0))
792 .collect();
793
794 let left_imp =
795 compute_impurity_weighted(&left_w_counts, left_w_total, self.criterion);
796 let right_imp = compute_impurity_weighted(
797 &right_w_counts,
798 right_w_total,
799 self.criterion,
800 );
801 let weighted_imp =
802 (left_w_total * left_imp + right_w_total * right_imp) / w_parent_total;
803
804 let threshold = f64::midpoint(prev_val, val);
805
806 let is_better = best
807 .as_ref()
808 .is_none_or(|b| weighted_imp < b.impurity_decrease);
809
810 if is_better {
811 best = Some(BestSplit {
812 feature_idx: feat_idx,
813 threshold,
814 impurity_decrease: weighted_imp,
815 });
816 }
817 }
818 }
819
820 let class = data.target[idx] as usize;
822 if class < self.n_classes {
823 left_w_counts[class] += w;
824 }
825 left_w_total += w;
826 left_n += 1;
827 prev_val = val;
828 }
829 }
830
831 best
832 }
833}
834
835impl Default for DecisionTreeClassifier {
836 fn default() -> Self {
837 Self::new()
838 }
839}
840
841#[derive(Clone)]
847#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
848#[non_exhaustive]
849pub struct DecisionTreeRegressor {
850 max_depth: Option<usize>,
851 min_samples_split: usize,
852 min_samples_leaf: usize,
853 max_features: Option<usize>,
854 ccp_alpha: f64,
855 pub(crate) flat_tree: Option<FlatTree>,
857 n_features: usize,
858 pub(crate) feature_importances_: Vec<f64>,
859 #[cfg_attr(feature = "serde", serde(default))]
860 _schema_version: u32,
861}
862
863impl DecisionTreeRegressor {
864 pub fn new() -> Self {
866 Self {
867 max_depth: None,
868 min_samples_split: 2,
869 min_samples_leaf: 1,
870 max_features: None,
871 ccp_alpha: 0.0,
872 flat_tree: None,
873 n_features: 0,
874 feature_importances_: Vec::new(),
875 _schema_version: crate::version::SCHEMA_VERSION,
876 }
877 }
878
879 pub fn max_depth(mut self, d: usize) -> Self {
881 self.max_depth = Some(d);
882 self
883 }
884
885 pub fn min_samples_split(mut self, n: usize) -> Self {
887 self.min_samples_split = n;
888 self
889 }
890
891 pub fn min_samples_leaf(mut self, n: usize) -> Self {
893 self.min_samples_leaf = n;
894 self
895 }
896
897 pub fn max_features(mut self, n: usize) -> Self {
899 self.max_features = Some(n);
900 self
901 }
902
903 pub fn ccp_alpha(mut self, alpha: f64) -> Self {
909 self.ccp_alpha = alpha;
910 self
911 }
912
913 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
915 data.validate_finite()?;
916 let indices: Vec<usize> = (0..data.n_samples()).collect();
917 self.fit_on_indices(data, &indices)
918 }
919
920 pub(crate) fn fit_on_indices(
925 &mut self,
926 data: &Dataset,
927 sample_indices: &[usize],
928 ) -> Result<()> {
929 let n = sample_indices.len();
930 if n == 0 {
931 return Err(ScryLearnError::EmptyDataset);
932 }
933 self.n_features = data.n_features();
934 self.feature_importances_ = vec![0.0; self.n_features];
935
936 let sorted_by_feature = presort_indices(data, sample_indices);
937 let mut feature_buf = Vec::with_capacity(self.n_features);
938 let mut split_rng = crate::rng::FastRng::new(0);
939
940 let tree = self.build_tree_reg(
941 data,
942 sorted_by_feature,
943 n,
944 0,
945 &mut feature_buf,
946 &mut split_rng,
947 );
948
949 let tree = if self.ccp_alpha > 0.0 {
951 tree.prune_ccp(self.ccp_alpha)
952 } else {
953 tree
954 };
955
956 let flat = FlatTree::from_tree_node(&tree, 0);
959 self.flat_tree = Some(flat);
960
961 let total: f64 = self.feature_importances_.iter().sum();
962 if total > 0.0 {
963 for imp in &mut self.feature_importances_ {
964 *imp /= total;
965 }
966 }
967 Ok(())
968 }
969
970 pub(crate) fn fit_on_indices_presorted(
975 &mut self,
976 data: &Dataset,
977 sample_indices: &[usize],
978 global_sorted: &[Vec<usize>],
979 ) -> Result<()> {
980 let n = sample_indices.len();
981 if n == 0 {
982 return Err(ScryLearnError::EmptyDataset);
983 }
984 self.n_features = data.n_features();
985 self.feature_importances_ = vec![0.0; self.n_features];
986
987 let membership_len = global_sorted.first().map_or(0, Vec::len);
989 let mut membership = vec![false; membership_len];
990 for &i in sample_indices {
991 membership[i] = true;
992 }
993 let sorted_by_feature = filter_sorted(global_sorted, &membership);
994 let mut feature_buf = Vec::with_capacity(self.n_features);
995 let mut split_rng = crate::rng::FastRng::new(0);
996
997 let tree = self.build_tree_reg(
998 data,
999 sorted_by_feature,
1000 n,
1001 0,
1002 &mut feature_buf,
1003 &mut split_rng,
1004 );
1005
1006 let tree = if self.ccp_alpha > 0.0 {
1007 tree.prune_ccp(self.ccp_alpha)
1008 } else {
1009 tree
1010 };
1011
1012 let flat = FlatTree::from_tree_node(&tree, 0);
1013 self.flat_tree = Some(flat);
1014
1015 let total: f64 = self.feature_importances_.iter().sum();
1016 if total > 0.0 {
1017 for imp in &mut self.feature_importances_ {
1018 *imp /= total;
1019 }
1020 }
1021 Ok(())
1022 }
1023
1024 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
1026 crate::version::check_schema_version(self._schema_version)?;
1027 let ft = self.flat_tree.as_ref().ok_or(ScryLearnError::NotFitted)?;
1028 Ok(ft.predict(features))
1029 }
1030
1031 pub fn feature_importances(&self) -> Result<Vec<f64>> {
1033 if self.flat_tree.is_none() {
1034 return Err(ScryLearnError::NotFitted);
1035 }
1036 Ok(self.feature_importances_.clone())
1037 }
1038
1039 pub fn flat_tree(&self) -> Option<&FlatTree> {
1041 self.flat_tree.as_ref()
1042 }
1043
1044 pub fn n_features(&self) -> usize {
1046 self.n_features
1047 }
1048
1049 fn build_tree_reg(
1051 &mut self,
1052 data: &Dataset,
1053 sorted_by_feature: Vec<Vec<usize>>,
1054 n_root_samples: usize,
1055 depth: usize,
1056 feature_buf: &mut Vec<usize>,
1057 split_rng: &mut crate::rng::FastRng,
1058 ) -> TreeNode {
1059 let active = &sorted_by_feature[0];
1060 let n_actual = active.len();
1061
1062 if n_actual == 0 {
1063 return TreeNode::Leaf {
1064 prediction: 0.0,
1065 n_samples: 0,
1066 class_counts: Vec::new(),
1067 impurity: 0.0,
1068 };
1069 }
1070
1071 let mut sum = 0.0;
1073 let mut sq_sum = 0.0;
1074 for &idx in active {
1075 let v = data.target[idx];
1076 sum += v;
1077 sq_sum += v * v;
1078 }
1079 let mean = sum / n_actual as f64;
1080 let mse = (sq_sum / n_actual as f64 - mean * mean).max(0.0);
1083
1084 let max_depth_reached = self.max_depth.is_some_and(|d| depth >= d);
1085 let too_few = n_actual < self.min_samples_split;
1086
1087 if max_depth_reached || too_few || mse < 1e-12 {
1088 return TreeNode::Leaf {
1089 prediction: mean,
1090 n_samples: n_actual,
1091 class_counts: Vec::new(),
1092 impurity: mse,
1093 };
1094 }
1095
1096 let best = self.find_best_split_reg(
1097 data,
1098 &sorted_by_feature,
1099 sum,
1100 sq_sum,
1101 n_actual,
1102 feature_buf,
1103 split_rng,
1104 );
1105
1106 match best {
1107 None => TreeNode::Leaf {
1108 prediction: mean,
1109 n_samples: n_actual,
1110 class_counts: Vec::new(),
1111 impurity: mse,
1112 },
1113 Some(split) => {
1114 let col = &data.features[split.feature_idx];
1115 let threshold = split.threshold;
1116
1117 let mut left_count = 0usize;
1118 let mut right_count = 0usize;
1119 for &idx in active {
1120 if col[idx] <= threshold {
1121 left_count += 1;
1122 } else {
1123 right_count += 1;
1124 }
1125 }
1126
1127 if left_count < self.min_samples_leaf || right_count < self.min_samples_leaf {
1128 return TreeNode::Leaf {
1129 prediction: mean,
1130 n_samples: n_actual,
1131 class_counts: Vec::new(),
1132 impurity: mse,
1133 };
1134 }
1135
1136 let decrease =
1137 (n_actual as f64 / n_root_samples as f64) * (mse - split.impurity_decrease);
1138 self.feature_importances_[split.feature_idx] += decrease.max(0.0);
1139
1140 let (left_sorted, right_sorted) =
1141 partition_sorted(sorted_by_feature, col, threshold, left_count, right_count);
1142
1143 let left = self.build_tree_reg(
1144 data,
1145 left_sorted,
1146 n_root_samples,
1147 depth + 1,
1148 feature_buf,
1149 split_rng,
1150 );
1151 let right = self.build_tree_reg(
1152 data,
1153 right_sorted,
1154 n_root_samples,
1155 depth + 1,
1156 feature_buf,
1157 split_rng,
1158 );
1159
1160 TreeNode::Split {
1161 feature_idx: split.feature_idx,
1162 threshold,
1163 left: Box::new(left),
1164 right: Box::new(right),
1165 n_samples: n_actual,
1166 impurity: mse,
1167 class_counts: Vec::new(),
1168 prediction: mean,
1169 }
1170 }
1171 }
1172 }
1173
1174 fn find_best_split_reg(
1176 &self,
1177 data: &Dataset,
1178 sorted_by_feature: &[Vec<usize>],
1179 total_sum: f64,
1180 total_sq: f64,
1181 n_parent: usize,
1182 feature_buf: &mut Vec<usize>,
1183 split_rng: &mut crate::rng::FastRng,
1184 ) -> Option<BestSplit> {
1185 let n_features = data.n_features();
1186 let mut best: Option<BestSplit> = None;
1187
1188 fill_feature_buf(feature_buf, n_features, self.max_features, split_rng);
1189
1190 for &feat_idx in feature_buf.iter() {
1191 let col = &data.features[feat_idx];
1192 let sorted = &sorted_by_feature[feat_idx];
1193
1194 let mut left_sum = 0.0;
1195 let mut left_sq_sum = 0.0;
1196 let mut left_n = 0usize;
1197 let mut prev_val = f64::NEG_INFINITY;
1198
1199 for &idx in sorted {
1200 let feat_val = col[idx];
1201
1202 if left_n > 0 && (feat_val - prev_val).abs() > 1e-12 {
1204 let right_n = n_parent - left_n;
1205 if left_n >= self.min_samples_leaf && right_n >= self.min_samples_leaf {
1206 let left_mse = (left_sq_sum / left_n as f64
1207 - (left_sum / left_n as f64).powi(2))
1208 .max(0.0);
1209 let right_sum = total_sum - left_sum;
1210 let right_sq = total_sq - left_sq_sum;
1211 let right_mse = (right_sq / right_n as f64
1212 - (right_sum / right_n as f64).powi(2))
1213 .max(0.0);
1214
1215 let weighted = (left_n as f64 * left_mse + right_n as f64 * right_mse)
1216 / n_parent as f64;
1217
1218 let threshold = f64::midpoint(prev_val, feat_val);
1219
1220 let is_better =
1221 best.as_ref().is_none_or(|b| weighted < b.impurity_decrease);
1222 if is_better {
1223 best = Some(BestSplit {
1224 feature_idx: feat_idx,
1225 threshold,
1226 impurity_decrease: weighted,
1227 });
1228 }
1229 }
1230 }
1231
1232 let target_val = data.target[idx];
1233 left_sum += target_val;
1234 left_sq_sum += target_val * target_val;
1235 left_n += 1;
1236 prev_val = feat_val;
1237 }
1238 }
1239 best
1240 }
1241}
1242
1243impl Default for DecisionTreeRegressor {
1244 fn default() -> Self {
1245 Self::new()
1246 }
1247}