Skip to main content

scry_learn/ensemble/
stacking.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Voting and Stacking ensemble classifiers.
3//!
4//! [`VotingClassifier`] combines multiple classifiers via hard (majority
5//! vote) or soft (probability averaging) voting.
6//!
7//! [`StackingClassifier`] trains a meta-learner on out-of-fold predictions
8//! from a set of base estimators.
9
10use crate::dataset::Dataset;
11use crate::error::{Result, ScryLearnError};
12
13// ---------------------------------------------------------------------------
14// Voting strategy
15// ---------------------------------------------------------------------------
16
17/// Voting strategy for [`VotingClassifier`].
18#[derive(Clone, Debug, PartialEq, Eq)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20#[non_exhaustive]
21pub enum Voting {
22    /// Majority vote on predicted class labels.
23    Hard,
24    /// Average predicted probabilities, then take argmax.
25    ///
26    /// Requires all estimators to support `predict_proba`.
27    Soft,
28}
29
30// ---------------------------------------------------------------------------
31// Classifier wrapper — trait object for ensemble base learners
32// ---------------------------------------------------------------------------
33
34/// Trait object for classifiers that can be used in ensembles.
35///
36/// Covers the common interface: fit on a [`Dataset`], predict on features,
37/// and optionally predict class probabilities.
38pub trait EnsembleClassifier: Send + Sync {
39    /// Train on a dataset.
40    fn fit(&mut self, data: &Dataset) -> Result<()>;
41
42    /// Predict class labels for the given feature matrix.
43    fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>>;
44
45    /// Predict class probabilities (required for soft voting / stacking).
46    ///
47    /// Default implementation returns an error indicating the model does not
48    /// support probability predictions.
49    fn predict_proba(&self, _features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
50        Err(ScryLearnError::InvalidParameter(
51            "this estimator does not support predict_proba".into(),
52        ))
53    }
54
55    /// Clone into a boxed trait object.
56    fn clone_box(&self) -> Box<dyn EnsembleClassifier>;
57}
58
59impl Clone for Box<dyn EnsembleClassifier> {
60    fn clone(&self) -> Self {
61        self.clone_box()
62    }
63}
64
65// ---------------------------------------------------------------------------
66// EnsembleClassifier implementations for existing models
67// ---------------------------------------------------------------------------
68
69macro_rules! impl_ensemble_no_proba {
70    ($ty:path) => {
71        impl EnsembleClassifier for $ty {
72            fn fit(&mut self, data: &Dataset) -> Result<()> {
73                self.fit(data)
74            }
75            fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
76                self.predict(features)
77            }
78            fn clone_box(&self) -> Box<dyn EnsembleClassifier> {
79                Box::new(self.clone())
80            }
81        }
82    };
83}
84
85macro_rules! impl_ensemble_with_proba {
86    ($ty:path) => {
87        impl EnsembleClassifier for $ty {
88            fn fit(&mut self, data: &Dataset) -> Result<()> {
89                self.fit(data)
90            }
91            fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
92                self.predict(features)
93            }
94            fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
95                self.predict_proba(features)
96            }
97            fn clone_box(&self) -> Box<dyn EnsembleClassifier> {
98                Box::new(self.clone())
99            }
100        }
101    };
102}
103
104// Models that support predict_proba:
105impl_ensemble_with_proba!(crate::tree::DecisionTreeClassifier);
106impl_ensemble_with_proba!(crate::tree::RandomForestClassifier);
107impl_ensemble_with_proba!(crate::naive_bayes::GaussianNb);
108impl_ensemble_with_proba!(crate::naive_bayes::BernoulliNB);
109impl_ensemble_with_proba!(crate::naive_bayes::MultinomialNB);
110
111// Models without predict_proba:
112impl_ensemble_no_proba!(crate::tree::DecisionTreeRegressor);
113impl_ensemble_no_proba!(crate::linear::LogisticRegression);
114impl_ensemble_no_proba!(crate::linear::LinearRegression);
115impl_ensemble_no_proba!(crate::linear::LassoRegression);
116impl_ensemble_no_proba!(crate::linear::ElasticNet);
117impl_ensemble_no_proba!(crate::neighbors::KnnClassifier);
118impl_ensemble_no_proba!(crate::neighbors::KnnRegressor);
119impl_ensemble_no_proba!(crate::svm::LinearSVC);
120impl_ensemble_no_proba!(crate::svm::LinearSVR);
121#[cfg(feature = "experimental")]
122impl_ensemble_no_proba!(crate::svm::KernelSVC);
123#[cfg(feature = "experimental")]
124impl_ensemble_no_proba!(crate::svm::KernelSVR);
125
126// ---------------------------------------------------------------------------
127// VotingClassifier
128// ---------------------------------------------------------------------------
129
130/// Combines multiple classifiers via voting.
131///
132/// In [`Voting::Hard`] mode, each estimator votes for a class and the majority
133/// wins. In [`Voting::Soft`] mode, predicted probabilities are averaged and
134/// the class with the highest average probability is selected.
135///
136/// # Examples
137///
138/// ```ignore
139/// use scry_learn::ensemble::{VotingClassifier, Voting};
140/// use scry_learn::tree::DecisionTreeClassifier;
141///
142/// let vc = VotingClassifier::new(vec![
143///     Box::new(DecisionTreeClassifier::new().max_depth(3)),
144///     Box::new(DecisionTreeClassifier::new().max_depth(5)),
145///     Box::new(DecisionTreeClassifier::new().max_depth(7)),
146/// ]).voting(Voting::Hard);
147/// ```
148#[derive(Clone)]
149#[non_exhaustive]
150pub struct VotingClassifier {
151    /// Base estimators.
152    estimators: Vec<Box<dyn EnsembleClassifier>>,
153    /// Voting strategy.
154    voting_strategy: Voting,
155    /// Optional weights for each estimator.
156    weights: Option<Vec<f64>>,
157    /// Whether the model has been fitted.
158    fitted: bool,
159    /// Number of unique classes seen during fit.
160    n_classes: usize,
161}
162
163impl std::fmt::Debug for VotingClassifier {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        f.debug_struct("VotingClassifier")
166            .field("n_estimators", &self.estimators.len())
167            .field("voting", &self.voting_strategy)
168            .field("weights", &self.weights)
169            .field("fitted", &self.fitted)
170            .finish()
171    }
172}
173
174impl VotingClassifier {
175    /// Create a new voting classifier with the given base estimators.
176    pub fn new(estimators: Vec<Box<dyn EnsembleClassifier>>) -> Self {
177        Self {
178            estimators,
179            voting_strategy: Voting::Hard,
180            weights: None,
181            fitted: false,
182            n_classes: 0,
183        }
184    }
185
186    /// Set the voting strategy (default: [`Voting::Hard`]).
187    pub fn voting(mut self, v: Voting) -> Self {
188        self.voting_strategy = v;
189        self
190    }
191
192    /// Set weights for each estimator (default: equal weights).
193    pub fn weights(mut self, w: Vec<f64>) -> Self {
194        self.weights = Some(w);
195        self
196    }
197
198    /// Fit all base estimators on the given dataset.
199    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
200        if data.n_samples() == 0 {
201            return Err(ScryLearnError::EmptyDataset);
202        }
203        if self.estimators.is_empty() {
204            return Err(ScryLearnError::InvalidParameter(
205                "VotingClassifier requires at least one estimator".into(),
206            ));
207        }
208        if let Some(ref w) = self.weights {
209            if w.len() != self.estimators.len() {
210                return Err(ScryLearnError::InvalidParameter(format!(
211                    "weights length ({}) must match estimators length ({})",
212                    w.len(),
213                    self.estimators.len(),
214                )));
215            }
216        }
217
218        self.n_classes = data.n_classes();
219
220        for est in &mut self.estimators {
221            est.fit(data)?;
222        }
223        self.fitted = true;
224        Ok(())
225    }
226
227    /// Predict class labels via voting.
228    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
229        if !self.fitted {
230            return Err(ScryLearnError::NotFitted);
231        }
232
233        match self.voting_strategy {
234            Voting::Hard => self.predict_hard(features),
235            Voting::Soft => self.predict_soft(features),
236        }
237    }
238
239    /// Hard voting: majority vote across estimator predictions.
240    fn predict_hard(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
241        let n = features.len();
242        let n_classes = self.n_classes;
243
244        // Collect predictions from all estimators.
245        let all_preds: Vec<Vec<f64>> = self
246            .estimators
247            .iter()
248            .map(|est| est.predict(features))
249            .collect::<Result<_>>()?;
250
251        let weights = self.uniform_weights();
252
253        let mut result = Vec::with_capacity(n);
254        for sample_idx in 0..n {
255            let mut votes = vec![0.0_f64; n_classes.max(1)];
256            for (est_idx, preds) in all_preds.iter().enumerate() {
257                let class = preds[sample_idx] as usize;
258                if class < votes.len() {
259                    votes[class] += weights[est_idx];
260                }
261            }
262            let best_class = votes
263                .iter()
264                .enumerate()
265                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
266                .map_or(0, |(idx, _)| idx);
267            result.push(best_class as f64);
268        }
269
270        Ok(result)
271    }
272
273    /// Soft voting: average predict_proba across estimators, take argmax.
274    fn predict_soft(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
275        let n = features.len();
276        let n_classes = self.n_classes;
277        let weights = self.uniform_weights();
278
279        let mut avg_proba = vec![vec![0.0; n_classes]; n];
280
281        for (est_idx, est) in self.estimators.iter().enumerate() {
282            let probas = est.predict_proba(features)?;
283            for (sample_idx, proba) in probas.iter().enumerate() {
284                for (class_idx, &p) in proba.iter().enumerate() {
285                    if class_idx < n_classes {
286                        avg_proba[sample_idx][class_idx] += p * weights[est_idx];
287                    }
288                }
289            }
290        }
291
292        let result: Vec<f64> = avg_proba
293            .iter()
294            .map(|proba| {
295                proba
296                    .iter()
297                    .enumerate()
298                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
299                    .map_or(0.0, |(idx, _)| idx as f64)
300            })
301            .collect();
302
303        Ok(result)
304    }
305
306    /// Return weights (user-provided or uniform).
307    fn uniform_weights(&self) -> Vec<f64> {
308        self.weights
309            .clone()
310            .unwrap_or_else(|| vec![1.0; self.estimators.len()])
311    }
312}
313
314// ---------------------------------------------------------------------------
315// StackingClassifier
316// ---------------------------------------------------------------------------
317
318/// Stacking (stacked generalization) classifier.
319///
320/// Trains base estimators via k-fold cross-validation, collects out-of-fold
321/// predictions as meta-features, then trains a final meta-learner on those
322/// meta-features. At predict time, base estimator predictions are fed to
323/// the meta-learner.
324///
325/// # Examples
326///
327/// ```ignore
328/// use scry_learn::ensemble::StackingClassifier;
329/// use scry_learn::tree::DecisionTreeClassifier;
330/// use scry_learn::linear::LogisticRegression;
331///
332/// let sc = StackingClassifier::new(
333///     vec![
334///         Box::new(DecisionTreeClassifier::new().max_depth(3)),
335///         Box::new(DecisionTreeClassifier::new().max_depth(7)),
336///     ],
337///     Box::new(LogisticRegression::new()),
338/// ).cv(5);
339/// ```
340#[derive(Clone)]
341#[non_exhaustive]
342pub struct StackingClassifier {
343    /// Base learners.
344    estimators: Vec<Box<dyn EnsembleClassifier>>,
345    /// Meta-learner trained on out-of-fold predictions.
346    final_estimator: Box<dyn EnsembleClassifier>,
347    /// Number of cross-validation folds.
348    cv: usize,
349    /// Random seed for fold generation.
350    seed: u64,
351    /// Whether the model has been fitted.
352    fitted: bool,
353}
354
355impl std::fmt::Debug for StackingClassifier {
356    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357        f.debug_struct("StackingClassifier")
358            .field("n_estimators", &self.estimators.len())
359            .field("cv", &self.cv)
360            .field("fitted", &self.fitted)
361            .finish()
362    }
363}
364
365impl StackingClassifier {
366    /// Create a new stacking classifier.
367    ///
368    /// `estimators` are the base learners; `final_estimator` is the meta-learner.
369    pub fn new(
370        estimators: Vec<Box<dyn EnsembleClassifier>>,
371        final_estimator: Box<dyn EnsembleClassifier>,
372    ) -> Self {
373        Self {
374            estimators,
375            final_estimator,
376            cv: 5,
377            seed: 42,
378            fitted: false,
379        }
380    }
381
382    /// Set the number of CV folds (default: 5).
383    pub fn cv(mut self, k: usize) -> Self {
384        self.cv = k;
385        self
386    }
387
388    /// Set the random seed for fold generation (default: 42).
389    pub fn seed(mut self, s: u64) -> Self {
390        self.seed = s;
391        self
392    }
393
394    /// Fit the stacking classifier.
395    ///
396    /// 1. Split data into `cv` folds.
397    /// 2. For each fold, train base learners on training folds and collect
398    ///    out-of-fold predictions.
399    /// 3. Assemble meta-features from out-of-fold predictions.
400    /// 4. Train the final estimator on meta-features.
401    /// 5. Re-train all base learners on the full dataset for prediction.
402    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
403        data.validate_finite()?;
404        if data.n_samples() == 0 {
405            return Err(ScryLearnError::EmptyDataset);
406        }
407        if self.estimators.is_empty() {
408            return Err(ScryLearnError::InvalidParameter(
409                "StackingClassifier requires at least one base estimator".into(),
410            ));
411        }
412        if self.cv < 2 {
413            return Err(ScryLearnError::InvalidParameter(
414                "cv must be at least 2".into(),
415            ));
416        }
417
418        let n_samples = data.n_samples();
419        let n_estimators = self.estimators.len();
420
421        // Generate fold indices.
422        let folds = generate_fold_indices(n_samples, self.cv, self.seed);
423
424        // Meta-feature matrix: n_samples rows × n_estimators columns.
425        let mut meta_features = vec![vec![0.0; n_estimators]; n_samples];
426
427        for (fold_idx, test_indices) in folds.iter().enumerate() {
428            let train_indices: Vec<usize> = (0..n_samples)
429                .filter(|i| !test_indices.contains(i))
430                .collect();
431
432            let train_data = data.subset(&train_indices);
433            let test_features = Self::extract_features(data, test_indices);
434
435            for (est_idx, est_template) in self.estimators.iter().enumerate() {
436                let mut est = est_template.clone_box();
437                est.fit(&train_data)?;
438                let preds = est.predict(&test_features)?;
439
440                for (local_idx, &global_idx) in test_indices.iter().enumerate() {
441                    meta_features[global_idx][est_idx] = preds[local_idx];
442                }
443
444                // Drop to free memory.
445                let _ = fold_idx;
446            }
447        }
448
449        // Build meta-dataset: features = meta_features, target = original target.
450        let meta_columns: Vec<Vec<f64>> = (0..n_estimators)
451            .map(|est_idx| meta_features.iter().map(|row| row[est_idx]).collect())
452            .collect();
453        let feature_names: Vec<String> = (0..n_estimators).map(|i| format!("est_{i}")).collect();
454
455        let meta_dataset = Dataset::new(meta_columns, data.target.clone(), feature_names, "target");
456
457        // Train the final estimator on meta-features.
458        self.final_estimator.fit(&meta_dataset)?;
459
460        // Re-train base learners on the full dataset for prediction time.
461        for est in &mut self.estimators {
462            est.fit(data)?;
463        }
464
465        self.fitted = true;
466        Ok(())
467    }
468
469    /// Predict class labels using the stacking ensemble.
470    ///
471    /// Gets predictions from all base learners, then feeds them to the
472    /// meta-learner.
473    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
474        if !self.fitted {
475            return Err(ScryLearnError::NotFitted);
476        }
477
478        let n = features.len();
479        let n_estimators = self.estimators.len();
480
481        // Get base predictions.
482        let base_preds: Vec<Vec<f64>> = self
483            .estimators
484            .iter()
485            .map(|est| est.predict(features))
486            .collect::<Result<_>>()?;
487
488        // Assemble meta-features.
489        let meta_features: Vec<Vec<f64>> = (0..n)
490            .map(|i| (0..n_estimators).map(|j| base_preds[j][i]).collect())
491            .collect();
492
493        self.final_estimator.predict(&meta_features)
494    }
495
496    /// Extract row-major features for specific sample indices.
497    fn extract_features(data: &Dataset, indices: &[usize]) -> Vec<Vec<f64>> {
498        indices.iter().map(|&i| data.sample(i)).collect()
499    }
500}
501
502/// Generate fold indices for k-fold cross-validation.
503fn generate_fold_indices(n: usize, k: usize, seed: u64) -> Vec<Vec<usize>> {
504    let mut indices: Vec<usize> = (0..n).collect();
505    let mut rng = crate::rng::FastRng::new(seed);
506
507    // Fisher-Yates shuffle.
508    for i in (1..indices.len()).rev() {
509        let j = rng.usize(0..=i);
510        indices.swap(i, j);
511    }
512
513    let fold_size = n / k;
514    let remainder = n % k;
515    let mut folds = Vec::with_capacity(k);
516    let mut start = 0;
517    for fold in 0..k {
518        let extra = usize::from(fold < remainder);
519        let end = start + fold_size + extra;
520        folds.push(indices[start..end].to_vec());
521        start = end;
522    }
523
524    folds
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use crate::tree::DecisionTreeClassifier;
531
532    fn make_iris_like_data() -> Dataset {
533        // 3-class classification with clear separation.
534        let mut f1 = Vec::new();
535        let mut f2 = Vec::new();
536        let mut target = Vec::new();
537        let mut rng = crate::rng::FastRng::new(42);
538
539        // Class 0: cluster around (1, 1)
540        for _ in 0..40 {
541            f1.push(1.0 + rng.f64() * 0.5);
542            f2.push(1.0 + rng.f64() * 0.5);
543            target.push(0.0);
544        }
545        // Class 1: cluster around (5, 5)
546        for _ in 0..40 {
547            f1.push(5.0 + rng.f64() * 0.5);
548            f2.push(5.0 + rng.f64() * 0.5);
549            target.push(1.0);
550        }
551        // Class 2: cluster around (1, 5)
552        for _ in 0..40 {
553            f1.push(1.0 + rng.f64() * 0.5);
554            f2.push(5.0 + rng.f64() * 0.5);
555            target.push(2.0);
556        }
557
558        Dataset::new(
559            vec![f1, f2],
560            target,
561            vec!["f1".into(), "f2".into()],
562            "class",
563        )
564    }
565
566    #[test]
567    fn test_voting_hard_basic() {
568        let data = make_iris_like_data();
569
570        let mut vc = VotingClassifier::new(vec![
571            Box::new(DecisionTreeClassifier::new().max_depth(3)),
572            Box::new(DecisionTreeClassifier::new().max_depth(5)),
573            Box::new(DecisionTreeClassifier::new().max_depth(7)),
574        ])
575        .voting(Voting::Hard);
576
577        vc.fit(&data).unwrap();
578        let features = data.feature_matrix();
579        let preds = vc.predict(&features).unwrap();
580
581        let acc = preds
582            .iter()
583            .zip(data.target.iter())
584            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
585            .count() as f64
586            / data.n_samples() as f64;
587
588        assert!(
589            acc >= 0.85,
590            "VotingClassifier hard vote accuracy should be ≥ 85%, got {:.1}%",
591            acc * 100.0,
592        );
593    }
594
595    #[test]
596    fn test_voting_soft_basic() {
597        let data = make_iris_like_data();
598
599        // Use DecisionTreeClassifier which supports predict_proba.
600        let mut vc = VotingClassifier::new(vec![
601            Box::new(DecisionTreeClassifier::new().max_depth(3)),
602            Box::new(DecisionTreeClassifier::new().max_depth(5)),
603            Box::new(DecisionTreeClassifier::new().max_depth(7)),
604        ])
605        .voting(Voting::Soft);
606
607        vc.fit(&data).unwrap();
608        let features = data.feature_matrix();
609        let preds = vc.predict(&features).unwrap();
610
611        let acc = preds
612            .iter()
613            .zip(data.target.iter())
614            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
615            .count() as f64
616            / data.n_samples() as f64;
617
618        assert!(
619            acc >= 0.85,
620            "VotingClassifier soft vote accuracy should be ≥ 85%, got {:.1}%",
621            acc * 100.0,
622        );
623    }
624
625    #[test]
626    fn test_voting_weighted() {
627        let data = make_iris_like_data();
628
629        let mut vc = VotingClassifier::new(vec![
630            Box::new(DecisionTreeClassifier::new().max_depth(3)),
631            Box::new(DecisionTreeClassifier::new().max_depth(5)),
632        ])
633        .voting(Voting::Hard)
634        .weights(vec![1.0, 2.0]);
635
636        vc.fit(&data).unwrap();
637        let features = data.feature_matrix();
638        let preds = vc.predict(&features).unwrap();
639        assert_eq!(preds.len(), data.n_samples());
640    }
641
642    #[test]
643    fn test_voting_not_fitted() {
644        let vc = VotingClassifier::new(vec![Box::new(DecisionTreeClassifier::new())]);
645        let result = vc.predict(&[vec![1.0, 2.0]]);
646        assert!(result.is_err());
647    }
648
649    #[test]
650    fn test_voting_empty_estimators() {
651        let data = make_iris_like_data();
652        let mut vc = VotingClassifier::new(vec![]);
653        assert!(vc.fit(&data).is_err());
654    }
655
656    #[test]
657    fn test_voting_weights_mismatch() {
658        let data = make_iris_like_data();
659        let mut vc = VotingClassifier::new(vec![Box::new(DecisionTreeClassifier::new())])
660            .weights(vec![1.0, 2.0]); // mismatch: 2 weights for 1 estimator
661        assert!(vc.fit(&data).is_err());
662    }
663
664    #[test]
665    fn test_stacking_basic() {
666        let data = make_iris_like_data();
667
668        let mut sc = StackingClassifier::new(
669            vec![
670                Box::new(DecisionTreeClassifier::new().max_depth(3)),
671                Box::new(DecisionTreeClassifier::new().max_depth(7)),
672            ],
673            Box::new(DecisionTreeClassifier::new().max_depth(5)),
674        )
675        .cv(3)
676        .seed(42);
677
678        sc.fit(&data).unwrap();
679        let features = data.feature_matrix();
680        let preds = sc.predict(&features).unwrap();
681
682        assert_eq!(preds.len(), data.n_samples());
683
684        let acc = preds
685            .iter()
686            .zip(data.target.iter())
687            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
688            .count() as f64
689            / data.n_samples() as f64;
690
691        assert!(
692            acc >= 0.70,
693            "StackingClassifier accuracy should be ≥ 70%, got {:.1}%",
694            acc * 100.0,
695        );
696    }
697
698    #[test]
699    fn test_stacking_not_fitted() {
700        let sc = StackingClassifier::new(
701            vec![Box::new(DecisionTreeClassifier::new())],
702            Box::new(DecisionTreeClassifier::new()),
703        );
704        let result = sc.predict(&[vec![1.0, 2.0]]);
705        assert!(result.is_err());
706    }
707
708    #[test]
709    fn test_stacking_empty_estimators() {
710        let data = make_iris_like_data();
711        let mut sc = StackingClassifier::new(vec![], Box::new(DecisionTreeClassifier::new()));
712        assert!(sc.fit(&data).is_err());
713    }
714
715    #[test]
716    fn test_stacking_cv_too_small() {
717        let data = make_iris_like_data();
718        let mut sc = StackingClassifier::new(
719            vec![Box::new(DecisionTreeClassifier::new())],
720            Box::new(DecisionTreeClassifier::new()),
721        )
722        .cv(1);
723        assert!(sc.fit(&data).is_err());
724    }
725
726    #[test]
727    fn test_generate_fold_indices() {
728        let folds = generate_fold_indices(10, 3, 42);
729        assert_eq!(folds.len(), 3);
730        let total: usize = folds.iter().map(std::vec::Vec::len).sum();
731        assert_eq!(total, 10);
732        // All indices present.
733        let mut all: Vec<usize> = folds.into_iter().flatten().collect();
734        all.sort_unstable();
735        assert_eq!(all, (0..10).collect::<Vec<_>>());
736    }
737
738    #[test]
739    fn test_voting_accuracy_ge_individual() {
740        let data = make_iris_like_data();
741        let features = data.feature_matrix();
742
743        // Train individual trees and get their accuracies.
744        let mut dt1 = DecisionTreeClassifier::new().max_depth(2);
745        dt1.fit(&data).unwrap();
746        let preds1 = dt1.predict(&features).unwrap();
747        let acc1 = preds1
748            .iter()
749            .zip(data.target.iter())
750            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
751            .count() as f64
752            / data.n_samples() as f64;
753
754        // Voting with 3 trees — accuracy should generally be >= worst individual.
755        let mut vc = VotingClassifier::new(vec![
756            Box::new(DecisionTreeClassifier::new().max_depth(2)),
757            Box::new(DecisionTreeClassifier::new().max_depth(4)),
758            Box::new(DecisionTreeClassifier::new().max_depth(6)),
759        ])
760        .voting(Voting::Hard);
761
762        vc.fit(&data).unwrap();
763        let preds_vc = vc.predict(&features).unwrap();
764        let acc_vc = preds_vc
765            .iter()
766            .zip(data.target.iter())
767            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
768            .count() as f64
769            / data.n_samples() as f64;
770
771        // Ensemble should be at least as good as shallow tree.
772        assert!(
773            acc_vc >= acc1 - 0.05,
774            "VotingClassifier ({:.1}%) should be ≥ individual DT ({:.1}%) - 5%",
775            acc_vc * 100.0,
776            acc1 * 100.0,
777        );
778    }
779}