Skip to main content

scry_learn/tree/cart/
node.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Recursive tree node representation used during construction.
3//!
4//! `TreeNode` is the recursive enum used to build the tree, which is then
5//! flattened into a `FlatTree` for cache-optimal prediction.
6
7/// A node in the decision tree (recursive representation).
8///
9/// Used during tree construction, then flattened into a `FlatTree` for
10/// cache-optimal prediction. Exposed publicly for visualization.
11#[derive(Clone, Debug)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[non_exhaustive]
14pub enum TreeNode {
15    /// A leaf node — produces a prediction.
16    Leaf {
17        /// Predicted class (classification) or value (regression).
18        prediction: f64,
19        /// Number of training samples that reached this node.
20        n_samples: usize,
21        /// Class distribution at this node (classification only).
22        class_counts: Vec<usize>,
23        /// Impurity at this node.
24        impurity: f64,
25    },
26    /// An internal split node.
27    Split {
28        /// Index of the feature used for the split.
29        feature_idx: usize,
30        /// Threshold value: left if ≤ threshold, right if > threshold.
31        threshold: f64,
32        /// Left child (≤ threshold).
33        left: Box<TreeNode>,
34        /// Right child (> threshold).
35        right: Box<TreeNode>,
36        /// Number of training samples that reached this node.
37        n_samples: usize,
38        /// Impurity at this node (before split).
39        impurity: f64,
40        /// Class distribution at this node.
41        class_counts: Vec<usize>,
42        /// Majority class prediction at this node.
43        prediction: f64,
44    },
45}
46
47impl TreeNode {
48    /// Predict for a single sample by walking the tree.
49    pub fn predict(&self, sample: &[f64]) -> f64 {
50        match self {
51            TreeNode::Leaf { prediction, .. } => *prediction,
52            TreeNode::Split {
53                feature_idx,
54                threshold,
55                left,
56                right,
57                ..
58            } => {
59                if sample[*feature_idx] <= *threshold {
60                    left.predict(sample)
61                } else {
62                    right.predict(sample)
63                }
64            }
65        }
66    }
67
68    /// Get class probabilities for a single sample.
69    pub fn predict_proba(&self, sample: &[f64], n_classes: usize) -> Vec<f64> {
70        match self {
71            TreeNode::Leaf {
72                class_counts,
73                n_samples,
74                ..
75            } => {
76                let mut proba = vec![0.0; n_classes];
77                let total = *n_samples as f64;
78                for (i, &count) in class_counts.iter().enumerate() {
79                    if i < n_classes {
80                        proba[i] = count as f64 / total;
81                    }
82                }
83                proba
84            }
85            TreeNode::Split {
86                feature_idx,
87                threshold,
88                left,
89                right,
90                ..
91            } => {
92                if sample[*feature_idx] <= *threshold {
93                    left.predict_proba(sample, n_classes)
94                } else {
95                    right.predict_proba(sample, n_classes)
96                }
97            }
98        }
99    }
100
101    /// Depth of this subtree.
102    pub fn depth(&self) -> usize {
103        match self {
104            TreeNode::Leaf { .. } => 1,
105            TreeNode::Split { left, right, .. } => 1 + left.depth().max(right.depth()),
106        }
107    }
108
109    /// Number of leaf nodes in this subtree.
110    pub fn n_leaves(&self) -> usize {
111        match self {
112            TreeNode::Leaf { .. } => 1,
113            TreeNode::Split { left, right, .. } => left.n_leaves() + right.n_leaves(),
114        }
115    }
116
117    /// Number of samples at this node.
118    pub fn n_samples(&self) -> usize {
119        match self {
120            TreeNode::Leaf { n_samples, .. } | TreeNode::Split { n_samples, .. } => *n_samples,
121        }
122    }
123
124    /// Sum of weighted leaf impurities: Σ(impurity_leaf × n_samples_leaf).
125    ///
126    /// This is R(T_t) in the cost-complexity pruning literature.
127    pub fn total_leaf_impurity(&self) -> f64 {
128        match self {
129            TreeNode::Leaf {
130                impurity,
131                n_samples,
132                ..
133            } => *impurity * (*n_samples as f64),
134            TreeNode::Split { left, right, .. } => {
135                left.total_leaf_impurity() + right.total_leaf_impurity()
136            }
137        }
138    }
139
140    /// Minimal cost-complexity pruning (MCCP).
141    ///
142    /// Recursively prunes subtrees whose effective alpha is ≤ `ccp_alpha`.
143    /// Effective alpha = (R(t) - R(T_t)) / (|T_t| - 1), where R(t) is the
144    /// re-substitution error if this node were a leaf and R(T_t) is the
145    /// total leaf impurity of the subtree.
146    ///
147    /// This matches sklearn's `ccp_alpha` parameter behavior.
148    pub fn prune_ccp(self, ccp_alpha: f64) -> TreeNode {
149        match self {
150            TreeNode::Leaf { .. } => self,
151            TreeNode::Split {
152                feature_idx,
153                threshold,
154                left,
155                right,
156                n_samples,
157                impurity,
158                class_counts,
159                prediction,
160            } => {
161                // Recursively prune children first (bottom-up).
162                let pruned_left = left.prune_ccp(ccp_alpha);
163                let pruned_right = right.prune_ccp(ccp_alpha);
164
165                // Build the pruned split node to compute its subtree stats.
166                let subtree = TreeNode::Split {
167                    feature_idx,
168                    threshold,
169                    left: Box::new(pruned_left),
170                    right: Box::new(pruned_right),
171                    n_samples,
172                    impurity,
173                    class_counts: class_counts.clone(),
174                    prediction,
175                };
176
177                let n_leaves = subtree.n_leaves();
178                if n_leaves <= 1 {
179                    return subtree;
180                }
181
182                // R(t) = impurity if this node were a leaf.
183                let r_node = impurity * (n_samples as f64);
184                // R(T_t) = total leaf impurity of subtree.
185                let r_subtree = subtree.total_leaf_impurity();
186
187                let effective_alpha = (r_node - r_subtree) / (n_leaves as f64 - 1.0);
188
189                if effective_alpha <= ccp_alpha {
190                    // Collapse to leaf.
191                    TreeNode::Leaf {
192                        prediction,
193                        n_samples,
194                        class_counts,
195                        impurity,
196                    }
197                } else {
198                    subtree
199                }
200            }
201        }
202    }
203
204    /// Compute the cost-complexity pruning path.
205    ///
206    /// Returns `(ccp_alphas, total_impurities)` — a sequence of effective
207    /// alpha values and the corresponding total tree impurity at each
208    /// pruning step. Useful for elbow-method selection of `ccp_alpha`.
209    pub fn cost_complexity_pruning_path(&self) -> (Vec<f64>, Vec<f64>) {
210        let mut alphas = vec![0.0];
211        let mut impurities = vec![self.total_leaf_impurity()];
212
213        let mut current = self.clone();
214        loop {
215            // Find the minimum effective alpha across all internal nodes.
216            let min_alpha = Self::min_effective_alpha(&current);
217            match min_alpha {
218                None => break, // no more internal nodes
219                Some(alpha) => {
220                    current = current.prune_ccp(alpha);
221                    alphas.push(alpha);
222                    impurities.push(current.total_leaf_impurity());
223                }
224            }
225        }
226        (alphas, impurities)
227    }
228
229    /// Find the minimum effective alpha among all internal nodes.
230    fn min_effective_alpha(node: &TreeNode) -> Option<f64> {
231        match node {
232            TreeNode::Leaf { .. } => None,
233            TreeNode::Split {
234                left,
235                right,
236                n_samples,
237                impurity,
238                ..
239            } => {
240                let n_leaves = node.n_leaves();
241                let r_node = impurity * (*n_samples as f64);
242                let r_subtree = node.total_leaf_impurity();
243                let my_alpha = if n_leaves > 1 {
244                    Some((r_node - r_subtree) / (n_leaves as f64 - 1.0))
245                } else {
246                    None
247                };
248
249                let left_alpha = Self::min_effective_alpha(left);
250                let right_alpha = Self::min_effective_alpha(right);
251
252                [my_alpha, left_alpha, right_alpha]
253                    .iter()
254                    .filter_map(|a| *a)
255                    .reduce(f64::min)
256            }
257        }
258    }
259}