Skip to main content

scry_learn/tree/cart/
builder.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Decision tree classifier and regressor implementations.
3//!
4//! Contains the CART tree-building algorithm with pre-sorted indices,
5//! feature bagging, class weighting, and cost-complexity pruning.
6
7use crate::dataset::Dataset;
8use crate::error::{Result, ScryLearnError};
9use crate::weights::{compute_sample_weights, ClassWeight};
10
11use super::{
12    compute_impurity, compute_impurity_weighted, majority_class, weighted_majority_class,
13    BestSplit, FlatTree, SplitCriterion, TreeNode,
14};
15
16// ---------------------------------------------------------------------------
17// Helpers
18// ---------------------------------------------------------------------------
19
20/// Pre-sort sample indices by each feature value. O(n·log n) per feature.
21pub(crate) fn presort_indices(data: &Dataset, indices: &[usize]) -> Vec<Vec<usize>> {
22    let n_features = data.n_features();
23    let mut sorted_by_feature = Vec::with_capacity(n_features);
24    for feat_idx in 0..n_features {
25        let col = &data.features[feat_idx];
26        let mut sorted = indices.to_vec();
27        sorted.sort_unstable_by(|&a, &b| {
28            col[a]
29                .partial_cmp(&col[b])
30                .unwrap_or(std::cmp::Ordering::Equal)
31        });
32        sorted_by_feature.push(sorted);
33    }
34    sorted_by_feature
35}
36
37/// Filter global sorted arrays to only include member indices.
38fn filter_sorted(global_sorted: &[Vec<usize>], membership: &[bool]) -> Vec<Vec<usize>> {
39    global_sorted
40        .iter()
41        .map(|gs| gs.iter().copied().filter(|&idx| membership[idx]).collect())
42        .collect()
43}
44
45/// Partition sorted arrays into left/right based on a split decision.
46/// Preserves sort order within each partition.
47///
48/// Reuses the parent's Vec allocations for the left child (in-place
49/// stable partition), only allocating new Vecs for the right child.
50/// This halves the number of heap allocations during tree building.
51fn partition_sorted(
52    mut sorted_by_feature: Vec<Vec<usize>>,
53    split_col: &[f64],
54    threshold: f64,
55    _left_count: usize,
56    right_count: usize,
57) -> (Vec<Vec<usize>>, Vec<Vec<usize>>) {
58    let n_feat = sorted_by_feature.len();
59    let mut right_sorted = Vec::with_capacity(n_feat);
60    for feat_sorted in &mut sorted_by_feature {
61        let mut right = Vec::with_capacity(right_count);
62        let mut write = 0;
63        for read in 0..feat_sorted.len() {
64            let idx = feat_sorted[read];
65            if split_col[idx] <= threshold {
66                feat_sorted[write] = idx;
67                write += 1;
68            } else {
69                right.push(idx);
70            }
71        }
72        feat_sorted.truncate(write);
73        right_sorted.push(right);
74    }
75    (sorted_by_feature, right_sorted)
76}
77
78/// Populate the feature buffer with indices, optionally shuffled for feature bagging.
79///
80/// When `max_features` is set, uses `rng` to select a random subset via
81/// partial Fisher-Yates shuffle. The caller must supply a mutable RNG whose
82/// state advances between calls so that each split considers a *different*
83/// random feature subset (critical for Random Forest decorrelation).
84fn fill_feature_buf(
85    feature_buf: &mut Vec<usize>,
86    n_features: usize,
87    max_features: Option<usize>,
88    rng: &mut crate::rng::FastRng,
89) {
90    feature_buf.clear();
91    feature_buf.extend(0..n_features);
92    if let Some(max_f) = max_features {
93        let m = max_f.min(n_features);
94        for i in 0..m {
95            let j = rng.usize(i..n_features);
96            feature_buf.swap(i, j);
97        }
98        feature_buf.truncate(m);
99    }
100}
101
102// ---------------------------------------------------------------------------
103// Decision Tree Classifier
104// ---------------------------------------------------------------------------
105
106/// CART decision tree for classification.
107#[derive(Clone)]
108#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
109#[non_exhaustive]
110pub struct DecisionTreeClassifier {
111    max_depth: Option<usize>,
112    min_samples_split: usize,
113    min_samples_leaf: usize,
114    max_features: Option<usize>,
115    criterion: SplitCriterion,
116    ccp_alpha: f64,
117    /// Class weighting strategy for imbalanced datasets.
118    pub(crate) class_weight: ClassWeight,
119    /// Per-sample weights computed from `class_weight` during fit.
120    pub(crate) sample_weights: Option<Vec<f64>>,
121    /// Flattened tree for cache-optimal prediction.
122    pub(crate) flat_tree: Option<FlatTree>,
123    n_classes: usize,
124    n_features: usize,
125    pub(crate) feature_importances_: Vec<f64>,
126    #[cfg_attr(feature = "serde", serde(default))]
127    _schema_version: u32,
128}
129
130impl DecisionTreeClassifier {
131    /// Create a new classifier with default parameters.
132    pub fn new() -> Self {
133        Self {
134            max_depth: None,
135            min_samples_split: 2,
136            min_samples_leaf: 1,
137            max_features: None,
138            criterion: SplitCriterion::Gini,
139            ccp_alpha: 0.0,
140            class_weight: ClassWeight::Uniform,
141            sample_weights: None,
142            flat_tree: None,
143            n_classes: 0,
144            n_features: 0,
145            feature_importances_: Vec::new(),
146            _schema_version: crate::version::SCHEMA_VERSION,
147        }
148    }
149
150    /// Set maximum tree depth.
151    pub fn max_depth(mut self, d: usize) -> Self {
152        self.max_depth = Some(d);
153        self
154    }
155
156    /// Set minimum samples required to split an internal node.
157    pub fn min_samples_split(mut self, n: usize) -> Self {
158        self.min_samples_split = n;
159        self
160    }
161
162    /// Set minimum samples required in a leaf node.
163    pub fn min_samples_leaf(mut self, n: usize) -> Self {
164        self.min_samples_leaf = n;
165        self
166    }
167
168    /// Set maximum features to consider per split (for random forest).
169    pub fn max_features(mut self, n: usize) -> Self {
170        self.max_features = Some(n);
171        self
172    }
173
174    /// Set the split criterion.
175    pub fn criterion(mut self, c: SplitCriterion) -> Self {
176        self.criterion = c;
177        self
178    }
179
180    /// Set class weighting strategy for imbalanced datasets.
181    ///
182    /// When set to [`ClassWeight::Balanced`], minority classes receive
183    /// higher weight in impurity calculations, improving their recall.
184    ///
185    /// # Example
186    /// ```
187    /// use scry_learn::tree::DecisionTreeClassifier;
188    /// use scry_learn::weights::ClassWeight;
189    ///
190    /// let dt = DecisionTreeClassifier::new()
191    ///     .class_weight(ClassWeight::Balanced);
192    /// ```
193    pub fn class_weight(mut self, cw: ClassWeight) -> Self {
194        self.class_weight = cw;
195        self
196    }
197
198    /// Set cost-complexity pruning parameter.
199    ///
200    /// Subtrees with effective alpha ≤ `ccp_alpha` are pruned after
201    /// tree construction. A value of 0.0 (default) disables pruning.
202    /// Larger values produce smaller, more regularized trees.
203    ///
204    /// # Example
205    /// ```
206    /// use scry_learn::tree::DecisionTreeClassifier;
207    ///
208    /// let dt = DecisionTreeClassifier::new()
209    ///     .ccp_alpha(0.01);
210    /// ```
211    pub fn ccp_alpha(mut self, alpha: f64) -> Self {
212        self.ccp_alpha = alpha;
213        self
214    }
215
216    /// Train the decision tree on a dataset.
217    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
218        data.validate_finite()?;
219        let indices: Vec<usize> = (0..data.n_samples()).collect();
220        self.fit_on_indices(data, &indices)
221    }
222
223    /// Train the decision tree on a dataset using a subset of sample indices.
224    ///
225    /// This is the production path used by Random Forest — avoids copying
226    /// the dataset by training directly on indices into the original data.
227    ///
228    /// Internally uses pre-sorted indices: sorts once at the root (O(n·log n)),
229    /// then partitions at each node (O(n) per feature per node) — matching
230    /// scikit-learn's optimized CART implementation.
231    pub(crate) fn fit_on_indices(
232        &mut self,
233        data: &Dataset,
234        sample_indices: &[usize],
235    ) -> Result<()> {
236        let sorted_by_feature = presort_indices(data, sample_indices);
237        self.fit_with_sorted(data, sample_indices, sorted_by_feature)
238    }
239
240    /// Train using pre-sorted indices shared across trees (RF memory optimization).
241    ///
242    /// `global_sorted` contains ALL dataset indices sorted by each feature.
243    /// Filters to only the bootstrap sample indices, then builds the tree
244    /// using partitioned sorted arrays.
245    pub(crate) fn fit_on_indices_presorted(
246        &mut self,
247        data: &Dataset,
248        sample_indices: &[usize],
249        global_sorted: &[Vec<usize>],
250    ) -> Result<()> {
251        // Filter global sorted arrays to only include bootstrap sample indices.
252        let membership_len = global_sorted.first().map_or(0, Vec::len);
253        let mut membership = vec![false; membership_len];
254        for &i in sample_indices {
255            membership[i] = true;
256        }
257        let sorted_by_feature = filter_sorted(global_sorted, &membership);
258        self.fit_with_sorted(data, sample_indices, sorted_by_feature)
259    }
260
261    /// Internal: fit using pre-filtered, per-node sorted arrays.
262    fn fit_with_sorted(
263        &mut self,
264        data: &Dataset,
265        sample_indices: &[usize],
266        sorted_by_feature: Vec<Vec<usize>>,
267    ) -> Result<()> {
268        let n = sample_indices.len();
269        if n == 0 {
270            return Err(ScryLearnError::EmptyDataset);
271        }
272
273        self.n_features = data.n_features();
274        self.n_classes = data.n_classes();
275        self.feature_importances_ = vec![0.0; self.n_features];
276
277        // Compute per-sample weights if class_weight is non-uniform.
278        let weights = match &self.class_weight {
279            ClassWeight::Uniform => None,
280            cw => Some(compute_sample_weights(&data.target, cw)),
281        };
282        self.sample_weights = weights;
283
284        let mut feature_buf = Vec::with_capacity(self.n_features);
285        let mut split_rng = crate::rng::FastRng::new(0);
286
287        let tree = if self.sample_weights.is_some() {
288            self.build_tree_weighted(
289                data,
290                sorted_by_feature,
291                n,
292                0,
293                &mut feature_buf,
294                &mut split_rng,
295            )
296        } else {
297            self.build_tree(
298                data,
299                sorted_by_feature,
300                n,
301                0,
302                &mut feature_buf,
303                &mut split_rng,
304            )
305        };
306
307        // Apply cost-complexity pruning if requested.
308        let tree = if self.ccp_alpha > 0.0 {
309            tree.prune_ccp(self.ccp_alpha)
310        } else {
311            tree
312        };
313
314        // Flatten recursive tree into contiguous array for prediction.
315        let flat = FlatTree::from_tree_node(&tree, self.n_classes);
316        self.flat_tree = Some(flat);
317
318        // Normalize feature importances to sum to 1.
319        let total: f64 = self.feature_importances_.iter().sum();
320        if total > 0.0 {
321            for imp in &mut self.feature_importances_ {
322                *imp /= total;
323            }
324        }
325
326        // Free training-only data.
327        self.sample_weights = None;
328
329        Ok(())
330    }
331
332    /// Predict class labels for a feature matrix.
333    ///
334    /// `features` is row-major: `features[sample_idx][feature_idx]`.
335    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
336        crate::version::check_schema_version(self._schema_version)?;
337        let ft = self.flat_tree.as_ref().ok_or(ScryLearnError::NotFitted)?;
338        Ok(ft.predict(features))
339    }
340
341    /// Predict class probabilities for a feature matrix.
342    pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
343        let ft = self.flat_tree.as_ref().ok_or(ScryLearnError::NotFitted)?;
344        let n_classes = self.n_classes;
345        Ok(features
346            .iter()
347            .map(|row| ft.predict_proba_sample(row, n_classes))
348            .collect())
349    }
350
351    /// Get feature importances (sum of weighted impurity decreases).
352    pub fn feature_importances(&self) -> Result<Vec<f64>> {
353        if self.flat_tree.is_none() {
354            return Err(ScryLearnError::NotFitted);
355        }
356        Ok(self.feature_importances_.clone())
357    }
358
359    /// Get the flat tree (for direct access).
360    pub fn flat_tree(&self) -> Option<&FlatTree> {
361        self.flat_tree.as_ref()
362    }
363
364    /// Tree depth.
365    pub fn depth(&self) -> usize {
366        self.flat_tree.as_ref().map_or(0, FlatTree::depth)
367    }
368
369    /// Number of leaf nodes.
370    pub fn n_leaves(&self) -> usize {
371        self.flat_tree.as_ref().map_or(0, FlatTree::n_leaves)
372    }
373
374    /// Number of features the model was trained on.
375    pub fn n_features(&self) -> usize {
376        self.n_features
377    }
378
379    /// Number of classes.
380    pub fn n_classes(&self) -> usize {
381        self.n_classes
382    }
383
384    /// Compute the cost-complexity pruning path for this classifier.
385    ///
386    /// Trains an unpruned tree, then returns `(ccp_alphas, total_impurities)`,
387    /// a sequence of effective alpha values and the corresponding total tree
388    /// impurity at each pruning step. Useful for selecting `ccp_alpha` via
389    /// the elbow method.
390    ///
391    /// The classifier must be fitted before calling this method.
392    pub fn cost_complexity_pruning_path(&self, data: &Dataset) -> Result<(Vec<f64>, Vec<f64>)> {
393        // Build an unpruned tree to compute the path.
394        let mut unpruned = self.clone();
395        unpruned.ccp_alpha = 0.0;
396        unpruned.fit(data)?;
397
398        // Rebuild the recursive tree from the dataset to get the path.
399        let indices: Vec<usize> = (0..data.n_samples()).collect();
400        let sorted_by_feature = presort_indices(data, &indices);
401        let n = indices.len();
402        let mut feature_buf = Vec::with_capacity(unpruned.n_features);
403        let mut split_rng = crate::rng::FastRng::new(0);
404
405        let tree = if unpruned.sample_weights.is_some() {
406            unpruned.build_tree_weighted(
407                data,
408                sorted_by_feature,
409                n,
410                0,
411                &mut feature_buf,
412                &mut split_rng,
413            )
414        } else {
415            unpruned.build_tree(
416                data,
417                sorted_by_feature,
418                n,
419                0,
420                &mut feature_buf,
421                &mut split_rng,
422            )
423        };
424        Ok(tree.cost_complexity_pruning_path())
425    }
426
427    // -----------------------------------------------------------------------
428    // Recursive tree building (unweighted)
429    // -----------------------------------------------------------------------
430
431    /// Build tree using partitioned sorted arrays.
432    ///
433    /// `sorted_by_feature[feat_idx]` contains only this node's sample indices,
434    /// sorted by that feature's value. No membership bitset needed.
435    fn build_tree(
436        &mut self,
437        data: &Dataset,
438        sorted_by_feature: Vec<Vec<usize>>,
439        n_root_samples: usize,
440        depth: usize,
441        feature_buf: &mut Vec<usize>,
442        split_rng: &mut crate::rng::FastRng,
443    ) -> TreeNode {
444        let active = &sorted_by_feature[0];
445        let n_actual = active.len();
446
447        // Collect class counts.
448        let mut class_counts = vec![0usize; self.n_classes];
449        for &idx in active {
450            let c = data.target[idx] as usize;
451            if c < self.n_classes {
452                class_counts[c] += 1;
453            }
454        }
455        let impurity = compute_impurity(&class_counts, n_actual, self.criterion);
456
457        // Check stopping conditions.
458        let max_depth_reached = self.max_depth.is_some_and(|d| depth >= d);
459        let too_few_samples = n_actual < self.min_samples_split;
460        let is_pure = impurity < 1e-12;
461
462        if max_depth_reached || too_few_samples || is_pure {
463            return TreeNode::Leaf {
464                prediction: majority_class(&class_counts),
465                n_samples: n_actual,
466                class_counts,
467                impurity,
468            };
469        }
470
471        // Find best split.
472        let best = self.find_best_split(
473            data,
474            &sorted_by_feature,
475            &class_counts,
476            n_actual,
477            feature_buf,
478            split_rng,
479        );
480
481        let node_prediction = majority_class(&class_counts);
482
483        match best {
484            None => TreeNode::Leaf {
485                prediction: node_prediction,
486                n_samples: n_actual,
487                class_counts,
488                impurity,
489            },
490            Some(split) => {
491                let col = &data.features[split.feature_idx];
492                let threshold = split.threshold;
493
494                // Count left/right.
495                let mut left_count = 0usize;
496                let mut right_count = 0usize;
497                for &idx in active {
498                    if col[idx] <= threshold {
499                        left_count += 1;
500                    } else {
501                        right_count += 1;
502                    }
503                }
504
505                if left_count < self.min_samples_leaf || right_count < self.min_samples_leaf {
506                    return TreeNode::Leaf {
507                        prediction: node_prediction,
508                        n_samples: n_actual,
509                        class_counts,
510                        impurity,
511                    };
512                }
513
514                // Record feature importance.
515                let weighted_impurity_decrease = (n_actual as f64 / n_root_samples as f64)
516                    * (impurity - split.impurity_decrease);
517                self.feature_importances_[split.feature_idx] += weighted_impurity_decrease.max(0.0);
518
519                // Partition sorted arrays into left/right children.
520                let (left_sorted, right_sorted) =
521                    partition_sorted(sorted_by_feature, col, threshold, left_count, right_count);
522
523                let left = self.build_tree(
524                    data,
525                    left_sorted,
526                    n_root_samples,
527                    depth + 1,
528                    feature_buf,
529                    split_rng,
530                );
531                let right = self.build_tree(
532                    data,
533                    right_sorted,
534                    n_root_samples,
535                    depth + 1,
536                    feature_buf,
537                    split_rng,
538                );
539
540                TreeNode::Split {
541                    feature_idx: split.feature_idx,
542                    threshold,
543                    left: Box::new(left),
544                    right: Box::new(right),
545                    n_samples: n_actual,
546                    impurity,
547                    class_counts,
548                    prediction: node_prediction,
549                }
550            }
551        }
552    }
553
554    /// Find the best split by scanning sorted arrays — O(n) per feature.
555    fn find_best_split(
556        &self,
557        data: &Dataset,
558        sorted_by_feature: &[Vec<usize>],
559        parent_counts: &[usize],
560        n_parent: usize,
561        feature_buf: &mut Vec<usize>,
562        split_rng: &mut crate::rng::FastRng,
563    ) -> Option<BestSplit> {
564        let n_features = data.n_features();
565        let mut best: Option<BestSplit> = None;
566
567        fill_feature_buf(feature_buf, n_features, self.max_features, split_rng);
568
569        for &feat_idx in feature_buf.iter() {
570            let col = &data.features[feat_idx];
571            let sorted = &sorted_by_feature[feat_idx];
572
573            let mut left_counts = vec![0usize; self.n_classes];
574            let mut left_n = 0;
575            let mut prev_val = f64::NEG_INFINITY;
576
577            for &idx in sorted {
578                let val = col[idx];
579
580                // Check threshold between previous and current value.
581                if left_n > 0 && (val - prev_val).abs() > 1e-12 {
582                    let right_n = n_parent - left_n;
583                    if left_n >= self.min_samples_leaf && right_n >= self.min_samples_leaf {
584                        let right_counts: Vec<usize> = parent_counts
585                            .iter()
586                            .zip(left_counts.iter())
587                            .map(|(&p, &l)| p - l)
588                            .collect();
589
590                        let left_imp = compute_impurity(&left_counts, left_n, self.criterion);
591                        let right_imp = compute_impurity(&right_counts, right_n, self.criterion);
592                        let weighted_imp = (left_n as f64 * left_imp + right_n as f64 * right_imp)
593                            / n_parent as f64;
594
595                        let threshold = f64::midpoint(prev_val, val);
596
597                        let is_better = best
598                            .as_ref()
599                            .is_none_or(|b| weighted_imp < b.impurity_decrease);
600
601                        if is_better {
602                            best = Some(BestSplit {
603                                feature_idx: feat_idx,
604                                threshold,
605                                impurity_decrease: weighted_imp,
606                            });
607                        }
608                    }
609                }
610
611                // Add current sample to left side.
612                let class = data.target[idx] as usize;
613                if class < self.n_classes {
614                    left_counts[class] += 1;
615                }
616                left_n += 1;
617                prev_val = val;
618            }
619        }
620
621        best
622    }
623
624    // -------------------------------------------------------------------
625    // Weighted tree building (class_weight support)
626    // -------------------------------------------------------------------
627
628    /// Build tree using partitioned sorted arrays with per-sample weights.
629    fn build_tree_weighted(
630        &mut self,
631        data: &Dataset,
632        sorted_by_feature: Vec<Vec<usize>>,
633        n_root_samples: usize,
634        depth: usize,
635        feature_buf: &mut Vec<usize>,
636        split_rng: &mut crate::rng::FastRng,
637    ) -> TreeNode {
638        let weights = self.sample_weights.as_ref().expect("weights must be set");
639        let active = &sorted_by_feature[0];
640        let n_actual = active.len();
641
642        // Collect weighted and unweighted class counts.
643        let mut w_counts = vec![0.0_f64; self.n_classes];
644        let mut w_total = 0.0_f64;
645        let mut class_counts = vec![0usize; self.n_classes];
646
647        for &idx in active {
648            let c = data.target[idx] as usize;
649            let w = weights[idx];
650            if c < self.n_classes {
651                w_counts[c] += w;
652                class_counts[c] += 1;
653            }
654            w_total += w;
655        }
656
657        let impurity = compute_impurity_weighted(&w_counts, w_total, self.criterion);
658
659        // Check stopping conditions.
660        let max_depth_reached = self.max_depth.is_some_and(|d| depth >= d);
661        let too_few_samples = n_actual < self.min_samples_split;
662        let is_pure = impurity < 1e-12;
663
664        if max_depth_reached || too_few_samples || is_pure {
665            return TreeNode::Leaf {
666                prediction: weighted_majority_class(&w_counts),
667                n_samples: n_actual,
668                class_counts,
669                impurity,
670            };
671        }
672
673        let best = self.find_best_split_weighted(
674            data,
675            &sorted_by_feature,
676            &w_counts,
677            w_total,
678            n_actual,
679            feature_buf,
680            split_rng,
681        );
682
683        let node_prediction = weighted_majority_class(&w_counts);
684
685        match best {
686            None => TreeNode::Leaf {
687                prediction: node_prediction,
688                n_samples: n_actual,
689                class_counts,
690                impurity,
691            },
692            Some(split) => {
693                let col = &data.features[split.feature_idx];
694                let threshold = split.threshold;
695
696                let mut left_count = 0usize;
697                let mut right_count = 0usize;
698                for &idx in active {
699                    if col[idx] <= threshold {
700                        left_count += 1;
701                    } else {
702                        right_count += 1;
703                    }
704                }
705
706                if left_count < self.min_samples_leaf || right_count < self.min_samples_leaf {
707                    return TreeNode::Leaf {
708                        prediction: node_prediction,
709                        n_samples: n_actual,
710                        class_counts,
711                        impurity,
712                    };
713                }
714
715                // Record feature importance.
716                let weighted_impurity_decrease = (n_actual as f64 / n_root_samples as f64)
717                    * (impurity - split.impurity_decrease);
718                self.feature_importances_[split.feature_idx] += weighted_impurity_decrease.max(0.0);
719
720                let (left_sorted, right_sorted) =
721                    partition_sorted(sorted_by_feature, col, threshold, left_count, right_count);
722
723                let left = self.build_tree_weighted(
724                    data,
725                    left_sorted,
726                    n_root_samples,
727                    depth + 1,
728                    feature_buf,
729                    split_rng,
730                );
731                let right = self.build_tree_weighted(
732                    data,
733                    right_sorted,
734                    n_root_samples,
735                    depth + 1,
736                    feature_buf,
737                    split_rng,
738                );
739
740                TreeNode::Split {
741                    feature_idx: split.feature_idx,
742                    threshold,
743                    left: Box::new(left),
744                    right: Box::new(right),
745                    n_samples: n_actual,
746                    impurity,
747                    class_counts,
748                    prediction: node_prediction,
749                }
750            }
751        }
752    }
753
754    /// Weighted variant of find_best_split — uses f64 weighted counts.
755    fn find_best_split_weighted(
756        &self,
757        data: &Dataset,
758        sorted_by_feature: &[Vec<usize>],
759        parent_w_counts: &[f64],
760        w_parent_total: f64,
761        n_parent: usize,
762        feature_buf: &mut Vec<usize>,
763        split_rng: &mut crate::rng::FastRng,
764    ) -> Option<BestSplit> {
765        let weights = self.sample_weights.as_ref().expect("weights must be set");
766        let n_features = data.n_features();
767        let mut best: Option<BestSplit> = None;
768
769        fill_feature_buf(feature_buf, n_features, self.max_features, split_rng);
770
771        for &feat_idx in feature_buf.iter() {
772            let col = &data.features[feat_idx];
773            let sorted = &sorted_by_feature[feat_idx];
774
775            let mut left_w_counts = vec![0.0_f64; self.n_classes];
776            let mut left_w_total = 0.0_f64;
777            let mut left_n = 0usize;
778            let mut prev_val = f64::NEG_INFINITY;
779
780            for &idx in sorted {
781                let val = col[idx];
782                let w = weights[idx];
783
784                if left_n > 0 && (val - prev_val).abs() > 1e-12 {
785                    let right_n = n_parent - left_n;
786                    if left_n >= self.min_samples_leaf && right_n >= self.min_samples_leaf {
787                        let right_w_total = w_parent_total - left_w_total;
788                        let right_w_counts: Vec<f64> = parent_w_counts
789                            .iter()
790                            .zip(left_w_counts.iter())
791                            .map(|(&p, &l)| (p - l).max(0.0))
792                            .collect();
793
794                        let left_imp =
795                            compute_impurity_weighted(&left_w_counts, left_w_total, self.criterion);
796                        let right_imp = compute_impurity_weighted(
797                            &right_w_counts,
798                            right_w_total,
799                            self.criterion,
800                        );
801                        let weighted_imp =
802                            (left_w_total * left_imp + right_w_total * right_imp) / w_parent_total;
803
804                        let threshold = f64::midpoint(prev_val, val);
805
806                        let is_better = best
807                            .as_ref()
808                            .is_none_or(|b| weighted_imp < b.impurity_decrease);
809
810                        if is_better {
811                            best = Some(BestSplit {
812                                feature_idx: feat_idx,
813                                threshold,
814                                impurity_decrease: weighted_imp,
815                            });
816                        }
817                    }
818                }
819
820                // Add current sample to left side.
821                let class = data.target[idx] as usize;
822                if class < self.n_classes {
823                    left_w_counts[class] += w;
824                }
825                left_w_total += w;
826                left_n += 1;
827                prev_val = val;
828            }
829        }
830
831        best
832    }
833}
834
835impl Default for DecisionTreeClassifier {
836    fn default() -> Self {
837        Self::new()
838    }
839}
840
841// ---------------------------------------------------------------------------
842// Decision Tree Regressor
843// ---------------------------------------------------------------------------
844
845/// CART decision tree for regression.
846#[derive(Clone)]
847#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
848#[non_exhaustive]
849pub struct DecisionTreeRegressor {
850    max_depth: Option<usize>,
851    min_samples_split: usize,
852    min_samples_leaf: usize,
853    max_features: Option<usize>,
854    ccp_alpha: f64,
855    /// Flattened tree for cache-optimal prediction.
856    pub(crate) flat_tree: Option<FlatTree>,
857    n_features: usize,
858    pub(crate) feature_importances_: Vec<f64>,
859    #[cfg_attr(feature = "serde", serde(default))]
860    _schema_version: u32,
861}
862
863impl DecisionTreeRegressor {
864    /// Create a new regressor with default parameters.
865    pub fn new() -> Self {
866        Self {
867            max_depth: None,
868            min_samples_split: 2,
869            min_samples_leaf: 1,
870            max_features: None,
871            ccp_alpha: 0.0,
872            flat_tree: None,
873            n_features: 0,
874            feature_importances_: Vec::new(),
875            _schema_version: crate::version::SCHEMA_VERSION,
876        }
877    }
878
879    /// Set maximum tree depth.
880    pub fn max_depth(mut self, d: usize) -> Self {
881        self.max_depth = Some(d);
882        self
883    }
884
885    /// Set minimum samples required to split.
886    pub fn min_samples_split(mut self, n: usize) -> Self {
887        self.min_samples_split = n;
888        self
889    }
890
891    /// Set minimum samples required in a leaf.
892    pub fn min_samples_leaf(mut self, n: usize) -> Self {
893        self.min_samples_leaf = n;
894        self
895    }
896
897    /// Set maximum features per split (for random forest).
898    pub fn max_features(mut self, n: usize) -> Self {
899        self.max_features = Some(n);
900        self
901    }
902
903    /// Set cost-complexity pruning parameter.
904    ///
905    /// Subtrees with effective alpha ≤ `ccp_alpha` are pruned after
906    /// tree construction. A value of 0.0 (default) disables pruning.
907    /// Larger values produce smaller, more regularized trees.
908    pub fn ccp_alpha(mut self, alpha: f64) -> Self {
909        self.ccp_alpha = alpha;
910        self
911    }
912
913    /// Train on a dataset.
914    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
915        data.validate_finite()?;
916        let indices: Vec<usize> = (0..data.n_samples()).collect();
917        self.fit_on_indices(data, &indices)
918    }
919
920    /// Train on a dataset using a subset of sample indices.
921    ///
922    /// Production path for Random Forest — trains directly on indices
923    /// into the original data, avoiding dataset copies.
924    pub(crate) fn fit_on_indices(
925        &mut self,
926        data: &Dataset,
927        sample_indices: &[usize],
928    ) -> Result<()> {
929        let n = sample_indices.len();
930        if n == 0 {
931            return Err(ScryLearnError::EmptyDataset);
932        }
933        self.n_features = data.n_features();
934        self.feature_importances_ = vec![0.0; self.n_features];
935
936        let sorted_by_feature = presort_indices(data, sample_indices);
937        let mut feature_buf = Vec::with_capacity(self.n_features);
938        let mut split_rng = crate::rng::FastRng::new(0);
939
940        let tree = self.build_tree_reg(
941            data,
942            sorted_by_feature,
943            n,
944            0,
945            &mut feature_buf,
946            &mut split_rng,
947        );
948
949        // Apply cost-complexity pruning if requested.
950        let tree = if self.ccp_alpha > 0.0 {
951            tree.prune_ccp(self.ccp_alpha)
952        } else {
953            tree
954        };
955
956        // Flatten recursive tree into contiguous array for prediction.
957        // Regression trees don't need class probabilities (n_classes=0).
958        let flat = FlatTree::from_tree_node(&tree, 0);
959        self.flat_tree = Some(flat);
960
961        let total: f64 = self.feature_importances_.iter().sum();
962        if total > 0.0 {
963            for imp in &mut self.feature_importances_ {
964                *imp /= total;
965            }
966        }
967        Ok(())
968    }
969
970    /// Train using pre-sorted indices (GBT/RF optimization — sort once, reuse each round).
971    ///
972    /// `global_sorted` contains ALL dataset indices sorted by each feature.
973    /// Filters to only the requested sample indices, then builds the tree.
974    pub(crate) fn fit_on_indices_presorted(
975        &mut self,
976        data: &Dataset,
977        sample_indices: &[usize],
978        global_sorted: &[Vec<usize>],
979    ) -> Result<()> {
980        let n = sample_indices.len();
981        if n == 0 {
982            return Err(ScryLearnError::EmptyDataset);
983        }
984        self.n_features = data.n_features();
985        self.feature_importances_ = vec![0.0; self.n_features];
986
987        // Filter global sorted arrays to only include requested sample indices.
988        let membership_len = global_sorted.first().map_or(0, Vec::len);
989        let mut membership = vec![false; membership_len];
990        for &i in sample_indices {
991            membership[i] = true;
992        }
993        let sorted_by_feature = filter_sorted(global_sorted, &membership);
994        let mut feature_buf = Vec::with_capacity(self.n_features);
995        let mut split_rng = crate::rng::FastRng::new(0);
996
997        let tree = self.build_tree_reg(
998            data,
999            sorted_by_feature,
1000            n,
1001            0,
1002            &mut feature_buf,
1003            &mut split_rng,
1004        );
1005
1006        let tree = if self.ccp_alpha > 0.0 {
1007            tree.prune_ccp(self.ccp_alpha)
1008        } else {
1009            tree
1010        };
1011
1012        let flat = FlatTree::from_tree_node(&tree, 0);
1013        self.flat_tree = Some(flat);
1014
1015        let total: f64 = self.feature_importances_.iter().sum();
1016        if total > 0.0 {
1017            for imp in &mut self.feature_importances_ {
1018                *imp /= total;
1019            }
1020        }
1021        Ok(())
1022    }
1023
1024    /// Predict values.
1025    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
1026        crate::version::check_schema_version(self._schema_version)?;
1027        let ft = self.flat_tree.as_ref().ok_or(ScryLearnError::NotFitted)?;
1028        Ok(ft.predict(features))
1029    }
1030
1031    /// Feature importances.
1032    pub fn feature_importances(&self) -> Result<Vec<f64>> {
1033        if self.flat_tree.is_none() {
1034            return Err(ScryLearnError::NotFitted);
1035        }
1036        Ok(self.feature_importances_.clone())
1037    }
1038
1039    /// Get the flat tree.
1040    pub fn flat_tree(&self) -> Option<&FlatTree> {
1041        self.flat_tree.as_ref()
1042    }
1043
1044    /// Number of features.
1045    pub fn n_features(&self) -> usize {
1046        self.n_features
1047    }
1048
1049    /// Build tree using partitioned sorted arrays — no membership bitset.
1050    fn build_tree_reg(
1051        &mut self,
1052        data: &Dataset,
1053        sorted_by_feature: Vec<Vec<usize>>,
1054        n_root_samples: usize,
1055        depth: usize,
1056        feature_buf: &mut Vec<usize>,
1057        split_rng: &mut crate::rng::FastRng,
1058    ) -> TreeNode {
1059        let active = &sorted_by_feature[0];
1060        let n_actual = active.len();
1061
1062        if n_actual == 0 {
1063            return TreeNode::Leaf {
1064                prediction: 0.0,
1065                n_samples: 0,
1066                class_counts: Vec::new(),
1067                impurity: 0.0,
1068            };
1069        }
1070
1071        // Compute mean/MSE from active indices directly.
1072        let mut sum = 0.0;
1073        let mut sq_sum = 0.0;
1074        for &idx in active {
1075            let v = data.target[idx];
1076            sum += v;
1077            sq_sum += v * v;
1078        }
1079        let mean = sum / n_actual as f64;
1080        // Clamp to 0.0: the textbook formula E[X²]-E[X]² can go slightly
1081        // negative due to floating-point catastrophic cancellation.
1082        let mse = (sq_sum / n_actual as f64 - mean * mean).max(0.0);
1083
1084        let max_depth_reached = self.max_depth.is_some_and(|d| depth >= d);
1085        let too_few = n_actual < self.min_samples_split;
1086
1087        if max_depth_reached || too_few || mse < 1e-12 {
1088            return TreeNode::Leaf {
1089                prediction: mean,
1090                n_samples: n_actual,
1091                class_counts: Vec::new(),
1092                impurity: mse,
1093            };
1094        }
1095
1096        let best = self.find_best_split_reg(
1097            data,
1098            &sorted_by_feature,
1099            sum,
1100            sq_sum,
1101            n_actual,
1102            feature_buf,
1103            split_rng,
1104        );
1105
1106        match best {
1107            None => TreeNode::Leaf {
1108                prediction: mean,
1109                n_samples: n_actual,
1110                class_counts: Vec::new(),
1111                impurity: mse,
1112            },
1113            Some(split) => {
1114                let col = &data.features[split.feature_idx];
1115                let threshold = split.threshold;
1116
1117                let mut left_count = 0usize;
1118                let mut right_count = 0usize;
1119                for &idx in active {
1120                    if col[idx] <= threshold {
1121                        left_count += 1;
1122                    } else {
1123                        right_count += 1;
1124                    }
1125                }
1126
1127                if left_count < self.min_samples_leaf || right_count < self.min_samples_leaf {
1128                    return TreeNode::Leaf {
1129                        prediction: mean,
1130                        n_samples: n_actual,
1131                        class_counts: Vec::new(),
1132                        impurity: mse,
1133                    };
1134                }
1135
1136                let decrease =
1137                    (n_actual as f64 / n_root_samples as f64) * (mse - split.impurity_decrease);
1138                self.feature_importances_[split.feature_idx] += decrease.max(0.0);
1139
1140                let (left_sorted, right_sorted) =
1141                    partition_sorted(sorted_by_feature, col, threshold, left_count, right_count);
1142
1143                let left = self.build_tree_reg(
1144                    data,
1145                    left_sorted,
1146                    n_root_samples,
1147                    depth + 1,
1148                    feature_buf,
1149                    split_rng,
1150                );
1151                let right = self.build_tree_reg(
1152                    data,
1153                    right_sorted,
1154                    n_root_samples,
1155                    depth + 1,
1156                    feature_buf,
1157                    split_rng,
1158                );
1159
1160                TreeNode::Split {
1161                    feature_idx: split.feature_idx,
1162                    threshold,
1163                    left: Box::new(left),
1164                    right: Box::new(right),
1165                    n_samples: n_actual,
1166                    impurity: mse,
1167                    class_counts: Vec::new(),
1168                    prediction: mean,
1169                }
1170            }
1171        }
1172    }
1173
1174    /// Find best regression split using incremental variance — O(n) per feature.
1175    fn find_best_split_reg(
1176        &self,
1177        data: &Dataset,
1178        sorted_by_feature: &[Vec<usize>],
1179        total_sum: f64,
1180        total_sq: f64,
1181        n_parent: usize,
1182        feature_buf: &mut Vec<usize>,
1183        split_rng: &mut crate::rng::FastRng,
1184    ) -> Option<BestSplit> {
1185        let n_features = data.n_features();
1186        let mut best: Option<BestSplit> = None;
1187
1188        fill_feature_buf(feature_buf, n_features, self.max_features, split_rng);
1189
1190        for &feat_idx in feature_buf.iter() {
1191            let col = &data.features[feat_idx];
1192            let sorted = &sorted_by_feature[feat_idx];
1193
1194            let mut left_sum = 0.0;
1195            let mut left_sq_sum = 0.0;
1196            let mut left_n = 0usize;
1197            let mut prev_val = f64::NEG_INFINITY;
1198
1199            for &idx in sorted {
1200                let feat_val = col[idx];
1201
1202                // Check threshold between previous and current.
1203                if left_n > 0 && (feat_val - prev_val).abs() > 1e-12 {
1204                    let right_n = n_parent - left_n;
1205                    if left_n >= self.min_samples_leaf && right_n >= self.min_samples_leaf {
1206                        let left_mse = (left_sq_sum / left_n as f64
1207                            - (left_sum / left_n as f64).powi(2))
1208                        .max(0.0);
1209                        let right_sum = total_sum - left_sum;
1210                        let right_sq = total_sq - left_sq_sum;
1211                        let right_mse = (right_sq / right_n as f64
1212                            - (right_sum / right_n as f64).powi(2))
1213                        .max(0.0);
1214
1215                        let weighted = (left_n as f64 * left_mse + right_n as f64 * right_mse)
1216                            / n_parent as f64;
1217
1218                        let threshold = f64::midpoint(prev_val, feat_val);
1219
1220                        let is_better =
1221                            best.as_ref().is_none_or(|b| weighted < b.impurity_decrease);
1222                        if is_better {
1223                            best = Some(BestSplit {
1224                                feature_idx: feat_idx,
1225                                threshold,
1226                                impurity_decrease: weighted,
1227                            });
1228                        }
1229                    }
1230                }
1231
1232                let target_val = data.target[idx];
1233                left_sum += target_val;
1234                left_sq_sum += target_val * target_val;
1235                left_n += 1;
1236                prev_val = feat_val;
1237            }
1238        }
1239        best
1240    }
1241}
1242
1243impl Default for DecisionTreeRegressor {
1244    fn default() -> Self {
1245        Self::new()
1246    }
1247}