Skip to main content

scry_learn/tree/cart/
mod.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! CART (Classification And Regression Trees) implementation.
3//!
4//! Implements the full CART algorithm with Gini impurity, entropy,
5//! and MSE split criteria. Supports feature bagging for Random Forest.
6//!
7//! Trees are built recursively using `TreeNode`, then flattened into a
8//! contiguous `FlatTree` (`Vec<FlatNode>`) for cache-optimal prediction.
9
10mod builder;
11mod flat;
12mod node;
13
14pub(crate) use builder::presort_indices;
15pub use builder::{DecisionTreeClassifier, DecisionTreeRegressor};
16pub use flat::FlatTree;
17pub use node::TreeNode;
18
19// ---------------------------------------------------------------------------
20// Split criterion
21// ---------------------------------------------------------------------------
22
23/// Split quality criterion.
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[non_exhaustive]
27pub enum SplitCriterion {
28    /// Gini impurity: `1 - Σ(pᵢ²)`.
29    Gini,
30    /// Information entropy: `-Σ(pᵢ log₂ pᵢ)`.
31    Entropy,
32    /// Mean squared error (for regression).
33    Mse,
34}
35
36/// Leaf sentinel — stored in `FlatNode::right` to indicate a leaf node.
37pub(crate) const LEAF_SENTINEL: u32 = u32::MAX;
38
39// ---------------------------------------------------------------------------
40// Helpers
41// ---------------------------------------------------------------------------
42
43pub(super) struct BestSplit {
44    pub(super) feature_idx: usize,
45    pub(super) threshold: f64,
46    pub(super) impurity_decrease: f64,
47}
48
49pub(super) fn compute_impurity(counts: &[usize], n: usize, criterion: SplitCriterion) -> f64 {
50    if n == 0 {
51        return 0.0;
52    }
53    let n_f = n as f64;
54    match criterion {
55        SplitCriterion::Gini => {
56            let sum_sq: f64 = counts
57                .iter()
58                .map(|&c| {
59                    let p = c as f64 / n_f;
60                    p * p
61                })
62                .sum();
63            1.0 - sum_sq
64        }
65        SplitCriterion::Entropy => {
66            let mut entropy = 0.0;
67            for &c in counts {
68                if c > 0 {
69                    let p = c as f64 / n_f;
70                    entropy -= p * p.log2();
71                }
72            }
73            entropy
74        }
75        SplitCriterion::Mse => {
76            // MSE is not applicable for class counts — used only in regressor.
77            0.0
78        }
79    }
80}
81
82pub(super) fn majority_class(counts: &[usize]) -> f64 {
83    counts
84        .iter()
85        .enumerate()
86        .max_by_key(|&(_, &count)| count)
87        .map_or(0.0, |(idx, _)| idx as f64)
88}
89
90// ---------------------------------------------------------------------------
91// Weighted impurity helpers (for class_weight support)
92// ---------------------------------------------------------------------------
93
94pub(super) fn compute_impurity_weighted(
95    counts: &[f64],
96    total: f64,
97    criterion: SplitCriterion,
98) -> f64 {
99    if total < 1e-12 {
100        return 0.0;
101    }
102    match criterion {
103        SplitCriterion::Gini => {
104            let sum_sq: f64 = counts
105                .iter()
106                .map(|&c| {
107                    let p = c / total;
108                    p * p
109                })
110                .sum();
111            1.0 - sum_sq
112        }
113        SplitCriterion::Entropy => {
114            let mut entropy = 0.0;
115            for &c in counts {
116                if c > 1e-12 {
117                    let p = c / total;
118                    entropy -= p * p.log2();
119                }
120            }
121            entropy
122        }
123        SplitCriterion::Mse => 0.0,
124    }
125}
126
127pub(super) fn weighted_majority_class(counts: &[f64]) -> f64 {
128    counts
129        .iter()
130        .enumerate()
131        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
132        .map_or(0.0, |(idx, _)| idx as f64)
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::dataset::Dataset;
139
140    fn make_linearly_separable() -> Dataset {
141        // Class 0: x < 5, Class 1: x >= 5
142        let features = vec![(0..20).map(|i| i as f64).collect()];
143        let target: Vec<f64> = (0..20).map(|i| if i < 10 { 0.0 } else { 1.0 }).collect();
144        Dataset::new(features, target, vec!["x".into()], "class")
145    }
146
147    #[test]
148    fn test_decision_tree_perfect_split() {
149        let data = make_linearly_separable();
150        let mut dt = DecisionTreeClassifier::new();
151        dt.fit(&data).unwrap();
152
153        let matrix = data.feature_matrix();
154        let preds = dt.predict(&matrix).unwrap();
155        let acc = preds
156            .iter()
157            .zip(data.target.iter())
158            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
159            .count() as f64
160            / data.n_samples() as f64;
161
162        assert!(
163            acc >= 0.95,
164            "expected ≥95% accuracy on linearly separable data, got {:.1}%",
165            acc * 100.0
166        );
167    }
168
169    #[test]
170    fn test_feature_importance_sums_to_one() {
171        let data = make_linearly_separable();
172        let mut dt = DecisionTreeClassifier::new();
173        dt.fit(&data).unwrap();
174
175        let importances = dt.feature_importances().unwrap();
176        let total: f64 = importances.iter().sum();
177        assert!(
178            (total - 1.0).abs() < 1e-6,
179            "feature importances should sum to 1.0, got {total}"
180        );
181    }
182
183    #[test]
184    fn test_max_depth() {
185        let data = make_linearly_separable();
186        let mut dt = DecisionTreeClassifier::new().max_depth(2);
187        dt.fit(&data).unwrap();
188        assert!(dt.depth() <= 2 + 1); // depth includes leaf
189    }
190
191    #[test]
192    fn test_predict_proba() {
193        let data = make_linearly_separable();
194        let mut dt = DecisionTreeClassifier::new();
195        dt.fit(&data).unwrap();
196
197        let sample_class0 = vec![2.0]; // clearly class 0
198        let proba = dt.predict_proba(&[sample_class0]).unwrap();
199        assert!(proba[0][0] > 0.5, "should predict class 0 with >50%");
200    }
201
202    #[test]
203    fn test_regressor_basic() {
204        // y = x
205        let features = vec![(0..50).map(|i| i as f64).collect()];
206        let target: Vec<f64> = (0..50).map(|i| i as f64).collect();
207        let data = Dataset::new(features, target, vec!["x".into()], "y");
208
209        let mut dt = DecisionTreeRegressor::new().max_depth(10);
210        dt.fit(&data).unwrap();
211
212        let matrix = data.feature_matrix();
213        let preds = dt.predict(&matrix).unwrap();
214
215        // Should get low MSE on training data.
216        let mse: f64 = preds
217            .iter()
218            .zip(data.target.iter())
219            .map(|(p, t)| (p - t).powi(2))
220            .sum::<f64>()
221            / data.n_samples() as f64;
222
223        assert!(mse < 5.0, "MSE on training data should be low, got {mse}");
224    }
225
226    #[test]
227    fn test_not_fitted_error() {
228        let dt = DecisionTreeClassifier::new();
229        assert!(dt.predict(&[vec![1.0]]).is_err());
230    }
231
232    // -------------------------------------------------------------------
233    // Cost-complexity pruning tests
234    // -------------------------------------------------------------------
235
236    fn make_iris_like() -> Dataset {
237        // A small 3-class dataset with enough samples to build a deep tree.
238        let mut rng = crate::rng::FastRng::new(42);
239        let n = 150;
240        let mut f1 = Vec::with_capacity(n);
241        let mut f2 = Vec::with_capacity(n);
242        let mut target = Vec::with_capacity(n);
243        for _ in 0..50 {
244            f1.push(rng.f64() * 2.0);
245            f2.push(rng.f64() * 2.0);
246            target.push(0.0);
247        }
248        for _ in 0..50 {
249            f1.push(rng.f64() * 2.0 + 3.0);
250            f2.push(rng.f64() * 2.0 + 3.0);
251            target.push(1.0);
252        }
253        for _ in 0..50 {
254            f1.push(rng.f64() * 2.0 + 6.0);
255            f2.push(rng.f64() * 2.0);
256            target.push(2.0);
257        }
258        Dataset::new(
259            vec![f1, f2],
260            target,
261            vec!["f1".into(), "f2".into()],
262            "class",
263        )
264    }
265
266    #[test]
267    fn test_ccp_alpha_reduces_depth() {
268        let data = make_iris_like();
269
270        let mut dt_full = DecisionTreeClassifier::new();
271        dt_full.fit(&data).unwrap();
272        let depth_full = dt_full.depth();
273        let leaves_full = dt_full.n_leaves();
274
275        let mut dt_pruned = DecisionTreeClassifier::new().ccp_alpha(0.02);
276        dt_pruned.fit(&data).unwrap();
277        let depth_pruned = dt_pruned.depth();
278        let leaves_pruned = dt_pruned.n_leaves();
279
280        eprintln!("Full tree: depth={depth_full}, leaves={leaves_full}");
281        eprintln!("Pruned tree: depth={depth_pruned}, leaves={leaves_pruned}");
282
283        assert!(
284            leaves_pruned <= leaves_full,
285            "Pruned tree should have ≤ leaves than full: {leaves_pruned} vs {leaves_full}"
286        );
287    }
288
289    #[test]
290    fn test_ccp_alpha_zero_no_change() {
291        let data = make_iris_like();
292
293        let mut dt_zero = DecisionTreeClassifier::new().ccp_alpha(0.0);
294        dt_zero.fit(&data).unwrap();
295        let mut dt_default = DecisionTreeClassifier::new();
296        dt_default.fit(&data).unwrap();
297
298        assert_eq!(
299            dt_zero.n_leaves(),
300            dt_default.n_leaves(),
301            "ccp_alpha=0.0 should not change the tree"
302        );
303    }
304
305    #[test]
306    fn test_ccp_alpha_large_collapses_to_root() {
307        let data = make_iris_like();
308        let mut dt = DecisionTreeClassifier::new().ccp_alpha(1000.0);
309        dt.fit(&data).unwrap();
310        assert_eq!(
311            dt.n_leaves(),
312            1,
313            "Very large ccp_alpha should collapse to a single leaf"
314        );
315    }
316
317    #[test]
318    fn test_regressor_ccp_alpha() {
319        let features = vec![(0..100).map(|i| i as f64).collect()];
320        let target: Vec<f64> = (0..100).map(|i| (i as f64).sin()).collect();
321        let data = Dataset::new(features, target, vec!["x".into()], "y");
322
323        let mut dt_full = DecisionTreeRegressor::new();
324        dt_full.fit(&data).unwrap();
325
326        let mut dt_pruned = DecisionTreeRegressor::new().ccp_alpha(0.01);
327        dt_pruned.fit(&data).unwrap();
328
329        let full_leaves = dt_full.flat_tree().unwrap().n_leaves();
330        let pruned_leaves = dt_pruned.flat_tree().unwrap().n_leaves();
331
332        eprintln!("Regressor: full={full_leaves} leaves, pruned={pruned_leaves} leaves");
333        assert!(
334            pruned_leaves <= full_leaves,
335            "Pruned regressor should have ≤ leaves: {pruned_leaves} vs {full_leaves}"
336        );
337    }
338
339    #[test]
340    fn test_pruning_path_monotonic() {
341        let data = make_iris_like();
342        let mut dt = DecisionTreeClassifier::new();
343        dt.fit(&data).unwrap();
344
345        let (alphas, impurities) = dt.cost_complexity_pruning_path(&data).unwrap();
346
347        assert!(alphas.len() >= 2, "Should have at least 2 pruning steps");
348        // Alphas should be monotonically non-decreasing.
349        for w in alphas.windows(2) {
350            assert!(
351                w[1] >= w[0] - 1e-12,
352                "Alphas should be monotonically non-decreasing: {} -> {}",
353                w[0],
354                w[1]
355            );
356        }
357        // Impurities should be monotonically non-decreasing.
358        for w in impurities.windows(2) {
359            assert!(
360                w[1] >= w[0] - 1e-12,
361                "Impurities should be non-decreasing: {} -> {}",
362                w[0],
363                w[1]
364            );
365        }
366        eprintln!("Pruning path: {} steps", alphas.len());
367        eprintln!("Alphas: {:?}", &alphas[..alphas.len().min(5)]);
368    }
369}