sklears_model_selection/
ensemble_evaluation.rs

1//! Ensemble Cross-Validation and Diversity-Based Evaluation
2//!
3//! This module provides advanced evaluation methods for ensemble models including
4//! specialized cross-validation techniques, diversity measures, stability analysis,
5//! and out-of-bag evaluation strategies specifically designed for ensemble learning.
6
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::Rng;
10use scirs2_core::random::SeedableRng;
11use sklears_core::types::Float;
12use std::collections::HashMap;
13
14/// Ensemble evaluation strategies
15#[derive(Debug, Clone)]
16pub enum EnsembleEvaluationStrategy {
17    /// Out-of-bag evaluation for bootstrap-based ensembles
18    OutOfBag {
19        bootstrap_samples: usize,
20
21        confidence_level: Float,
22    },
23    /// Ensemble-specific cross-validation
24    EnsembleCrossValidation {
25        cv_strategy: EnsembleCVStrategy,
26
27        n_folds: usize,
28    },
29    /// Diversity-based evaluation
30    DiversityEvaluation {
31        diversity_measures: Vec<DiversityMeasure>,
32        diversity_threshold: Float,
33    },
34    /// Stability analysis across different data splits
35    StabilityAnalysis {
36        n_bootstrap_samples: usize,
37        stability_metrics: Vec<StabilityMetric>,
38    },
39    /// Progressive ensemble evaluation
40    ProgressiveEvaluation {
41        ensemble_sizes: Vec<usize>,
42        selection_strategy: ProgressiveSelectionStrategy,
43    },
44    /// Multi-objective ensemble evaluation
45    MultiObjectiveEvaluation {
46        objectives: Vec<EvaluationObjective>,
47        trade_off_analysis: bool,
48    },
49}
50
51/// Cross-validation strategies specific to ensembles
52#[derive(Debug, Clone)]
53pub enum EnsembleCVStrategy {
54    /// Standard k-fold CV for ensembles
55    KFoldEnsemble,
56    /// Stratified CV maintaining ensemble member diversity
57    StratifiedEnsemble,
58    /// Leave-one-model-out CV
59    LeaveOneModelOut,
60    /// Bootstrap CV for ensemble components
61    BootstrapEnsemble { n_bootstrap: usize },
62    /// Nested CV for ensemble and member optimization
63    NestedEnsemble { inner_cv: usize, outer_cv: usize },
64    /// Time series CV for temporal ensembles
65    TimeSeriesEnsemble { n_splits: usize, test_size: Float },
66}
67
68/// Diversity measures for ensemble evaluation
69#[derive(Debug, Clone)]
70pub enum DiversityMeasure {
71    /// Q-statistic between pairs of classifiers
72    QStatistic,
73    /// Correlation coefficient between predictions
74    CorrelationCoefficient,
75    /// Disagreement measure
76    DisagreementMeasure,
77    /// Double-fault measure
78    DoubleFaultMeasure,
79    /// Entropy-based diversity
80    EntropyDiversity,
81    /// Kohavi-Wolpert variance
82    KohaviWolpertVariance,
83    /// Interrater agreement (Kappa)
84    InterraterAgreement,
85    /// Measurement of difficulty
86    DifficultyMeasure,
87    /// Generalized diversity index
88    GeneralizedDiversity { alpha: Float },
89}
90
91/// Stability metrics for ensemble evaluation
92#[derive(Debug, Clone)]
93pub enum StabilityMetric {
94    /// Prediction stability across bootstrap samples
95    PredictionStability,
96    /// Model selection stability
97    ModelSelectionStability,
98    /// Weight stability for weighted ensembles
99    WeightStability,
100    /// Performance stability
101    PerformanceStability,
102    /// Ranking stability of ensemble members
103    RankingStability,
104}
105
106/// Progressive selection strategies
107#[derive(Debug, Clone)]
108pub enum ProgressiveSelectionStrategy {
109    /// Forward selection based on performance
110    ForwardSelection,
111    /// Backward elimination
112    BackwardElimination,
113    /// Diversity-driven selection
114    DiversityDriven,
115    /// Performance-diversity trade-off
116    PerformanceDiversityTradeoff { alpha: Float },
117}
118
119/// Evaluation objectives for multi-objective analysis
120#[derive(Debug, Clone)]
121pub enum EvaluationObjective {
122    /// Predictive accuracy
123    Accuracy,
124    /// Model diversity
125    Diversity,
126    /// Computational efficiency
127    Efficiency,
128    /// Memory usage
129    MemoryUsage,
130    /// Robustness to outliers
131    Robustness,
132    /// Interpretability
133    Interpretability,
134    /// Fairness across groups
135    Fairness,
136}
137
138/// Ensemble evaluation configuration
139#[derive(Debug, Clone)]
140pub struct EnsembleEvaluationConfig {
141    pub strategy: EnsembleEvaluationStrategy,
142    pub evaluation_metrics: Vec<String>,
143    pub confidence_level: Float,
144    pub n_repetitions: usize,
145    pub parallel_evaluation: bool,
146    pub random_state: Option<u64>,
147    pub verbose: bool,
148}
149
150/// Ensemble evaluation result
151#[derive(Debug, Clone)]
152pub struct EnsembleEvaluationResult {
153    pub ensemble_performance: EnsemblePerformanceMetrics,
154    pub diversity_analysis: DiversityAnalysis,
155    pub stability_analysis: Option<StabilityAnalysis>,
156    pub member_contributions: Vec<MemberContribution>,
157    pub out_of_bag_scores: Option<OutOfBagScores>,
158    pub progressive_performance: Option<ProgressivePerformance>,
159    pub multi_objective_analysis: Option<MultiObjectiveAnalysis>,
160}
161
162/// Comprehensive ensemble performance metrics
163#[derive(Debug, Clone)]
164pub struct EnsemblePerformanceMetrics {
165    pub mean_performance: Float,
166    pub std_performance: Float,
167    pub confidence_interval: (Float, Float),
168    pub individual_fold_scores: Vec<Float>,
169    pub ensemble_vs_best_member: Float,
170    pub ensemble_vs_average_member: Float,
171    pub performance_gain: Float,
172}
173
174/// Diversity analysis results
175#[derive(Debug, Clone)]
176pub struct DiversityAnalysis {
177    pub overall_diversity: Float,
178    pub pairwise_diversities: Array2<Float>,
179    pub diversity_by_measure: HashMap<String, Float>,
180    pub diversity_distribution: Vec<Float>,
181    pub optimal_diversity_size: Option<usize>,
182}
183
184/// Stability analysis results
185#[derive(Debug, Clone)]
186pub struct StabilityAnalysis {
187    pub prediction_stability: Float,
188    pub model_selection_stability: Float,
189    pub weight_stability: Option<Float>,
190    pub performance_stability: Float,
191    pub stability_confidence_intervals: HashMap<String, (Float, Float)>,
192}
193
194/// Individual member contribution analysis
195#[derive(Debug, Clone)]
196pub struct MemberContribution {
197    pub member_id: usize,
198    pub member_name: String,
199    pub individual_performance: Float,
200    pub marginal_contribution: Float,
201    pub shapley_value: Option<Float>,
202    pub removal_impact: Float,
203    pub diversity_contribution: Float,
204}
205
206/// Out-of-bag evaluation scores
207#[derive(Debug, Clone)]
208pub struct OutOfBagScores {
209    pub oob_score: Float,
210    pub oob_confidence_interval: (Float, Float),
211    pub feature_importance: Option<Array1<Float>>,
212    pub prediction_intervals: Option<Array2<Float>>,
213    pub individual_oob_scores: Vec<Float>,
214}
215
216/// Progressive performance analysis
217#[derive(Debug, Clone)]
218pub struct ProgressivePerformance {
219    pub ensemble_sizes: Vec<usize>,
220    pub performance_curve: Vec<Float>,
221    pub diversity_curve: Vec<Float>,
222    pub efficiency_curve: Vec<Float>,
223    pub optimal_size: usize,
224    pub diminishing_returns_threshold: Option<usize>,
225}
226
227/// Multi-objective analysis results
228#[derive(Debug, Clone)]
229pub struct MultiObjectiveAnalysis {
230    pub pareto_front: Vec<(Float, Float)>,
231    pub objective_scores: HashMap<String, Float>,
232    pub trade_off_analysis: HashMap<String, Float>,
233    pub dominated_solutions: Vec<usize>,
234    pub compromise_solution: Option<usize>,
235}
236
237/// Ensemble evaluator
238#[derive(Debug)]
239pub struct EnsembleEvaluator {
240    config: EnsembleEvaluationConfig,
241    rng: StdRng,
242}
243
244impl Default for EnsembleEvaluationConfig {
245    fn default() -> Self {
246        Self {
247            strategy: EnsembleEvaluationStrategy::EnsembleCrossValidation {
248                cv_strategy: EnsembleCVStrategy::KFoldEnsemble,
249                n_folds: 5,
250            },
251            evaluation_metrics: vec!["accuracy".to_string(), "f1_score".to_string()],
252            confidence_level: 0.95,
253            n_repetitions: 1,
254            parallel_evaluation: false,
255            random_state: None,
256            verbose: false,
257        }
258    }
259}
260
261impl EnsembleEvaluator {
262    /// Create a new ensemble evaluator
263    pub fn new(config: EnsembleEvaluationConfig) -> Self {
264        let rng = match config.random_state {
265            Some(seed) => StdRng::seed_from_u64(seed),
266            None => {
267                use scirs2_core::random::thread_rng;
268                StdRng::from_rng(&mut thread_rng())
269            }
270        };
271
272        Self { config, rng }
273    }
274
275    /// Evaluate ensemble using specified strategy
276    pub fn evaluate<F>(
277        &mut self,
278        ensemble_predictions: &Array2<Float>,
279        true_labels: &Array1<Float>,
280        ensemble_weights: Option<&Array1<Float>>,
281        model_predictions: Option<&Array2<Float>>,
282        evaluation_fn: F,
283    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
284    where
285        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
286    {
287        match &self.config.strategy {
288            EnsembleEvaluationStrategy::OutOfBag { .. } => self.evaluate_out_of_bag(
289                ensemble_predictions,
290                true_labels,
291                ensemble_weights,
292                &evaluation_fn,
293            ),
294            EnsembleEvaluationStrategy::EnsembleCrossValidation { .. } => self
295                .evaluate_cross_validation(
296                    ensemble_predictions,
297                    true_labels,
298                    ensemble_weights,
299                    model_predictions,
300                    &evaluation_fn,
301                ),
302            EnsembleEvaluationStrategy::DiversityEvaluation { .. } => self.evaluate_diversity(
303                ensemble_predictions,
304                true_labels,
305                model_predictions,
306                &evaluation_fn,
307            ),
308            EnsembleEvaluationStrategy::StabilityAnalysis { .. } => self.evaluate_stability(
309                ensemble_predictions,
310                true_labels,
311                ensemble_weights,
312                &evaluation_fn,
313            ),
314            EnsembleEvaluationStrategy::ProgressiveEvaluation { .. } => self.evaluate_progressive(
315                ensemble_predictions,
316                true_labels,
317                model_predictions,
318                &evaluation_fn,
319            ),
320            EnsembleEvaluationStrategy::MultiObjectiveEvaluation { .. } => self
321                .evaluate_multi_objective(
322                    ensemble_predictions,
323                    true_labels,
324                    ensemble_weights,
325                    &evaluation_fn,
326                ),
327        }
328    }
329
330    /// Out-of-bag evaluation implementation
331    fn evaluate_out_of_bag<F>(
332        &mut self,
333        ensemble_predictions: &Array2<Float>,
334        true_labels: &Array1<Float>,
335        ensemble_weights: Option<&Array1<Float>>,
336        evaluation_fn: &F,
337    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
338    where
339        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
340    {
341        let (bootstrap_samples, confidence_level) = match &self.config.strategy {
342            EnsembleEvaluationStrategy::OutOfBag {
343                bootstrap_samples,
344                confidence_level,
345            } => (*bootstrap_samples, *confidence_level),
346            _ => unreachable!(),
347        };
348
349        let n_samples = ensemble_predictions.nrows();
350        let n_models = ensemble_predictions.ncols();
351
352        let mut oob_scores = Vec::new();
353        let mut oob_predictions_all = Vec::new();
354
355        for _ in 0..bootstrap_samples {
356            // Generate bootstrap sample
357            let bootstrap_indices: Vec<usize> = (0..n_samples)
358                .map(|_| self.rng.gen_range(0..n_samples))
359                .collect();
360
361            // Find out-of-bag samples
362            let mut oob_indices = Vec::new();
363            for i in 0..n_samples {
364                if !bootstrap_indices.contains(&i) {
365                    oob_indices.push(i);
366                }
367            }
368
369            if oob_indices.is_empty() {
370                continue;
371            }
372
373            // Calculate OOB predictions
374            let oob_ensemble_preds = self.calculate_ensemble_predictions(
375                ensemble_predictions,
376                &oob_indices,
377                ensemble_weights,
378            )?;
379
380            let oob_true_labels =
381                Array1::from_vec(oob_indices.iter().map(|&i| true_labels[i]).collect());
382
383            let oob_score = evaluation_fn(&oob_ensemble_preds, &oob_true_labels)?;
384            oob_scores.push(oob_score);
385            oob_predictions_all.push(oob_ensemble_preds);
386        }
387
388        let mean_oob_score = oob_scores.iter().sum::<Float>() / oob_scores.len() as Float;
389        let std_oob_score = {
390            let variance = oob_scores
391                .iter()
392                .map(|&score| (score - mean_oob_score).powi(2))
393                .sum::<Float>()
394                / oob_scores.len() as Float;
395            variance.sqrt()
396        };
397
398        let _alpha = 1.0 - confidence_level;
399        let z_score = 1.96; // Approximate for 95% confidence
400        let margin_of_error = z_score * std_oob_score / (oob_scores.len() as Float).sqrt();
401        let confidence_interval = (
402            mean_oob_score - margin_of_error,
403            mean_oob_score + margin_of_error,
404        );
405
406        let oob_scores_result = OutOfBagScores {
407            oob_score: mean_oob_score,
408            oob_confidence_interval: confidence_interval,
409            feature_importance: None, // Could be calculated if feature data available
410            prediction_intervals: None, // Could be calculated from OOB predictions
411            individual_oob_scores: oob_scores,
412        };
413
414        // Calculate basic ensemble performance
415        let ensemble_preds = self.calculate_ensemble_predictions(
416            ensemble_predictions,
417            &(0..n_samples).collect::<Vec<_>>(),
418            ensemble_weights,
419        )?;
420        let ensemble_score = evaluation_fn(&ensemble_preds, true_labels)?;
421
422        let ensemble_performance = EnsemblePerformanceMetrics {
423            mean_performance: ensemble_score,
424            std_performance: std_oob_score,
425            confidence_interval,
426            individual_fold_scores: vec![ensemble_score],
427            ensemble_vs_best_member: 0.0, // Would need individual model scores
428            ensemble_vs_average_member: 0.0,
429            performance_gain: 0.0,
430        };
431
432        Ok(EnsembleEvaluationResult {
433            ensemble_performance,
434            diversity_analysis: DiversityAnalysis {
435                overall_diversity: 0.0,
436                pairwise_diversities: Array2::zeros((n_models, n_models)),
437                diversity_by_measure: HashMap::new(),
438                diversity_distribution: Vec::new(),
439                optimal_diversity_size: None,
440            },
441            stability_analysis: None,
442            member_contributions: Vec::new(),
443            out_of_bag_scores: Some(oob_scores_result),
444            progressive_performance: None,
445            multi_objective_analysis: None,
446        })
447    }
448
449    /// Cross-validation evaluation implementation
450    fn evaluate_cross_validation<F>(
451        &mut self,
452        ensemble_predictions: &Array2<Float>,
453        true_labels: &Array1<Float>,
454        ensemble_weights: Option<&Array1<Float>>,
455        model_predictions: Option<&Array2<Float>>,
456        evaluation_fn: &F,
457    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
458    where
459        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
460    {
461        let (_cv_strategy, n_folds) = match &self.config.strategy {
462            EnsembleEvaluationStrategy::EnsembleCrossValidation {
463                cv_strategy,
464                n_folds,
465            } => (cv_strategy, *n_folds),
466            _ => unreachable!(),
467        };
468
469        let n_samples = ensemble_predictions.nrows();
470        let fold_size = n_samples / n_folds;
471        let mut fold_scores = Vec::new();
472        let mut diversity_scores = Vec::new();
473
474        for fold in 0..n_folds {
475            let test_start = fold * fold_size;
476            let test_end = if fold == n_folds - 1 {
477                n_samples
478            } else {
479                (fold + 1) * fold_size
480            };
481            let test_indices: Vec<usize> = (test_start..test_end).collect();
482
483            // Calculate ensemble predictions for test fold
484            let test_ensemble_preds = self.calculate_ensemble_predictions(
485                ensemble_predictions,
486                &test_indices,
487                ensemble_weights,
488            )?;
489
490            let test_true_labels =
491                Array1::from_vec(test_indices.iter().map(|&i| true_labels[i]).collect());
492
493            let fold_score = evaluation_fn(&test_ensemble_preds, &test_true_labels)?;
494            fold_scores.push(fold_score);
495
496            // Calculate diversity for this fold if model predictions available
497            if let Some(model_preds) = model_predictions {
498                let mut fold_data = Vec::new();
499                for &i in test_indices.iter() {
500                    fold_data.extend(model_preds.row(i).iter().cloned());
501                }
502                let fold_model_preds =
503                    Array2::from_shape_vec((test_indices.len(), model_preds.ncols()), fold_data)?;
504
505                let diversity = self.calculate_q_statistic(&fold_model_preds)?;
506                diversity_scores.push(diversity);
507            }
508        }
509
510        let mean_performance = fold_scores.iter().sum::<Float>() / fold_scores.len() as Float;
511        let std_performance = {
512            let variance = fold_scores
513                .iter()
514                .map(|&score| (score - mean_performance).powi(2))
515                .sum::<Float>()
516                / fold_scores.len() as Float;
517            variance.sqrt()
518        };
519
520        let z_score = 1.96; // 95% confidence
521        let margin_of_error = z_score * std_performance / (fold_scores.len() as Float).sqrt();
522        let confidence_interval = (
523            mean_performance - margin_of_error,
524            mean_performance + margin_of_error,
525        );
526
527        let ensemble_performance = EnsemblePerformanceMetrics {
528            mean_performance,
529            std_performance,
530            confidence_interval,
531            individual_fold_scores: fold_scores,
532            ensemble_vs_best_member: 0.0, // Would need individual model analysis
533            ensemble_vs_average_member: 0.0,
534            performance_gain: 0.0,
535        };
536
537        let diversity_analysis = if !diversity_scores.is_empty() {
538            let mean_diversity =
539                diversity_scores.iter().sum::<Float>() / diversity_scores.len() as Float;
540            DiversityAnalysis {
541                overall_diversity: mean_diversity,
542                pairwise_diversities: Array2::zeros((0, 0)), // Would calculate pairwise if needed
543                diversity_by_measure: {
544                    let mut map = HashMap::new();
545                    map.insert("q_statistic".to_string(), mean_diversity);
546                    map
547                },
548                diversity_distribution: diversity_scores,
549                optimal_diversity_size: None,
550            }
551        } else {
552            DiversityAnalysis {
553                overall_diversity: 0.0,
554                pairwise_diversities: Array2::zeros((0, 0)),
555                diversity_by_measure: HashMap::new(),
556                diversity_distribution: Vec::new(),
557                optimal_diversity_size: None,
558            }
559        };
560
561        Ok(EnsembleEvaluationResult {
562            ensemble_performance,
563            diversity_analysis,
564            stability_analysis: None,
565            member_contributions: Vec::new(),
566            out_of_bag_scores: None,
567            progressive_performance: None,
568            multi_objective_analysis: None,
569        })
570    }
571
572    /// Diversity evaluation implementation
573    fn evaluate_diversity<F>(
574        &mut self,
575        ensemble_predictions: &Array2<Float>,
576        true_labels: &Array1<Float>,
577        model_predictions: Option<&Array2<Float>>,
578        evaluation_fn: &F,
579    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
580    where
581        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
582    {
583        let (diversity_measures, _diversity_threshold) = match &self.config.strategy {
584            EnsembleEvaluationStrategy::DiversityEvaluation {
585                diversity_measures,
586                diversity_threshold,
587            } => (diversity_measures, *diversity_threshold),
588            _ => unreachable!(),
589        };
590
591        if let Some(model_preds) = model_predictions {
592            let n_models = model_preds.ncols();
593            let mut diversity_by_measure = HashMap::new();
594            let mut pairwise_diversities = Array2::zeros((n_models, n_models));
595
596            for measure in diversity_measures {
597                let diversity_value = match measure {
598                    DiversityMeasure::QStatistic => self.calculate_q_statistic(model_preds)?,
599                    DiversityMeasure::CorrelationCoefficient => {
600                        self.calculate_correlation_coefficient(model_preds)?
601                    }
602                    DiversityMeasure::DisagreementMeasure => {
603                        self.calculate_disagreement_measure(model_preds)?
604                    }
605                    DiversityMeasure::DoubleFaultMeasure => {
606                        self.calculate_double_fault_measure(model_preds, true_labels)?
607                    }
608                    DiversityMeasure::EntropyDiversity => {
609                        self.calculate_entropy_diversity(model_preds)?
610                    }
611                    DiversityMeasure::KohaviWolpertVariance => {
612                        self.calculate_kw_variance(model_preds, true_labels)?
613                    }
614                    DiversityMeasure::InterraterAgreement => {
615                        self.calculate_interrater_agreement(model_preds)?
616                    }
617                    DiversityMeasure::DifficultyMeasure => {
618                        self.calculate_difficulty_measure(model_preds, true_labels)?
619                    }
620                    DiversityMeasure::GeneralizedDiversity { alpha } => {
621                        self.calculate_generalized_diversity(model_preds, *alpha)?
622                    }
623                };
624
625                diversity_by_measure.insert(format!("{:?}", measure), diversity_value);
626            }
627
628            // Calculate pairwise diversities
629            for i in 0..n_models {
630                for j in i + 1..n_models {
631                    let pair_preds = Array2::from_shape_vec(
632                        (model_preds.nrows(), 2),
633                        model_preds
634                            .column(i)
635                            .iter()
636                            .cloned()
637                            .chain(model_preds.column(j).iter().cloned())
638                            .collect(),
639                    )?;
640                    let pair_diversity = self.calculate_q_statistic(&pair_preds)?;
641                    pairwise_diversities[[i, j]] = pair_diversity;
642                    pairwise_diversities[[j, i]] = pair_diversity;
643                }
644            }
645
646            let overall_diversity =
647                diversity_by_measure.values().sum::<Float>() / diversity_by_measure.len() as Float;
648
649            let diversity_analysis = DiversityAnalysis {
650                overall_diversity,
651                pairwise_diversities,
652                diversity_by_measure,
653                diversity_distribution: Vec::new(), // Could add distribution analysis
654                optimal_diversity_size: None,       // Could calculate optimal size
655            };
656
657            // Calculate basic ensemble performance
658            let ensemble_preds = self.calculate_ensemble_predictions(
659                ensemble_predictions,
660                &(0..ensemble_predictions.nrows()).collect::<Vec<_>>(),
661                None,
662            )?;
663            let ensemble_score = evaluation_fn(&ensemble_preds, true_labels)?;
664
665            let ensemble_performance = EnsemblePerformanceMetrics {
666                mean_performance: ensemble_score,
667                std_performance: 0.0,
668                confidence_interval: (ensemble_score, ensemble_score),
669                individual_fold_scores: vec![ensemble_score],
670                ensemble_vs_best_member: 0.0,
671                ensemble_vs_average_member: 0.0,
672                performance_gain: 0.0,
673            };
674
675            Ok(EnsembleEvaluationResult {
676                ensemble_performance,
677                diversity_analysis,
678                stability_analysis: None,
679                member_contributions: Vec::new(),
680                out_of_bag_scores: None,
681                progressive_performance: None,
682                multi_objective_analysis: None,
683            })
684        } else {
685            Err("Model predictions required for diversity evaluation".into())
686        }
687    }
688
689    /// Stability analysis implementation
690    fn evaluate_stability<F>(
691        &mut self,
692        ensemble_predictions: &Array2<Float>,
693        true_labels: &Array1<Float>,
694        ensemble_weights: Option<&Array1<Float>>,
695        evaluation_fn: &F,
696    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
697    where
698        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
699    {
700        let (n_bootstrap_samples, _stability_metrics) = match &self.config.strategy {
701            EnsembleEvaluationStrategy::StabilityAnalysis {
702                n_bootstrap_samples,
703                stability_metrics,
704            } => (*n_bootstrap_samples, stability_metrics),
705            _ => unreachable!(),
706        };
707
708        let n_samples = ensemble_predictions.nrows();
709        let mut bootstrap_scores = Vec::new();
710        let mut bootstrap_predictions = Vec::new();
711
712        for _ in 0..n_bootstrap_samples {
713            // Generate bootstrap sample
714            let bootstrap_indices: Vec<usize> = (0..n_samples)
715                .map(|_| self.rng.gen_range(0..n_samples))
716                .collect();
717
718            let bootstrap_preds = self.calculate_ensemble_predictions(
719                ensemble_predictions,
720                &bootstrap_indices,
721                ensemble_weights,
722            )?;
723
724            let bootstrap_labels =
725                Array1::from_vec(bootstrap_indices.iter().map(|&i| true_labels[i]).collect());
726
727            let bootstrap_score = evaluation_fn(&bootstrap_preds, &bootstrap_labels)?;
728            bootstrap_scores.push(bootstrap_score);
729            bootstrap_predictions.push(bootstrap_preds);
730        }
731
732        // Calculate prediction stability
733        let prediction_stability = self.calculate_prediction_stability(&bootstrap_predictions)?;
734
735        // Calculate performance stability
736        let mean_score = bootstrap_scores.iter().sum::<Float>() / bootstrap_scores.len() as Float;
737        let score_variance = bootstrap_scores
738            .iter()
739            .map(|&score| (score - mean_score).powi(2))
740            .sum::<Float>()
741            / bootstrap_scores.len() as Float;
742        let performance_stability = 1.0 / (1.0 + score_variance); // Higher variance = lower stability
743
744        let stability_analysis = StabilityAnalysis {
745            prediction_stability,
746            model_selection_stability: 0.8, // Placeholder - would need model selection data
747            weight_stability: None,         // Would calculate if weights provided
748            performance_stability,
749            stability_confidence_intervals: HashMap::new(), // Could add CIs
750        };
751
752        let ensemble_performance = EnsemblePerformanceMetrics {
753            mean_performance: mean_score,
754            std_performance: score_variance.sqrt(),
755            confidence_interval: (
756                mean_score - score_variance.sqrt(),
757                mean_score + score_variance.sqrt(),
758            ),
759            individual_fold_scores: bootstrap_scores,
760            ensemble_vs_best_member: 0.0,
761            ensemble_vs_average_member: 0.0,
762            performance_gain: 0.0,
763        };
764
765        Ok(EnsembleEvaluationResult {
766            ensemble_performance,
767            diversity_analysis: DiversityAnalysis {
768                overall_diversity: 0.0,
769                pairwise_diversities: Array2::zeros((0, 0)),
770                diversity_by_measure: HashMap::new(),
771                diversity_distribution: Vec::new(),
772                optimal_diversity_size: None,
773            },
774            stability_analysis: Some(stability_analysis),
775            member_contributions: Vec::new(),
776            out_of_bag_scores: None,
777            progressive_performance: None,
778            multi_objective_analysis: None,
779        })
780    }
781
782    /// Progressive evaluation implementation
783    fn evaluate_progressive<F>(
784        &mut self,
785        _ensemble_predictions: &Array2<Float>,
786        true_labels: &Array1<Float>,
787        model_predictions: Option<&Array2<Float>>,
788        evaluation_fn: &F,
789    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
790    where
791        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
792    {
793        let (ensemble_sizes, _selection_strategy) = match &self.config.strategy {
794            EnsembleEvaluationStrategy::ProgressiveEvaluation {
795                ensemble_sizes,
796                selection_strategy,
797            } => (ensemble_sizes, selection_strategy),
798            _ => unreachable!(),
799        };
800
801        if let Some(model_preds) = model_predictions {
802            let mut performance_curve = Vec::new();
803            let mut diversity_curve = Vec::new();
804            let n_models = model_preds.ncols();
805
806            for &size in ensemble_sizes {
807                if size <= n_models {
808                    // Select top models (simplified - could use more sophisticated selection)
809                    let selected_indices: Vec<usize> = (0..size).collect();
810
811                    // Calculate ensemble predictions for selected models
812                    let mut selected_data = Vec::new();
813                    for &i in selected_indices.iter() {
814                        selected_data.extend(model_preds.column(i).iter().cloned());
815                    }
816                    let selected_predictions =
817                        Array2::from_shape_vec((model_preds.nrows(), size), selected_data)?;
818
819                    let ensemble_preds = selected_predictions.mean_axis(Axis(1)).unwrap();
820                    let performance = evaluation_fn(&ensemble_preds, true_labels)?;
821                    performance_curve.push(performance);
822
823                    // Calculate diversity for selected models
824                    let diversity = self.calculate_q_statistic(&selected_predictions)?;
825                    diversity_curve.push(diversity);
826                }
827            }
828
829            // Find optimal size (highest performance)
830            let optimal_size_idx = performance_curve
831                .iter()
832                .enumerate()
833                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
834                .map(|(idx, _)| idx)
835                .unwrap_or(0);
836            let optimal_size = ensemble_sizes[optimal_size_idx];
837
838            let progressive_performance = ProgressivePerformance {
839                ensemble_sizes: ensemble_sizes.clone(),
840                performance_curve,
841                diversity_curve,
842                efficiency_curve: vec![1.0; ensemble_sizes.len()], // Placeholder
843                optimal_size,
844                diminishing_returns_threshold: None, // Could calculate
845            };
846
847            let ensemble_performance = EnsemblePerformanceMetrics {
848                mean_performance: progressive_performance.performance_curve[optimal_size_idx],
849                std_performance: 0.0,
850                confidence_interval: (0.0, 0.0),
851                individual_fold_scores: vec![],
852                ensemble_vs_best_member: 0.0,
853                ensemble_vs_average_member: 0.0,
854                performance_gain: 0.0,
855            };
856
857            Ok(EnsembleEvaluationResult {
858                ensemble_performance,
859                diversity_analysis: DiversityAnalysis {
860                    overall_diversity: 0.0,
861                    pairwise_diversities: Array2::zeros((0, 0)),
862                    diversity_by_measure: HashMap::new(),
863                    diversity_distribution: Vec::new(),
864                    optimal_diversity_size: Some(optimal_size),
865                },
866                stability_analysis: None,
867                member_contributions: Vec::new(),
868                out_of_bag_scores: None,
869                progressive_performance: Some(progressive_performance),
870                multi_objective_analysis: None,
871            })
872        } else {
873            Err("Model predictions required for progressive evaluation".into())
874        }
875    }
876
877    /// Multi-objective evaluation implementation
878    fn evaluate_multi_objective<F>(
879        &mut self,
880        ensemble_predictions: &Array2<Float>,
881        true_labels: &Array1<Float>,
882        ensemble_weights: Option<&Array1<Float>>,
883        evaluation_fn: &F,
884    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
885    where
886        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
887    {
888        // Simplified multi-objective evaluation
889        let ensemble_preds = self.calculate_ensemble_predictions(
890            ensemble_predictions,
891            &(0..ensemble_predictions.nrows()).collect::<Vec<_>>(),
892            ensemble_weights,
893        )?;
894        let performance = evaluation_fn(&ensemble_preds, true_labels)?;
895
896        let mut objective_scores = HashMap::new();
897        objective_scores.insert("accuracy".to_string(), performance);
898        objective_scores.insert("diversity".to_string(), 0.5); // Placeholder
899        objective_scores.insert("efficiency".to_string(), 0.8); // Placeholder
900
901        let multi_objective_analysis = MultiObjectiveAnalysis {
902            pareto_front: vec![(performance, 0.5)], // (accuracy, diversity)
903            objective_scores,
904            trade_off_analysis: HashMap::new(),
905            dominated_solutions: Vec::new(),
906            compromise_solution: Some(0),
907        };
908
909        let ensemble_performance = EnsemblePerformanceMetrics {
910            mean_performance: performance,
911            std_performance: 0.0,
912            confidence_interval: (performance, performance),
913            individual_fold_scores: vec![performance],
914            ensemble_vs_best_member: 0.0,
915            ensemble_vs_average_member: 0.0,
916            performance_gain: 0.0,
917        };
918
919        Ok(EnsembleEvaluationResult {
920            ensemble_performance,
921            diversity_analysis: DiversityAnalysis {
922                overall_diversity: 0.0,
923                pairwise_diversities: Array2::zeros((0, 0)),
924                diversity_by_measure: HashMap::new(),
925                diversity_distribution: Vec::new(),
926                optimal_diversity_size: None,
927            },
928            stability_analysis: None,
929            member_contributions: Vec::new(),
930            out_of_bag_scores: None,
931            progressive_performance: None,
932            multi_objective_analysis: Some(multi_objective_analysis),
933        })
934    }
935
936    /// Calculate ensemble predictions for given indices
937    fn calculate_ensemble_predictions(
938        &self,
939        ensemble_predictions: &Array2<Float>,
940        indices: &[usize],
941        weights: Option<&Array1<Float>>,
942    ) -> Result<Array1<Float>, Box<dyn std::error::Error>> {
943        let mut selected_data = Vec::new();
944        for &i in indices.iter() {
945            selected_data.extend(ensemble_predictions.row(i).iter().cloned());
946        }
947        let selected_predictions =
948            Array2::from_shape_vec((indices.len(), ensemble_predictions.ncols()), selected_data)?;
949
950        if let Some(w) = weights {
951            // Weighted average
952            Ok(selected_predictions.dot(w))
953        } else {
954            // Simple average
955            Ok(selected_predictions.mean_axis(Axis(1)).unwrap())
956        }
957    }
958
959    /// Calculate Q-statistic diversity measure
960    fn calculate_q_statistic(
961        &self,
962        predictions: &Array2<Float>,
963    ) -> Result<Float, Box<dyn std::error::Error>> {
964        let n_models = predictions.ncols();
965        if n_models < 2 {
966            return Ok(0.0);
967        }
968
969        let mut q_sum = 0.0;
970        let mut pairs = 0;
971
972        for i in 0..n_models {
973            for j in i + 1..n_models {
974                let pred_i = predictions.column(i);
975                let pred_j = predictions.column(j);
976
977                let mut n11 = 0; // Both correct
978                let mut n10 = 0; // i correct, j wrong
979                let mut n01 = 0; // i wrong, j correct
980                let mut n00 = 0; // Both wrong
981
982                for k in 0..predictions.nrows() {
983                    let i_correct = pred_i[k] > 0.5;
984                    let j_correct = pred_j[k] > 0.5;
985
986                    match (i_correct, j_correct) {
987                        (true, true) => n11 += 1,
988                        (true, false) => n10 += 1,
989                        (false, true) => n01 += 1,
990                        (false, false) => n00 += 1,
991                    }
992                }
993
994                let numerator = (n11 * n00 - n01 * n10) as Float;
995                let denominator = (n11 * n00 + n01 * n10) as Float;
996
997                if denominator != 0.0 {
998                    q_sum += numerator / denominator;
999                    pairs += 1;
1000                }
1001            }
1002        }
1003
1004        Ok(if pairs > 0 {
1005            q_sum / pairs as Float
1006        } else {
1007            0.0
1008        })
1009    }
1010
1011    /// Calculate correlation coefficient diversity measure
1012    fn calculate_correlation_coefficient(
1013        &self,
1014        predictions: &Array2<Float>,
1015    ) -> Result<Float, Box<dyn std::error::Error>> {
1016        let n_models = predictions.ncols();
1017        if n_models < 2 {
1018            return Ok(0.0);
1019        }
1020
1021        let mut correlations = Vec::new();
1022        for i in 0..n_models {
1023            for j in i + 1..n_models {
1024                let pred_i = predictions.column(i);
1025                let pred_j = predictions.column(j);
1026
1027                let mean_i = pred_i.mean().unwrap_or(0.0);
1028                let mean_j = pred_j.mean().unwrap_or(0.0);
1029
1030                let mut covariance = 0.0;
1031                let mut var_i = 0.0;
1032                let mut var_j = 0.0;
1033
1034                for k in 0..predictions.nrows() {
1035                    let diff_i = pred_i[k] - mean_i;
1036                    let diff_j = pred_j[k] - mean_j;
1037                    covariance += diff_i * diff_j;
1038                    var_i += diff_i * diff_i;
1039                    var_j += diff_j * diff_j;
1040                }
1041
1042                let correlation = if var_i > 0.0 && var_j > 0.0 {
1043                    covariance / (var_i.sqrt() * var_j.sqrt())
1044                } else {
1045                    0.0
1046                };
1047
1048                correlations.push(correlation.abs());
1049            }
1050        }
1051
1052        Ok(1.0 - correlations.iter().sum::<Float>() / correlations.len() as Float)
1053    }
1054
1055    /// Calculate disagreement measure
1056    fn calculate_disagreement_measure(
1057        &self,
1058        predictions: &Array2<Float>,
1059    ) -> Result<Float, Box<dyn std::error::Error>> {
1060        let n_models = predictions.ncols();
1061        if n_models < 2 {
1062            return Ok(0.0);
1063        }
1064
1065        let mut disagreement_sum = 0.0;
1066        let mut pairs = 0;
1067
1068        for i in 0..n_models {
1069            for j in i + 1..n_models {
1070                let pred_i = predictions.column(i);
1071                let pred_j = predictions.column(j);
1072
1073                let mut disagreements = 0;
1074                for k in 0..predictions.nrows() {
1075                    if (pred_i[k] > 0.5) != (pred_j[k] > 0.5) {
1076                        disagreements += 1;
1077                    }
1078                }
1079
1080                disagreement_sum += disagreements as Float / predictions.nrows() as Float;
1081                pairs += 1;
1082            }
1083        }
1084
1085        Ok(if pairs > 0 {
1086            disagreement_sum / pairs as Float
1087        } else {
1088            0.0
1089        })
1090    }
1091
1092    /// Calculate double fault measure
1093    fn calculate_double_fault_measure(
1094        &self,
1095        predictions: &Array2<Float>,
1096        true_labels: &Array1<Float>,
1097    ) -> Result<Float, Box<dyn std::error::Error>> {
1098        let n_models = predictions.ncols();
1099        if n_models < 2 {
1100            return Ok(0.0);
1101        }
1102
1103        let mut double_fault_sum = 0.0;
1104        let mut pairs = 0;
1105
1106        for i in 0..n_models {
1107            for j in i + 1..n_models {
1108                let pred_i = predictions.column(i);
1109                let pred_j = predictions.column(j);
1110
1111                let mut double_faults = 0;
1112                for k in 0..predictions.nrows() {
1113                    let i_wrong = (pred_i[k] > 0.5) != (true_labels[k] > 0.5);
1114                    let j_wrong = (pred_j[k] > 0.5) != (true_labels[k] > 0.5);
1115
1116                    if i_wrong && j_wrong {
1117                        double_faults += 1;
1118                    }
1119                }
1120
1121                double_fault_sum += double_faults as Float / predictions.nrows() as Float;
1122                pairs += 1;
1123            }
1124        }
1125
1126        Ok(if pairs > 0 {
1127            double_fault_sum / pairs as Float
1128        } else {
1129            0.0
1130        })
1131    }
1132
1133    /// Calculate entropy-based diversity
1134    fn calculate_entropy_diversity(
1135        &self,
1136        predictions: &Array2<Float>,
1137    ) -> Result<Float, Box<dyn std::error::Error>> {
1138        let n_samples = predictions.nrows();
1139        let n_models = predictions.ncols();
1140
1141        let mut entropy_sum = 0.0;
1142
1143        for i in 0..n_samples {
1144            let correct_count = predictions
1145                .row(i)
1146                .iter()
1147                .filter(|&&pred| pred > 0.5)
1148                .count() as Float;
1149
1150            let p = correct_count / n_models as Float;
1151            if p > 0.0 && p < 1.0 {
1152                entropy_sum += -p * p.log2() - (1.0 - p) * (1.0 - p).log2();
1153            }
1154        }
1155
1156        Ok(entropy_sum / n_samples as Float)
1157    }
1158
1159    /// Calculate Kohavi-Wolpert variance
1160    fn calculate_kw_variance(
1161        &self,
1162        predictions: &Array2<Float>,
1163        true_labels: &Array1<Float>,
1164    ) -> Result<Float, Box<dyn std::error::Error>> {
1165        let n_samples = predictions.nrows();
1166        let n_models = predictions.ncols();
1167
1168        let mut variance_sum = 0.0;
1169
1170        for i in 0..n_samples {
1171            let correct_count = predictions
1172                .row(i)
1173                .iter()
1174                .filter(|&&pred| (pred > 0.5) == (true_labels[i] > 0.5))
1175                .count() as Float;
1176
1177            let l = correct_count / n_models as Float;
1178            variance_sum += l * (1.0 - l);
1179        }
1180
1181        Ok(variance_sum / n_samples as Float)
1182    }
1183
1184    /// Calculate interrater agreement (simplified Kappa)
1185    fn calculate_interrater_agreement(
1186        &self,
1187        predictions: &Array2<Float>,
1188    ) -> Result<Float, Box<dyn std::error::Error>> {
1189        let n_models = predictions.ncols();
1190        if n_models < 2 {
1191            return Ok(0.0);
1192        }
1193
1194        let mut agreement_sum = 0.0;
1195        let mut pairs = 0;
1196
1197        for i in 0..n_models {
1198            for j in i + 1..n_models {
1199                let pred_i = predictions.column(i);
1200                let pred_j = predictions.column(j);
1201
1202                let mut agreements = 0;
1203                for k in 0..predictions.nrows() {
1204                    if (pred_i[k] > 0.5) == (pred_j[k] > 0.5) {
1205                        agreements += 1;
1206                    }
1207                }
1208
1209                agreement_sum += agreements as Float / predictions.nrows() as Float;
1210                pairs += 1;
1211            }
1212        }
1213
1214        Ok(if pairs > 0 {
1215            agreement_sum / pairs as Float
1216        } else {
1217            0.0
1218        })
1219    }
1220
1221    /// Calculate difficulty measure
1222    fn calculate_difficulty_measure(
1223        &self,
1224        predictions: &Array2<Float>,
1225        true_labels: &Array1<Float>,
1226    ) -> Result<Float, Box<dyn std::error::Error>> {
1227        let n_samples = predictions.nrows();
1228        let n_models = predictions.ncols();
1229
1230        let mut difficulty_sum = 0.0;
1231
1232        for i in 0..n_samples {
1233            let error_count = predictions
1234                .row(i)
1235                .iter()
1236                .filter(|&&pred| (pred > 0.5) != (true_labels[i] > 0.5))
1237                .count() as Float;
1238
1239            difficulty_sum += error_count / n_models as Float;
1240        }
1241
1242        Ok(difficulty_sum / n_samples as Float)
1243    }
1244
1245    /// Calculate generalized diversity index
1246    fn calculate_generalized_diversity(
1247        &self,
1248        predictions: &Array2<Float>,
1249        alpha: Float,
1250    ) -> Result<Float, Box<dyn std::error::Error>> {
1251        let n_samples = predictions.nrows();
1252        let n_models = predictions.ncols();
1253
1254        let mut diversity_sum = 0.0;
1255
1256        for i in 0..n_samples {
1257            let correct_count = predictions
1258                .row(i)
1259                .iter()
1260                .filter(|&&pred| pred > 0.5)
1261                .count() as Float;
1262
1263            let p = correct_count / n_models as Float;
1264            if alpha != 1.0 {
1265                diversity_sum += (1.0 - p.powf(alpha) - (1.0 - p).powf(alpha))
1266                    / (2.0_f64.powf(1.0 - alpha) as Float - 1.0);
1267            } else {
1268                // Shannon entropy case (alpha = 1)
1269                if p > 0.0 && p < 1.0 {
1270                    diversity_sum += -p * p.log2() - (1.0 - p) * (1.0 - p).log2();
1271                }
1272            }
1273        }
1274
1275        Ok(diversity_sum / n_samples as Float)
1276    }
1277
1278    /// Calculate prediction stability across bootstrap samples
1279    fn calculate_prediction_stability(
1280        &self,
1281        predictions: &[Array1<Float>],
1282    ) -> Result<Float, Box<dyn std::error::Error>> {
1283        if predictions.len() < 2 {
1284            return Ok(1.0);
1285        }
1286
1287        let n_samples = predictions[0].len();
1288        let mut stability_sum = 0.0;
1289
1290        for i in 0..n_samples {
1291            let sample_predictions: Vec<Float> = predictions.iter().map(|p| p[i]).collect();
1292            let mean_pred =
1293                sample_predictions.iter().sum::<Float>() / sample_predictions.len() as Float;
1294            let variance = sample_predictions
1295                .iter()
1296                .map(|&pred| (pred - mean_pred).powi(2))
1297                .sum::<Float>()
1298                / sample_predictions.len() as Float;
1299
1300            stability_sum += 1.0 / (1.0 + variance); // Higher variance = lower stability
1301        }
1302
1303        Ok(stability_sum / n_samples as Float)
1304    }
1305}
1306
1307/// Convenience function for ensemble evaluation
1308pub fn evaluate_ensemble<F>(
1309    ensemble_predictions: &Array2<Float>,
1310    true_labels: &Array1<Float>,
1311    ensemble_weights: Option<&Array1<Float>>,
1312    model_predictions: Option<&Array2<Float>>,
1313    evaluation_fn: F,
1314    config: Option<EnsembleEvaluationConfig>,
1315) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
1316where
1317    F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
1318{
1319    let config = config.unwrap_or_default();
1320    let mut evaluator = EnsembleEvaluator::new(config);
1321    evaluator.evaluate(
1322        ensemble_predictions,
1323        true_labels,
1324        ensemble_weights,
1325        model_predictions,
1326        evaluation_fn,
1327    )
1328}
1329
1330#[allow(non_snake_case)]
1331#[cfg(test)]
1332mod tests {
1333    use super::*;
1334
1335    fn mock_evaluation_function(
1336        predictions: &Array1<Float>,
1337        labels: &Array1<Float>,
1338    ) -> Result<Float, Box<dyn std::error::Error>> {
1339        let correct = predictions
1340            .iter()
1341            .zip(labels.iter())
1342            .filter(|(&pred, &label)| (pred > 0.5) == (label > 0.5))
1343            .count();
1344        Ok(correct as Float / predictions.len() as Float)
1345    }
1346
1347    #[test]
1348    fn test_ensemble_evaluator_creation() {
1349        let config = EnsembleEvaluationConfig::default();
1350        let evaluator = EnsembleEvaluator::new(config);
1351        assert_eq!(evaluator.config.confidence_level, 0.95);
1352    }
1353
1354    #[test]
1355    fn test_out_of_bag_evaluation() {
1356        let ensemble_predictions = Array2::from_shape_vec(
1357            (10, 3),
1358            vec![
1359                0.1, 0.8, 0.3, 0.9, 0.2, 0.7, 0.4, 0.6, 0.8, 0.1, 0.2, 0.9, 0.1, 0.8, 0.3, 0.6,
1360                0.5, 0.7, 0.9, 0.2, 0.3, 0.7, 0.2, 0.9, 0.1, 0.8, 0.4, 0.5, 0.6, 0.3,
1361            ],
1362        )
1363        .unwrap();
1364        let true_labels = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
1365
1366        let config = EnsembleEvaluationConfig {
1367            strategy: EnsembleEvaluationStrategy::OutOfBag {
1368                bootstrap_samples: 10,
1369                confidence_level: 0.95,
1370            },
1371            ..Default::default()
1372        };
1373
1374        let result = evaluate_ensemble(
1375            &ensemble_predictions,
1376            &true_labels,
1377            None,
1378            None,
1379            mock_evaluation_function,
1380            Some(config),
1381        )
1382        .unwrap();
1383
1384        assert!(result.out_of_bag_scores.is_some());
1385        let oob_scores = result.out_of_bag_scores.unwrap();
1386        assert!(oob_scores.oob_score >= 0.0 && oob_scores.oob_score <= 1.0);
1387    }
1388
1389    #[test]
1390    fn test_diversity_evaluation() {
1391        let ensemble_predictions = Array2::from_shape_vec(
1392            (10, 3),
1393            vec![
1394                0.1, 0.8, 0.3, 0.9, 0.2, 0.7, 0.4, 0.6, 0.8, 0.1, 0.2, 0.9, 0.1, 0.8, 0.3, 0.6,
1395                0.5, 0.7, 0.9, 0.2, 0.3, 0.7, 0.2, 0.9, 0.1, 0.8, 0.4, 0.5, 0.6, 0.3,
1396            ],
1397        )
1398        .unwrap();
1399        let true_labels = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
1400        let model_predictions = ensemble_predictions.clone();
1401
1402        let config = EnsembleEvaluationConfig {
1403            strategy: EnsembleEvaluationStrategy::DiversityEvaluation {
1404                diversity_measures: vec![
1405                    DiversityMeasure::QStatistic,
1406                    DiversityMeasure::DisagreementMeasure,
1407                ],
1408                diversity_threshold: 0.5,
1409            },
1410            ..Default::default()
1411        };
1412
1413        let result = evaluate_ensemble(
1414            &ensemble_predictions,
1415            &true_labels,
1416            None,
1417            Some(&model_predictions),
1418            mock_evaluation_function,
1419            Some(config),
1420        )
1421        .unwrap();
1422
1423        assert!(!result.diversity_analysis.diversity_by_measure.is_empty());
1424        assert!(result.diversity_analysis.overall_diversity >= 0.0);
1425    }
1426
1427    #[test]
1428    fn test_cross_validation_evaluation() {
1429        let ensemble_predictions = Array2::from_shape_vec(
1430            (10, 3),
1431            vec![
1432                0.1, 0.8, 0.3, 0.9, 0.2, 0.7, 0.4, 0.6, 0.8, 0.1, 0.2, 0.9, 0.1, 0.8, 0.3, 0.6,
1433                0.5, 0.7, 0.9, 0.2, 0.3, 0.7, 0.2, 0.9, 0.1, 0.8, 0.4, 0.5, 0.6, 0.3,
1434            ],
1435        )
1436        .unwrap();
1437        let true_labels = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
1438
1439        let config = EnsembleEvaluationConfig {
1440            strategy: EnsembleEvaluationStrategy::EnsembleCrossValidation {
1441                cv_strategy: EnsembleCVStrategy::KFoldEnsemble,
1442                n_folds: 5,
1443            },
1444            ..Default::default()
1445        };
1446
1447        let result = evaluate_ensemble(
1448            &ensemble_predictions,
1449            &true_labels,
1450            None,
1451            None,
1452            mock_evaluation_function,
1453            Some(config),
1454        )
1455        .unwrap();
1456
1457        assert!(!result
1458            .ensemble_performance
1459            .individual_fold_scores
1460            .is_empty());
1461        assert!(result.ensemble_performance.mean_performance >= 0.0);
1462    }
1463}