scry_learn/tree/cart/
mod.rs1mod builder;
11mod flat;
12mod node;
13
14pub(crate) use builder::presort_indices;
15pub use builder::{DecisionTreeClassifier, DecisionTreeRegressor};
16pub use flat::FlatTree;
17pub use node::TreeNode;
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[non_exhaustive]
27pub enum SplitCriterion {
28 Gini,
30 Entropy,
32 Mse,
34}
35
36pub(crate) const LEAF_SENTINEL: u32 = u32::MAX;
38
39pub(super) struct BestSplit {
44 pub(super) feature_idx: usize,
45 pub(super) threshold: f64,
46 pub(super) impurity_decrease: f64,
47}
48
49pub(super) fn compute_impurity(counts: &[usize], n: usize, criterion: SplitCriterion) -> f64 {
50 if n == 0 {
51 return 0.0;
52 }
53 let n_f = n as f64;
54 match criterion {
55 SplitCriterion::Gini => {
56 let sum_sq: f64 = counts
57 .iter()
58 .map(|&c| {
59 let p = c as f64 / n_f;
60 p * p
61 })
62 .sum();
63 1.0 - sum_sq
64 }
65 SplitCriterion::Entropy => {
66 let mut entropy = 0.0;
67 for &c in counts {
68 if c > 0 {
69 let p = c as f64 / n_f;
70 entropy -= p * p.log2();
71 }
72 }
73 entropy
74 }
75 SplitCriterion::Mse => {
76 0.0
78 }
79 }
80}
81
82pub(super) fn majority_class(counts: &[usize]) -> f64 {
83 counts
84 .iter()
85 .enumerate()
86 .max_by_key(|&(_, &count)| count)
87 .map_or(0.0, |(idx, _)| idx as f64)
88}
89
90pub(super) fn compute_impurity_weighted(
95 counts: &[f64],
96 total: f64,
97 criterion: SplitCriterion,
98) -> f64 {
99 if total < 1e-12 {
100 return 0.0;
101 }
102 match criterion {
103 SplitCriterion::Gini => {
104 let sum_sq: f64 = counts
105 .iter()
106 .map(|&c| {
107 let p = c / total;
108 p * p
109 })
110 .sum();
111 1.0 - sum_sq
112 }
113 SplitCriterion::Entropy => {
114 let mut entropy = 0.0;
115 for &c in counts {
116 if c > 1e-12 {
117 let p = c / total;
118 entropy -= p * p.log2();
119 }
120 }
121 entropy
122 }
123 SplitCriterion::Mse => 0.0,
124 }
125}
126
127pub(super) fn weighted_majority_class(counts: &[f64]) -> f64 {
128 counts
129 .iter()
130 .enumerate()
131 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
132 .map_or(0.0, |(idx, _)| idx as f64)
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::dataset::Dataset;
139
140 fn make_linearly_separable() -> Dataset {
141 let features = vec![(0..20).map(|i| i as f64).collect()];
143 let target: Vec<f64> = (0..20).map(|i| if i < 10 { 0.0 } else { 1.0 }).collect();
144 Dataset::new(features, target, vec!["x".into()], "class")
145 }
146
147 #[test]
148 fn test_decision_tree_perfect_split() {
149 let data = make_linearly_separable();
150 let mut dt = DecisionTreeClassifier::new();
151 dt.fit(&data).unwrap();
152
153 let matrix = data.feature_matrix();
154 let preds = dt.predict(&matrix).unwrap();
155 let acc = preds
156 .iter()
157 .zip(data.target.iter())
158 .filter(|(p, t)| (*p - *t).abs() < 1e-6)
159 .count() as f64
160 / data.n_samples() as f64;
161
162 assert!(
163 acc >= 0.95,
164 "expected ≥95% accuracy on linearly separable data, got {:.1}%",
165 acc * 100.0
166 );
167 }
168
169 #[test]
170 fn test_feature_importance_sums_to_one() {
171 let data = make_linearly_separable();
172 let mut dt = DecisionTreeClassifier::new();
173 dt.fit(&data).unwrap();
174
175 let importances = dt.feature_importances().unwrap();
176 let total: f64 = importances.iter().sum();
177 assert!(
178 (total - 1.0).abs() < 1e-6,
179 "feature importances should sum to 1.0, got {total}"
180 );
181 }
182
183 #[test]
184 fn test_max_depth() {
185 let data = make_linearly_separable();
186 let mut dt = DecisionTreeClassifier::new().max_depth(2);
187 dt.fit(&data).unwrap();
188 assert!(dt.depth() <= 2 + 1); }
190
191 #[test]
192 fn test_predict_proba() {
193 let data = make_linearly_separable();
194 let mut dt = DecisionTreeClassifier::new();
195 dt.fit(&data).unwrap();
196
197 let sample_class0 = vec![2.0]; let proba = dt.predict_proba(&[sample_class0]).unwrap();
199 assert!(proba[0][0] > 0.5, "should predict class 0 with >50%");
200 }
201
202 #[test]
203 fn test_regressor_basic() {
204 let features = vec![(0..50).map(|i| i as f64).collect()];
206 let target: Vec<f64> = (0..50).map(|i| i as f64).collect();
207 let data = Dataset::new(features, target, vec!["x".into()], "y");
208
209 let mut dt = DecisionTreeRegressor::new().max_depth(10);
210 dt.fit(&data).unwrap();
211
212 let matrix = data.feature_matrix();
213 let preds = dt.predict(&matrix).unwrap();
214
215 let mse: f64 = preds
217 .iter()
218 .zip(data.target.iter())
219 .map(|(p, t)| (p - t).powi(2))
220 .sum::<f64>()
221 / data.n_samples() as f64;
222
223 assert!(mse < 5.0, "MSE on training data should be low, got {mse}");
224 }
225
226 #[test]
227 fn test_not_fitted_error() {
228 let dt = DecisionTreeClassifier::new();
229 assert!(dt.predict(&[vec![1.0]]).is_err());
230 }
231
232 fn make_iris_like() -> Dataset {
237 let mut rng = crate::rng::FastRng::new(42);
239 let n = 150;
240 let mut f1 = Vec::with_capacity(n);
241 let mut f2 = Vec::with_capacity(n);
242 let mut target = Vec::with_capacity(n);
243 for _ in 0..50 {
244 f1.push(rng.f64() * 2.0);
245 f2.push(rng.f64() * 2.0);
246 target.push(0.0);
247 }
248 for _ in 0..50 {
249 f1.push(rng.f64() * 2.0 + 3.0);
250 f2.push(rng.f64() * 2.0 + 3.0);
251 target.push(1.0);
252 }
253 for _ in 0..50 {
254 f1.push(rng.f64() * 2.0 + 6.0);
255 f2.push(rng.f64() * 2.0);
256 target.push(2.0);
257 }
258 Dataset::new(
259 vec![f1, f2],
260 target,
261 vec!["f1".into(), "f2".into()],
262 "class",
263 )
264 }
265
266 #[test]
267 fn test_ccp_alpha_reduces_depth() {
268 let data = make_iris_like();
269
270 let mut dt_full = DecisionTreeClassifier::new();
271 dt_full.fit(&data).unwrap();
272 let depth_full = dt_full.depth();
273 let leaves_full = dt_full.n_leaves();
274
275 let mut dt_pruned = DecisionTreeClassifier::new().ccp_alpha(0.02);
276 dt_pruned.fit(&data).unwrap();
277 let depth_pruned = dt_pruned.depth();
278 let leaves_pruned = dt_pruned.n_leaves();
279
280 eprintln!("Full tree: depth={depth_full}, leaves={leaves_full}");
281 eprintln!("Pruned tree: depth={depth_pruned}, leaves={leaves_pruned}");
282
283 assert!(
284 leaves_pruned <= leaves_full,
285 "Pruned tree should have ≤ leaves than full: {leaves_pruned} vs {leaves_full}"
286 );
287 }
288
289 #[test]
290 fn test_ccp_alpha_zero_no_change() {
291 let data = make_iris_like();
292
293 let mut dt_zero = DecisionTreeClassifier::new().ccp_alpha(0.0);
294 dt_zero.fit(&data).unwrap();
295 let mut dt_default = DecisionTreeClassifier::new();
296 dt_default.fit(&data).unwrap();
297
298 assert_eq!(
299 dt_zero.n_leaves(),
300 dt_default.n_leaves(),
301 "ccp_alpha=0.0 should not change the tree"
302 );
303 }
304
305 #[test]
306 fn test_ccp_alpha_large_collapses_to_root() {
307 let data = make_iris_like();
308 let mut dt = DecisionTreeClassifier::new().ccp_alpha(1000.0);
309 dt.fit(&data).unwrap();
310 assert_eq!(
311 dt.n_leaves(),
312 1,
313 "Very large ccp_alpha should collapse to a single leaf"
314 );
315 }
316
317 #[test]
318 fn test_regressor_ccp_alpha() {
319 let features = vec![(0..100).map(|i| i as f64).collect()];
320 let target: Vec<f64> = (0..100).map(|i| (i as f64).sin()).collect();
321 let data = Dataset::new(features, target, vec!["x".into()], "y");
322
323 let mut dt_full = DecisionTreeRegressor::new();
324 dt_full.fit(&data).unwrap();
325
326 let mut dt_pruned = DecisionTreeRegressor::new().ccp_alpha(0.01);
327 dt_pruned.fit(&data).unwrap();
328
329 let full_leaves = dt_full.flat_tree().unwrap().n_leaves();
330 let pruned_leaves = dt_pruned.flat_tree().unwrap().n_leaves();
331
332 eprintln!("Regressor: full={full_leaves} leaves, pruned={pruned_leaves} leaves");
333 assert!(
334 pruned_leaves <= full_leaves,
335 "Pruned regressor should have ≤ leaves: {pruned_leaves} vs {full_leaves}"
336 );
337 }
338
339 #[test]
340 fn test_pruning_path_monotonic() {
341 let data = make_iris_like();
342 let mut dt = DecisionTreeClassifier::new();
343 dt.fit(&data).unwrap();
344
345 let (alphas, impurities) = dt.cost_complexity_pruning_path(&data).unwrap();
346
347 assert!(alphas.len() >= 2, "Should have at least 2 pruning steps");
348 for w in alphas.windows(2) {
350 assert!(
351 w[1] >= w[0] - 1e-12,
352 "Alphas should be monotonically non-decreasing: {} -> {}",
353 w[0],
354 w[1]
355 );
356 }
357 for w in impurities.windows(2) {
359 assert!(
360 w[1] >= w[0] - 1e-12,
361 "Impurities should be non-decreasing: {} -> {}",
362 w[0],
363 w[1]
364 );
365 }
366 eprintln!("Pruning path: {} steps", alphas.len());
367 eprintln!("Alphas: {:?}", &alphas[..alphas.len().min(5)]);
368 }
369}