scry_learn/tree/cart/node.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Recursive tree node representation used during construction.
3//!
4//! `TreeNode` is the recursive enum used to build the tree, which is then
5//! flattened into a `FlatTree` for cache-optimal prediction.
6
7/// A node in the decision tree (recursive representation).
8///
9/// Used during tree construction, then flattened into a `FlatTree` for
10/// cache-optimal prediction. Exposed publicly for visualization.
11#[derive(Clone, Debug)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[non_exhaustive]
14pub enum TreeNode {
15 /// A leaf node — produces a prediction.
16 Leaf {
17 /// Predicted class (classification) or value (regression).
18 prediction: f64,
19 /// Number of training samples that reached this node.
20 n_samples: usize,
21 /// Class distribution at this node (classification only).
22 class_counts: Vec<usize>,
23 /// Impurity at this node.
24 impurity: f64,
25 },
26 /// An internal split node.
27 Split {
28 /// Index of the feature used for the split.
29 feature_idx: usize,
30 /// Threshold value: left if ≤ threshold, right if > threshold.
31 threshold: f64,
32 /// Left child (≤ threshold).
33 left: Box<TreeNode>,
34 /// Right child (> threshold).
35 right: Box<TreeNode>,
36 /// Number of training samples that reached this node.
37 n_samples: usize,
38 /// Impurity at this node (before split).
39 impurity: f64,
40 /// Class distribution at this node.
41 class_counts: Vec<usize>,
42 /// Majority class prediction at this node.
43 prediction: f64,
44 },
45}
46
47impl TreeNode {
48 /// Predict for a single sample by walking the tree.
49 pub fn predict(&self, sample: &[f64]) -> f64 {
50 match self {
51 TreeNode::Leaf { prediction, .. } => *prediction,
52 TreeNode::Split {
53 feature_idx,
54 threshold,
55 left,
56 right,
57 ..
58 } => {
59 if sample[*feature_idx] <= *threshold {
60 left.predict(sample)
61 } else {
62 right.predict(sample)
63 }
64 }
65 }
66 }
67
68 /// Get class probabilities for a single sample.
69 pub fn predict_proba(&self, sample: &[f64], n_classes: usize) -> Vec<f64> {
70 match self {
71 TreeNode::Leaf {
72 class_counts,
73 n_samples,
74 ..
75 } => {
76 let mut proba = vec![0.0; n_classes];
77 let total = *n_samples as f64;
78 for (i, &count) in class_counts.iter().enumerate() {
79 if i < n_classes {
80 proba[i] = count as f64 / total;
81 }
82 }
83 proba
84 }
85 TreeNode::Split {
86 feature_idx,
87 threshold,
88 left,
89 right,
90 ..
91 } => {
92 if sample[*feature_idx] <= *threshold {
93 left.predict_proba(sample, n_classes)
94 } else {
95 right.predict_proba(sample, n_classes)
96 }
97 }
98 }
99 }
100
101 /// Depth of this subtree.
102 pub fn depth(&self) -> usize {
103 match self {
104 TreeNode::Leaf { .. } => 1,
105 TreeNode::Split { left, right, .. } => 1 + left.depth().max(right.depth()),
106 }
107 }
108
109 /// Number of leaf nodes in this subtree.
110 pub fn n_leaves(&self) -> usize {
111 match self {
112 TreeNode::Leaf { .. } => 1,
113 TreeNode::Split { left, right, .. } => left.n_leaves() + right.n_leaves(),
114 }
115 }
116
117 /// Number of samples at this node.
118 pub fn n_samples(&self) -> usize {
119 match self {
120 TreeNode::Leaf { n_samples, .. } | TreeNode::Split { n_samples, .. } => *n_samples,
121 }
122 }
123
124 /// Sum of weighted leaf impurities: Σ(impurity_leaf × n_samples_leaf).
125 ///
126 /// This is R(T_t) in the cost-complexity pruning literature.
127 pub fn total_leaf_impurity(&self) -> f64 {
128 match self {
129 TreeNode::Leaf {
130 impurity,
131 n_samples,
132 ..
133 } => *impurity * (*n_samples as f64),
134 TreeNode::Split { left, right, .. } => {
135 left.total_leaf_impurity() + right.total_leaf_impurity()
136 }
137 }
138 }
139
140 /// Minimal cost-complexity pruning (MCCP).
141 ///
142 /// Recursively prunes subtrees whose effective alpha is ≤ `ccp_alpha`.
143 /// Effective alpha = (R(t) - R(T_t)) / (|T_t| - 1), where R(t) is the
144 /// re-substitution error if this node were a leaf and R(T_t) is the
145 /// total leaf impurity of the subtree.
146 ///
147 /// This matches sklearn's `ccp_alpha` parameter behavior.
148 pub fn prune_ccp(self, ccp_alpha: f64) -> TreeNode {
149 match self {
150 TreeNode::Leaf { .. } => self,
151 TreeNode::Split {
152 feature_idx,
153 threshold,
154 left,
155 right,
156 n_samples,
157 impurity,
158 class_counts,
159 prediction,
160 } => {
161 // Recursively prune children first (bottom-up).
162 let pruned_left = left.prune_ccp(ccp_alpha);
163 let pruned_right = right.prune_ccp(ccp_alpha);
164
165 // Build the pruned split node to compute its subtree stats.
166 let subtree = TreeNode::Split {
167 feature_idx,
168 threshold,
169 left: Box::new(pruned_left),
170 right: Box::new(pruned_right),
171 n_samples,
172 impurity,
173 class_counts: class_counts.clone(),
174 prediction,
175 };
176
177 let n_leaves = subtree.n_leaves();
178 if n_leaves <= 1 {
179 return subtree;
180 }
181
182 // R(t) = impurity if this node were a leaf.
183 let r_node = impurity * (n_samples as f64);
184 // R(T_t) = total leaf impurity of subtree.
185 let r_subtree = subtree.total_leaf_impurity();
186
187 let effective_alpha = (r_node - r_subtree) / (n_leaves as f64 - 1.0);
188
189 if effective_alpha <= ccp_alpha {
190 // Collapse to leaf.
191 TreeNode::Leaf {
192 prediction,
193 n_samples,
194 class_counts,
195 impurity,
196 }
197 } else {
198 subtree
199 }
200 }
201 }
202 }
203
204 /// Compute the cost-complexity pruning path.
205 ///
206 /// Returns `(ccp_alphas, total_impurities)` — a sequence of effective
207 /// alpha values and the corresponding total tree impurity at each
208 /// pruning step. Useful for elbow-method selection of `ccp_alpha`.
209 pub fn cost_complexity_pruning_path(&self) -> (Vec<f64>, Vec<f64>) {
210 let mut alphas = vec![0.0];
211 let mut impurities = vec![self.total_leaf_impurity()];
212
213 let mut current = self.clone();
214 loop {
215 // Find the minimum effective alpha across all internal nodes.
216 let min_alpha = Self::min_effective_alpha(¤t);
217 match min_alpha {
218 None => break, // no more internal nodes
219 Some(alpha) => {
220 current = current.prune_ccp(alpha);
221 alphas.push(alpha);
222 impurities.push(current.total_leaf_impurity());
223 }
224 }
225 }
226 (alphas, impurities)
227 }
228
229 /// Find the minimum effective alpha among all internal nodes.
230 fn min_effective_alpha(node: &TreeNode) -> Option<f64> {
231 match node {
232 TreeNode::Leaf { .. } => None,
233 TreeNode::Split {
234 left,
235 right,
236 n_samples,
237 impurity,
238 ..
239 } => {
240 let n_leaves = node.n_leaves();
241 let r_node = impurity * (*n_samples as f64);
242 let r_subtree = node.total_leaf_impurity();
243 let my_alpha = if n_leaves > 1 {
244 Some((r_node - r_subtree) / (n_leaves as f64 - 1.0))
245 } else {
246 None
247 };
248
249 let left_alpha = Self::min_effective_alpha(left);
250 let right_alpha = Self::min_effective_alpha(right);
251
252 [my_alpha, left_alpha, right_alpha]
253 .iter()
254 .filter_map(|a| *a)
255 .reduce(f64::min)
256 }
257 }
258 }
259}