Skip to main content

sklears_model_selection/
ensemble_evaluation.rs

1//! Ensemble Cross-Validation and Diversity-Based Evaluation
2//!
3//! This module provides advanced evaluation methods for ensemble models including
4//! specialized cross-validation techniques, diversity measures, stability analysis,
5//! and out-of-bag evaluation strategies specifically designed for ensemble learning.
6
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::SeedableRng;
10use scirs2_core::RngExt;
11use sklears_core::types::Float;
12use std::collections::HashMap;
13
14/// Ensemble evaluation strategies
15#[derive(Debug, Clone)]
16pub enum EnsembleEvaluationStrategy {
17    /// Out-of-bag evaluation for bootstrap-based ensembles
18    OutOfBag {
19        bootstrap_samples: usize,
20
21        confidence_level: Float,
22    },
23    /// Ensemble-specific cross-validation
24    EnsembleCrossValidation {
25        cv_strategy: EnsembleCVStrategy,
26
27        n_folds: usize,
28    },
29    /// Diversity-based evaluation
30    DiversityEvaluation {
31        diversity_measures: Vec<DiversityMeasure>,
32        diversity_threshold: Float,
33    },
34    /// Stability analysis across different data splits
35    StabilityAnalysis {
36        n_bootstrap_samples: usize,
37        stability_metrics: Vec<StabilityMetric>,
38    },
39    /// Progressive ensemble evaluation
40    ProgressiveEvaluation {
41        ensemble_sizes: Vec<usize>,
42        selection_strategy: ProgressiveSelectionStrategy,
43    },
44    /// Multi-objective ensemble evaluation
45    MultiObjectiveEvaluation {
46        objectives: Vec<EvaluationObjective>,
47        trade_off_analysis: bool,
48    },
49}
50
51/// Cross-validation strategies specific to ensembles
52#[derive(Debug, Clone)]
53pub enum EnsembleCVStrategy {
54    /// Standard k-fold CV for ensembles
55    KFoldEnsemble,
56    /// Stratified CV maintaining ensemble member diversity
57    StratifiedEnsemble,
58    /// Leave-one-model-out CV
59    LeaveOneModelOut,
60    /// Bootstrap CV for ensemble components
61    BootstrapEnsemble { n_bootstrap: usize },
62    /// Nested CV for ensemble and member optimization
63    NestedEnsemble { inner_cv: usize, outer_cv: usize },
64    /// Time series CV for temporal ensembles
65    TimeSeriesEnsemble { n_splits: usize, test_size: Float },
66}
67
68/// Diversity measures for ensemble evaluation
69#[derive(Debug, Clone)]
70pub enum DiversityMeasure {
71    /// Q-statistic between pairs of classifiers
72    QStatistic,
73    /// Correlation coefficient between predictions
74    CorrelationCoefficient,
75    /// Disagreement measure
76    DisagreementMeasure,
77    /// Double-fault measure
78    DoubleFaultMeasure,
79    /// Entropy-based diversity
80    EntropyDiversity,
81    /// Kohavi-Wolpert variance
82    KohaviWolpertVariance,
83    /// Interrater agreement (Kappa)
84    InterraterAgreement,
85    /// Measurement of difficulty
86    DifficultyMeasure,
87    /// Generalized diversity index
88    GeneralizedDiversity { alpha: Float },
89}
90
91/// Stability metrics for ensemble evaluation
92#[derive(Debug, Clone)]
93pub enum StabilityMetric {
94    /// Prediction stability across bootstrap samples
95    PredictionStability,
96    /// Model selection stability
97    ModelSelectionStability,
98    /// Weight stability for weighted ensembles
99    WeightStability,
100    /// Performance stability
101    PerformanceStability,
102    /// Ranking stability of ensemble members
103    RankingStability,
104}
105
106/// Progressive selection strategies
107#[derive(Debug, Clone)]
108pub enum ProgressiveSelectionStrategy {
109    /// Forward selection based on performance
110    ForwardSelection,
111    /// Backward elimination
112    BackwardElimination,
113    /// Diversity-driven selection
114    DiversityDriven,
115    /// Performance-diversity trade-off
116    PerformanceDiversityTradeoff { alpha: Float },
117}
118
119/// Evaluation objectives for multi-objective analysis
120#[derive(Debug, Clone)]
121pub enum EvaluationObjective {
122    /// Predictive accuracy
123    Accuracy,
124    /// Model diversity
125    Diversity,
126    /// Computational efficiency
127    Efficiency,
128    /// Memory usage
129    MemoryUsage,
130    /// Robustness to outliers
131    Robustness,
132    /// Interpretability
133    Interpretability,
134    /// Fairness across groups
135    Fairness,
136}
137
138/// Ensemble evaluation configuration
139#[derive(Debug, Clone)]
140pub struct EnsembleEvaluationConfig {
141    pub strategy: EnsembleEvaluationStrategy,
142    pub evaluation_metrics: Vec<String>,
143    pub confidence_level: Float,
144    pub n_repetitions: usize,
145    pub parallel_evaluation: bool,
146    pub random_state: Option<u64>,
147    pub verbose: bool,
148}
149
150/// Ensemble evaluation result
151#[derive(Debug, Clone)]
152pub struct EnsembleEvaluationResult {
153    pub ensemble_performance: EnsemblePerformanceMetrics,
154    pub diversity_analysis: DiversityAnalysis,
155    pub stability_analysis: Option<StabilityAnalysis>,
156    pub member_contributions: Vec<MemberContribution>,
157    pub out_of_bag_scores: Option<OutOfBagScores>,
158    pub progressive_performance: Option<ProgressivePerformance>,
159    pub multi_objective_analysis: Option<MultiObjectiveAnalysis>,
160}
161
162/// Comprehensive ensemble performance metrics
163#[derive(Debug, Clone)]
164pub struct EnsemblePerformanceMetrics {
165    pub mean_performance: Float,
166    pub std_performance: Float,
167    pub confidence_interval: (Float, Float),
168    pub individual_fold_scores: Vec<Float>,
169    pub ensemble_vs_best_member: Float,
170    pub ensemble_vs_average_member: Float,
171    pub performance_gain: Float,
172}
173
174/// Diversity analysis results
175#[derive(Debug, Clone)]
176pub struct DiversityAnalysis {
177    pub overall_diversity: Float,
178    pub pairwise_diversities: Array2<Float>,
179    pub diversity_by_measure: HashMap<String, Float>,
180    pub diversity_distribution: Vec<Float>,
181    pub optimal_diversity_size: Option<usize>,
182}
183
184/// Stability analysis results
185#[derive(Debug, Clone)]
186pub struct StabilityAnalysis {
187    pub prediction_stability: Float,
188    pub model_selection_stability: Float,
189    pub weight_stability: Option<Float>,
190    pub performance_stability: Float,
191    pub stability_confidence_intervals: HashMap<String, (Float, Float)>,
192}
193
194/// Individual member contribution analysis
195#[derive(Debug, Clone)]
196pub struct MemberContribution {
197    pub member_id: usize,
198    pub member_name: String,
199    pub individual_performance: Float,
200    pub marginal_contribution: Float,
201    pub shapley_value: Option<Float>,
202    pub removal_impact: Float,
203    pub diversity_contribution: Float,
204}
205
206/// Out-of-bag evaluation scores
207#[derive(Debug, Clone)]
208pub struct OutOfBagScores {
209    pub oob_score: Float,
210    pub oob_confidence_interval: (Float, Float),
211    pub feature_importance: Option<Array1<Float>>,
212    pub prediction_intervals: Option<Array2<Float>>,
213    pub individual_oob_scores: Vec<Float>,
214}
215
216/// Progressive performance analysis
217#[derive(Debug, Clone)]
218pub struct ProgressivePerformance {
219    pub ensemble_sizes: Vec<usize>,
220    pub performance_curve: Vec<Float>,
221    pub diversity_curve: Vec<Float>,
222    pub efficiency_curve: Vec<Float>,
223    pub optimal_size: usize,
224    pub diminishing_returns_threshold: Option<usize>,
225}
226
227/// Multi-objective analysis results
228#[derive(Debug, Clone)]
229pub struct MultiObjectiveAnalysis {
230    pub pareto_front: Vec<(Float, Float)>,
231    pub objective_scores: HashMap<String, Float>,
232    pub trade_off_analysis: HashMap<String, Float>,
233    pub dominated_solutions: Vec<usize>,
234    pub compromise_solution: Option<usize>,
235}
236
237/// Ensemble evaluator
238#[derive(Debug)]
239pub struct EnsembleEvaluator {
240    config: EnsembleEvaluationConfig,
241    rng: StdRng,
242}
243
244impl Default for EnsembleEvaluationConfig {
245    fn default() -> Self {
246        Self {
247            strategy: EnsembleEvaluationStrategy::EnsembleCrossValidation {
248                cv_strategy: EnsembleCVStrategy::KFoldEnsemble,
249                n_folds: 5,
250            },
251            evaluation_metrics: vec!["accuracy".to_string(), "f1_score".to_string()],
252            confidence_level: 0.95,
253            n_repetitions: 1,
254            parallel_evaluation: false,
255            random_state: None,
256            verbose: false,
257        }
258    }
259}
260
261impl EnsembleEvaluator {
262    /// Create a new ensemble evaluator
263    pub fn new(config: EnsembleEvaluationConfig) -> Self {
264        let rng = match config.random_state {
265            Some(seed) => StdRng::seed_from_u64(seed),
266            None => {
267                use scirs2_core::random::thread_rng;
268                StdRng::from_rng(&mut thread_rng())
269            }
270        };
271
272        Self { config, rng }
273    }
274
275    /// Evaluate ensemble using specified strategy
276    pub fn evaluate<F>(
277        &mut self,
278        ensemble_predictions: &Array2<Float>,
279        true_labels: &Array1<Float>,
280        ensemble_weights: Option<&Array1<Float>>,
281        model_predictions: Option<&Array2<Float>>,
282        evaluation_fn: F,
283    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
284    where
285        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
286    {
287        match &self.config.strategy {
288            EnsembleEvaluationStrategy::OutOfBag { .. } => self.evaluate_out_of_bag(
289                ensemble_predictions,
290                true_labels,
291                ensemble_weights,
292                &evaluation_fn,
293            ),
294            EnsembleEvaluationStrategy::EnsembleCrossValidation { .. } => self
295                .evaluate_cross_validation(
296                    ensemble_predictions,
297                    true_labels,
298                    ensemble_weights,
299                    model_predictions,
300                    &evaluation_fn,
301                ),
302            EnsembleEvaluationStrategy::DiversityEvaluation { .. } => self.evaluate_diversity(
303                ensemble_predictions,
304                true_labels,
305                model_predictions,
306                &evaluation_fn,
307            ),
308            EnsembleEvaluationStrategy::StabilityAnalysis { .. } => self.evaluate_stability(
309                ensemble_predictions,
310                true_labels,
311                ensemble_weights,
312                &evaluation_fn,
313            ),
314            EnsembleEvaluationStrategy::ProgressiveEvaluation { .. } => self.evaluate_progressive(
315                ensemble_predictions,
316                true_labels,
317                model_predictions,
318                &evaluation_fn,
319            ),
320            EnsembleEvaluationStrategy::MultiObjectiveEvaluation { .. } => self
321                .evaluate_multi_objective(
322                    ensemble_predictions,
323                    true_labels,
324                    ensemble_weights,
325                    &evaluation_fn,
326                ),
327        }
328    }
329
330    /// Out-of-bag evaluation implementation
331    fn evaluate_out_of_bag<F>(
332        &mut self,
333        ensemble_predictions: &Array2<Float>,
334        true_labels: &Array1<Float>,
335        ensemble_weights: Option<&Array1<Float>>,
336        evaluation_fn: &F,
337    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
338    where
339        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
340    {
341        let (bootstrap_samples, confidence_level) = match &self.config.strategy {
342            EnsembleEvaluationStrategy::OutOfBag {
343                bootstrap_samples,
344                confidence_level,
345            } => (*bootstrap_samples, *confidence_level),
346            _ => unreachable!(),
347        };
348
349        let n_samples = ensemble_predictions.nrows();
350        let n_models = ensemble_predictions.ncols();
351
352        let mut oob_scores = Vec::new();
353        let mut oob_predictions_all = Vec::new();
354
355        for _ in 0..bootstrap_samples {
356            // Generate bootstrap sample
357            let bootstrap_indices: Vec<usize> = (0..n_samples)
358                .map(|_| self.rng.random_range(0..n_samples))
359                .collect();
360
361            // Find out-of-bag samples
362            let mut oob_indices = Vec::new();
363            for i in 0..n_samples {
364                if !bootstrap_indices.contains(&i) {
365                    oob_indices.push(i);
366                }
367            }
368
369            if oob_indices.is_empty() {
370                continue;
371            }
372
373            // Calculate OOB predictions
374            let oob_ensemble_preds = self.calculate_ensemble_predictions(
375                ensemble_predictions,
376                &oob_indices,
377                ensemble_weights,
378            )?;
379
380            let oob_true_labels =
381                Array1::from_vec(oob_indices.iter().map(|&i| true_labels[i]).collect());
382
383            let oob_score = evaluation_fn(&oob_ensemble_preds, &oob_true_labels)?;
384            oob_scores.push(oob_score);
385            oob_predictions_all.push(oob_ensemble_preds);
386        }
387
388        let mean_oob_score = oob_scores.iter().sum::<Float>() / oob_scores.len() as Float;
389        let std_oob_score = {
390            let variance = oob_scores
391                .iter()
392                .map(|&score| (score - mean_oob_score).powi(2))
393                .sum::<Float>()
394                / oob_scores.len() as Float;
395            variance.sqrt()
396        };
397
398        let _alpha = 1.0 - confidence_level;
399        let z_score = 1.96; // Approximate for 95% confidence
400        let margin_of_error = z_score * std_oob_score / (oob_scores.len() as Float).sqrt();
401        let confidence_interval = (
402            mean_oob_score - margin_of_error,
403            mean_oob_score + margin_of_error,
404        );
405
406        let oob_scores_result = OutOfBagScores {
407            oob_score: mean_oob_score,
408            oob_confidence_interval: confidence_interval,
409            feature_importance: None, // Could be calculated if feature data available
410            prediction_intervals: None, // Could be calculated from OOB predictions
411            individual_oob_scores: oob_scores,
412        };
413
414        // Calculate basic ensemble performance
415        let ensemble_preds = self.calculate_ensemble_predictions(
416            ensemble_predictions,
417            &(0..n_samples).collect::<Vec<_>>(),
418            ensemble_weights,
419        )?;
420        let ensemble_score = evaluation_fn(&ensemble_preds, true_labels)?;
421
422        let ensemble_performance = EnsemblePerformanceMetrics {
423            mean_performance: ensemble_score,
424            std_performance: std_oob_score,
425            confidence_interval,
426            individual_fold_scores: vec![ensemble_score],
427            ensemble_vs_best_member: 0.0, // Would need individual model scores
428            ensemble_vs_average_member: 0.0,
429            performance_gain: 0.0,
430        };
431
432        Ok(EnsembleEvaluationResult {
433            ensemble_performance,
434            diversity_analysis: DiversityAnalysis {
435                overall_diversity: 0.0,
436                pairwise_diversities: Array2::zeros((n_models, n_models)),
437                diversity_by_measure: HashMap::new(),
438                diversity_distribution: Vec::new(),
439                optimal_diversity_size: None,
440            },
441            stability_analysis: None,
442            member_contributions: Vec::new(),
443            out_of_bag_scores: Some(oob_scores_result),
444            progressive_performance: None,
445            multi_objective_analysis: None,
446        })
447    }
448
449    /// Cross-validation evaluation implementation
450    fn evaluate_cross_validation<F>(
451        &mut self,
452        ensemble_predictions: &Array2<Float>,
453        true_labels: &Array1<Float>,
454        ensemble_weights: Option<&Array1<Float>>,
455        model_predictions: Option<&Array2<Float>>,
456        evaluation_fn: &F,
457    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
458    where
459        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
460    {
461        let (_cv_strategy, n_folds) = match &self.config.strategy {
462            EnsembleEvaluationStrategy::EnsembleCrossValidation {
463                cv_strategy,
464                n_folds,
465            } => (cv_strategy, *n_folds),
466            _ => unreachable!(),
467        };
468
469        let n_samples = ensemble_predictions.nrows();
470        let fold_size = n_samples / n_folds;
471        let mut fold_scores = Vec::new();
472        let mut diversity_scores = Vec::new();
473
474        for fold in 0..n_folds {
475            let test_start = fold * fold_size;
476            let test_end = if fold == n_folds - 1 {
477                n_samples
478            } else {
479                (fold + 1) * fold_size
480            };
481            let test_indices: Vec<usize> = (test_start..test_end).collect();
482
483            // Calculate ensemble predictions for test fold
484            let test_ensemble_preds = self.calculate_ensemble_predictions(
485                ensemble_predictions,
486                &test_indices,
487                ensemble_weights,
488            )?;
489
490            let test_true_labels =
491                Array1::from_vec(test_indices.iter().map(|&i| true_labels[i]).collect());
492
493            let fold_score = evaluation_fn(&test_ensemble_preds, &test_true_labels)?;
494            fold_scores.push(fold_score);
495
496            // Calculate diversity for this fold if model predictions available
497            if let Some(model_preds) = model_predictions {
498                let mut fold_data = Vec::new();
499                for &i in test_indices.iter() {
500                    fold_data.extend(model_preds.row(i).iter().cloned());
501                }
502                let fold_model_preds =
503                    Array2::from_shape_vec((test_indices.len(), model_preds.ncols()), fold_data)?;
504
505                let diversity = self.calculate_q_statistic(&fold_model_preds)?;
506                diversity_scores.push(diversity);
507            }
508        }
509
510        let mean_performance = fold_scores.iter().sum::<Float>() / fold_scores.len() as Float;
511        let std_performance = {
512            let variance = fold_scores
513                .iter()
514                .map(|&score| (score - mean_performance).powi(2))
515                .sum::<Float>()
516                / fold_scores.len() as Float;
517            variance.sqrt()
518        };
519
520        let z_score = 1.96; // 95% confidence
521        let margin_of_error = z_score * std_performance / (fold_scores.len() as Float).sqrt();
522        let confidence_interval = (
523            mean_performance - margin_of_error,
524            mean_performance + margin_of_error,
525        );
526
527        let ensemble_performance = EnsemblePerformanceMetrics {
528            mean_performance,
529            std_performance,
530            confidence_interval,
531            individual_fold_scores: fold_scores,
532            ensemble_vs_best_member: 0.0, // Would need individual model analysis
533            ensemble_vs_average_member: 0.0,
534            performance_gain: 0.0,
535        };
536
537        let diversity_analysis = if !diversity_scores.is_empty() {
538            let mean_diversity =
539                diversity_scores.iter().sum::<Float>() / diversity_scores.len() as Float;
540            DiversityAnalysis {
541                overall_diversity: mean_diversity,
542                pairwise_diversities: Array2::zeros((0, 0)), // Would calculate pairwise if needed
543                diversity_by_measure: {
544                    let mut map = HashMap::new();
545                    map.insert("q_statistic".to_string(), mean_diversity);
546                    map
547                },
548                diversity_distribution: diversity_scores,
549                optimal_diversity_size: None,
550            }
551        } else {
552            DiversityAnalysis {
553                overall_diversity: 0.0,
554                pairwise_diversities: Array2::zeros((0, 0)),
555                diversity_by_measure: HashMap::new(),
556                diversity_distribution: Vec::new(),
557                optimal_diversity_size: None,
558            }
559        };
560
561        Ok(EnsembleEvaluationResult {
562            ensemble_performance,
563            diversity_analysis,
564            stability_analysis: None,
565            member_contributions: Vec::new(),
566            out_of_bag_scores: None,
567            progressive_performance: None,
568            multi_objective_analysis: None,
569        })
570    }
571
572    /// Diversity evaluation implementation
573    fn evaluate_diversity<F>(
574        &mut self,
575        ensemble_predictions: &Array2<Float>,
576        true_labels: &Array1<Float>,
577        model_predictions: Option<&Array2<Float>>,
578        evaluation_fn: &F,
579    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
580    where
581        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
582    {
583        let (diversity_measures, _diversity_threshold) = match &self.config.strategy {
584            EnsembleEvaluationStrategy::DiversityEvaluation {
585                diversity_measures,
586                diversity_threshold,
587            } => (diversity_measures, *diversity_threshold),
588            _ => unreachable!(),
589        };
590
591        if let Some(model_preds) = model_predictions {
592            let n_models = model_preds.ncols();
593            let mut diversity_by_measure = HashMap::new();
594            let mut pairwise_diversities = Array2::zeros((n_models, n_models));
595
596            for measure in diversity_measures {
597                let diversity_value = match measure {
598                    DiversityMeasure::QStatistic => self.calculate_q_statistic(model_preds)?,
599                    DiversityMeasure::CorrelationCoefficient => {
600                        self.calculate_correlation_coefficient(model_preds)?
601                    }
602                    DiversityMeasure::DisagreementMeasure => {
603                        self.calculate_disagreement_measure(model_preds)?
604                    }
605                    DiversityMeasure::DoubleFaultMeasure => {
606                        self.calculate_double_fault_measure(model_preds, true_labels)?
607                    }
608                    DiversityMeasure::EntropyDiversity => {
609                        self.calculate_entropy_diversity(model_preds)?
610                    }
611                    DiversityMeasure::KohaviWolpertVariance => {
612                        self.calculate_kw_variance(model_preds, true_labels)?
613                    }
614                    DiversityMeasure::InterraterAgreement => {
615                        self.calculate_interrater_agreement(model_preds)?
616                    }
617                    DiversityMeasure::DifficultyMeasure => {
618                        self.calculate_difficulty_measure(model_preds, true_labels)?
619                    }
620                    DiversityMeasure::GeneralizedDiversity { alpha } => {
621                        self.calculate_generalized_diversity(model_preds, *alpha)?
622                    }
623                };
624
625                diversity_by_measure.insert(format!("{:?}", measure), diversity_value);
626            }
627
628            // Calculate pairwise diversities
629            for i in 0..n_models {
630                for j in i + 1..n_models {
631                    let pair_preds = Array2::from_shape_vec(
632                        (model_preds.nrows(), 2),
633                        model_preds
634                            .column(i)
635                            .iter()
636                            .cloned()
637                            .chain(model_preds.column(j).iter().cloned())
638                            .collect(),
639                    )?;
640                    let pair_diversity = self.calculate_q_statistic(&pair_preds)?;
641                    pairwise_diversities[[i, j]] = pair_diversity;
642                    pairwise_diversities[[j, i]] = pair_diversity;
643                }
644            }
645
646            let overall_diversity =
647                diversity_by_measure.values().sum::<Float>() / diversity_by_measure.len() as Float;
648
649            let diversity_analysis = DiversityAnalysis {
650                overall_diversity,
651                pairwise_diversities,
652                diversity_by_measure,
653                diversity_distribution: Vec::new(), // Could add distribution analysis
654                optimal_diversity_size: None,       // Could calculate optimal size
655            };
656
657            // Calculate basic ensemble performance
658            let ensemble_preds = self.calculate_ensemble_predictions(
659                ensemble_predictions,
660                &(0..ensemble_predictions.nrows()).collect::<Vec<_>>(),
661                None,
662            )?;
663            let ensemble_score = evaluation_fn(&ensemble_preds, true_labels)?;
664
665            let ensemble_performance = EnsemblePerformanceMetrics {
666                mean_performance: ensemble_score,
667                std_performance: 0.0,
668                confidence_interval: (ensemble_score, ensemble_score),
669                individual_fold_scores: vec![ensemble_score],
670                ensemble_vs_best_member: 0.0,
671                ensemble_vs_average_member: 0.0,
672                performance_gain: 0.0,
673            };
674
675            Ok(EnsembleEvaluationResult {
676                ensemble_performance,
677                diversity_analysis,
678                stability_analysis: None,
679                member_contributions: Vec::new(),
680                out_of_bag_scores: None,
681                progressive_performance: None,
682                multi_objective_analysis: None,
683            })
684        } else {
685            Err("Model predictions required for diversity evaluation".into())
686        }
687    }
688
689    /// Stability analysis implementation
690    fn evaluate_stability<F>(
691        &mut self,
692        ensemble_predictions: &Array2<Float>,
693        true_labels: &Array1<Float>,
694        ensemble_weights: Option<&Array1<Float>>,
695        evaluation_fn: &F,
696    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
697    where
698        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
699    {
700        let (n_bootstrap_samples, _stability_metrics) = match &self.config.strategy {
701            EnsembleEvaluationStrategy::StabilityAnalysis {
702                n_bootstrap_samples,
703                stability_metrics,
704            } => (*n_bootstrap_samples, stability_metrics),
705            _ => unreachable!(),
706        };
707
708        let n_samples = ensemble_predictions.nrows();
709        let mut bootstrap_scores = Vec::new();
710        let mut bootstrap_predictions = Vec::new();
711
712        for _ in 0..n_bootstrap_samples {
713            // Generate bootstrap sample
714            let bootstrap_indices: Vec<usize> = (0..n_samples)
715                .map(|_| self.rng.random_range(0..n_samples))
716                .collect();
717
718            let bootstrap_preds = self.calculate_ensemble_predictions(
719                ensemble_predictions,
720                &bootstrap_indices,
721                ensemble_weights,
722            )?;
723
724            let bootstrap_labels =
725                Array1::from_vec(bootstrap_indices.iter().map(|&i| true_labels[i]).collect());
726
727            let bootstrap_score = evaluation_fn(&bootstrap_preds, &bootstrap_labels)?;
728            bootstrap_scores.push(bootstrap_score);
729            bootstrap_predictions.push(bootstrap_preds);
730        }
731
732        // Calculate prediction stability
733        let prediction_stability = self.calculate_prediction_stability(&bootstrap_predictions)?;
734
735        // Calculate performance stability
736        let mean_score = bootstrap_scores.iter().sum::<Float>() / bootstrap_scores.len() as Float;
737        let score_variance = bootstrap_scores
738            .iter()
739            .map(|&score| (score - mean_score).powi(2))
740            .sum::<Float>()
741            / bootstrap_scores.len() as Float;
742        let performance_stability = 1.0 / (1.0 + score_variance); // Higher variance = lower stability
743
744        let stability_analysis = StabilityAnalysis {
745            prediction_stability,
746            model_selection_stability: 0.8, // Placeholder - would need model selection data
747            weight_stability: None,         // Would calculate if weights provided
748            performance_stability,
749            stability_confidence_intervals: HashMap::new(), // Could add CIs
750        };
751
752        let ensemble_performance = EnsemblePerformanceMetrics {
753            mean_performance: mean_score,
754            std_performance: score_variance.sqrt(),
755            confidence_interval: (
756                mean_score - score_variance.sqrt(),
757                mean_score + score_variance.sqrt(),
758            ),
759            individual_fold_scores: bootstrap_scores,
760            ensemble_vs_best_member: 0.0,
761            ensemble_vs_average_member: 0.0,
762            performance_gain: 0.0,
763        };
764
765        Ok(EnsembleEvaluationResult {
766            ensemble_performance,
767            diversity_analysis: DiversityAnalysis {
768                overall_diversity: 0.0,
769                pairwise_diversities: Array2::zeros((0, 0)),
770                diversity_by_measure: HashMap::new(),
771                diversity_distribution: Vec::new(),
772                optimal_diversity_size: None,
773            },
774            stability_analysis: Some(stability_analysis),
775            member_contributions: Vec::new(),
776            out_of_bag_scores: None,
777            progressive_performance: None,
778            multi_objective_analysis: None,
779        })
780    }
781
782    /// Progressive evaluation implementation
783    fn evaluate_progressive<F>(
784        &mut self,
785        _ensemble_predictions: &Array2<Float>,
786        true_labels: &Array1<Float>,
787        model_predictions: Option<&Array2<Float>>,
788        evaluation_fn: &F,
789    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
790    where
791        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
792    {
793        let (ensemble_sizes, _selection_strategy) = match &self.config.strategy {
794            EnsembleEvaluationStrategy::ProgressiveEvaluation {
795                ensemble_sizes,
796                selection_strategy,
797            } => (ensemble_sizes, selection_strategy),
798            _ => unreachable!(),
799        };
800
801        if let Some(model_preds) = model_predictions {
802            let mut performance_curve = Vec::new();
803            let mut diversity_curve = Vec::new();
804            let n_models = model_preds.ncols();
805
806            for &size in ensemble_sizes {
807                if size <= n_models {
808                    // Select top models (simplified - could use more sophisticated selection)
809                    let selected_indices: Vec<usize> = (0..size).collect();
810
811                    // Calculate ensemble predictions for selected models
812                    let mut selected_data = Vec::new();
813                    for &i in selected_indices.iter() {
814                        selected_data.extend(model_preds.column(i).iter().cloned());
815                    }
816                    let selected_predictions =
817                        Array2::from_shape_vec((model_preds.nrows(), size), selected_data)?;
818
819                    let ensemble_preds = selected_predictions
820                        .mean_axis(Axis(1))
821                        .expect("operation should succeed");
822                    let performance = evaluation_fn(&ensemble_preds, true_labels)?;
823                    performance_curve.push(performance);
824
825                    // Calculate diversity for selected models
826                    let diversity = self.calculate_q_statistic(&selected_predictions)?;
827                    diversity_curve.push(diversity);
828                }
829            }
830
831            // Find optimal size (highest performance)
832            let optimal_size_idx = performance_curve
833                .iter()
834                .enumerate()
835                .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
836                .map(|(idx, _)| idx)
837                .unwrap_or(0);
838            let optimal_size = ensemble_sizes[optimal_size_idx];
839
840            let progressive_performance = ProgressivePerformance {
841                ensemble_sizes: ensemble_sizes.clone(),
842                performance_curve,
843                diversity_curve,
844                efficiency_curve: vec![1.0; ensemble_sizes.len()], // Placeholder
845                optimal_size,
846                diminishing_returns_threshold: None, // Could calculate
847            };
848
849            let ensemble_performance = EnsemblePerformanceMetrics {
850                mean_performance: progressive_performance.performance_curve[optimal_size_idx],
851                std_performance: 0.0,
852                confidence_interval: (0.0, 0.0),
853                individual_fold_scores: vec![],
854                ensemble_vs_best_member: 0.0,
855                ensemble_vs_average_member: 0.0,
856                performance_gain: 0.0,
857            };
858
859            Ok(EnsembleEvaluationResult {
860                ensemble_performance,
861                diversity_analysis: DiversityAnalysis {
862                    overall_diversity: 0.0,
863                    pairwise_diversities: Array2::zeros((0, 0)),
864                    diversity_by_measure: HashMap::new(),
865                    diversity_distribution: Vec::new(),
866                    optimal_diversity_size: Some(optimal_size),
867                },
868                stability_analysis: None,
869                member_contributions: Vec::new(),
870                out_of_bag_scores: None,
871                progressive_performance: Some(progressive_performance),
872                multi_objective_analysis: None,
873            })
874        } else {
875            Err("Model predictions required for progressive evaluation".into())
876        }
877    }
878
879    /// Multi-objective evaluation implementation
880    fn evaluate_multi_objective<F>(
881        &mut self,
882        ensemble_predictions: &Array2<Float>,
883        true_labels: &Array1<Float>,
884        ensemble_weights: Option<&Array1<Float>>,
885        evaluation_fn: &F,
886    ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
887    where
888        F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
889    {
890        // Simplified multi-objective evaluation
891        let ensemble_preds = self.calculate_ensemble_predictions(
892            ensemble_predictions,
893            &(0..ensemble_predictions.nrows()).collect::<Vec<_>>(),
894            ensemble_weights,
895        )?;
896        let performance = evaluation_fn(&ensemble_preds, true_labels)?;
897
898        let mut objective_scores = HashMap::new();
899        objective_scores.insert("accuracy".to_string(), performance);
900        objective_scores.insert("diversity".to_string(), 0.5); // Placeholder
901        objective_scores.insert("efficiency".to_string(), 0.8); // Placeholder
902
903        let multi_objective_analysis = MultiObjectiveAnalysis {
904            pareto_front: vec![(performance, 0.5)], // (accuracy, diversity)
905            objective_scores,
906            trade_off_analysis: HashMap::new(),
907            dominated_solutions: Vec::new(),
908            compromise_solution: Some(0),
909        };
910
911        let ensemble_performance = EnsemblePerformanceMetrics {
912            mean_performance: performance,
913            std_performance: 0.0,
914            confidence_interval: (performance, performance),
915            individual_fold_scores: vec![performance],
916            ensemble_vs_best_member: 0.0,
917            ensemble_vs_average_member: 0.0,
918            performance_gain: 0.0,
919        };
920
921        Ok(EnsembleEvaluationResult {
922            ensemble_performance,
923            diversity_analysis: DiversityAnalysis {
924                overall_diversity: 0.0,
925                pairwise_diversities: Array2::zeros((0, 0)),
926                diversity_by_measure: HashMap::new(),
927                diversity_distribution: Vec::new(),
928                optimal_diversity_size: None,
929            },
930            stability_analysis: None,
931            member_contributions: Vec::new(),
932            out_of_bag_scores: None,
933            progressive_performance: None,
934            multi_objective_analysis: Some(multi_objective_analysis),
935        })
936    }
937
938    /// Calculate ensemble predictions for given indices
939    fn calculate_ensemble_predictions(
940        &self,
941        ensemble_predictions: &Array2<Float>,
942        indices: &[usize],
943        weights: Option<&Array1<Float>>,
944    ) -> Result<Array1<Float>, Box<dyn std::error::Error>> {
945        let mut selected_data = Vec::new();
946        for &i in indices.iter() {
947            selected_data.extend(ensemble_predictions.row(i).iter().cloned());
948        }
949        let selected_predictions =
950            Array2::from_shape_vec((indices.len(), ensemble_predictions.ncols()), selected_data)?;
951
952        if let Some(w) = weights {
953            // Weighted average
954            Ok(selected_predictions.dot(w))
955        } else {
956            // Simple average
957            Ok(selected_predictions
958                .mean_axis(Axis(1))
959                .expect("operation should succeed"))
960        }
961    }
962
963    /// Calculate Q-statistic diversity measure
964    fn calculate_q_statistic(
965        &self,
966        predictions: &Array2<Float>,
967    ) -> Result<Float, Box<dyn std::error::Error>> {
968        let n_models = predictions.ncols();
969        if n_models < 2 {
970            return Ok(0.0);
971        }
972
973        let mut q_sum = 0.0;
974        let mut pairs = 0;
975
976        for i in 0..n_models {
977            for j in i + 1..n_models {
978                let pred_i = predictions.column(i);
979                let pred_j = predictions.column(j);
980
981                let mut n11 = 0; // Both correct
982                let mut n10 = 0; // i correct, j wrong
983                let mut n01 = 0; // i wrong, j correct
984                let mut n00 = 0; // Both wrong
985
986                for k in 0..predictions.nrows() {
987                    let i_correct = pred_i[k] > 0.5;
988                    let j_correct = pred_j[k] > 0.5;
989
990                    match (i_correct, j_correct) {
991                        (true, true) => n11 += 1,
992                        (true, false) => n10 += 1,
993                        (false, true) => n01 += 1,
994                        (false, false) => n00 += 1,
995                    }
996                }
997
998                let numerator = (n11 * n00 - n01 * n10) as Float;
999                let denominator = (n11 * n00 + n01 * n10) as Float;
1000
1001                if denominator != 0.0 {
1002                    q_sum += numerator / denominator;
1003                    pairs += 1;
1004                }
1005            }
1006        }
1007
1008        Ok(if pairs > 0 {
1009            q_sum / pairs as Float
1010        } else {
1011            0.0
1012        })
1013    }
1014
1015    /// Calculate correlation coefficient diversity measure
1016    fn calculate_correlation_coefficient(
1017        &self,
1018        predictions: &Array2<Float>,
1019    ) -> Result<Float, Box<dyn std::error::Error>> {
1020        let n_models = predictions.ncols();
1021        if n_models < 2 {
1022            return Ok(0.0);
1023        }
1024
1025        let mut correlations = Vec::new();
1026        for i in 0..n_models {
1027            for j in i + 1..n_models {
1028                let pred_i = predictions.column(i);
1029                let pred_j = predictions.column(j);
1030
1031                let mean_i = pred_i.mean().unwrap_or(0.0);
1032                let mean_j = pred_j.mean().unwrap_or(0.0);
1033
1034                let mut covariance = 0.0;
1035                let mut var_i = 0.0;
1036                let mut var_j = 0.0;
1037
1038                for k in 0..predictions.nrows() {
1039                    let diff_i = pred_i[k] - mean_i;
1040                    let diff_j = pred_j[k] - mean_j;
1041                    covariance += diff_i * diff_j;
1042                    var_i += diff_i * diff_i;
1043                    var_j += diff_j * diff_j;
1044                }
1045
1046                let correlation = if var_i > 0.0 && var_j > 0.0 {
1047                    covariance / (var_i.sqrt() * var_j.sqrt())
1048                } else {
1049                    0.0
1050                };
1051
1052                correlations.push(correlation.abs());
1053            }
1054        }
1055
1056        Ok(1.0 - correlations.iter().sum::<Float>() / correlations.len() as Float)
1057    }
1058
1059    /// Calculate disagreement measure
1060    fn calculate_disagreement_measure(
1061        &self,
1062        predictions: &Array2<Float>,
1063    ) -> Result<Float, Box<dyn std::error::Error>> {
1064        let n_models = predictions.ncols();
1065        if n_models < 2 {
1066            return Ok(0.0);
1067        }
1068
1069        let mut disagreement_sum = 0.0;
1070        let mut pairs = 0;
1071
1072        for i in 0..n_models {
1073            for j in i + 1..n_models {
1074                let pred_i = predictions.column(i);
1075                let pred_j = predictions.column(j);
1076
1077                let mut disagreements = 0;
1078                for k in 0..predictions.nrows() {
1079                    if (pred_i[k] > 0.5) != (pred_j[k] > 0.5) {
1080                        disagreements += 1;
1081                    }
1082                }
1083
1084                disagreement_sum += disagreements as Float / predictions.nrows() as Float;
1085                pairs += 1;
1086            }
1087        }
1088
1089        Ok(if pairs > 0 {
1090            disagreement_sum / pairs as Float
1091        } else {
1092            0.0
1093        })
1094    }
1095
1096    /// Calculate double fault measure
1097    fn calculate_double_fault_measure(
1098        &self,
1099        predictions: &Array2<Float>,
1100        true_labels: &Array1<Float>,
1101    ) -> Result<Float, Box<dyn std::error::Error>> {
1102        let n_models = predictions.ncols();
1103        if n_models < 2 {
1104            return Ok(0.0);
1105        }
1106
1107        let mut double_fault_sum = 0.0;
1108        let mut pairs = 0;
1109
1110        for i in 0..n_models {
1111            for j in i + 1..n_models {
1112                let pred_i = predictions.column(i);
1113                let pred_j = predictions.column(j);
1114
1115                let mut double_faults = 0;
1116                for k in 0..predictions.nrows() {
1117                    let i_wrong = (pred_i[k] > 0.5) != (true_labels[k] > 0.5);
1118                    let j_wrong = (pred_j[k] > 0.5) != (true_labels[k] > 0.5);
1119
1120                    if i_wrong && j_wrong {
1121                        double_faults += 1;
1122                    }
1123                }
1124
1125                double_fault_sum += double_faults as Float / predictions.nrows() as Float;
1126                pairs += 1;
1127            }
1128        }
1129
1130        Ok(if pairs > 0 {
1131            double_fault_sum / pairs as Float
1132        } else {
1133            0.0
1134        })
1135    }
1136
1137    /// Calculate entropy-based diversity
1138    fn calculate_entropy_diversity(
1139        &self,
1140        predictions: &Array2<Float>,
1141    ) -> Result<Float, Box<dyn std::error::Error>> {
1142        let n_samples = predictions.nrows();
1143        let n_models = predictions.ncols();
1144
1145        let mut entropy_sum = 0.0;
1146
1147        for i in 0..n_samples {
1148            let correct_count = predictions
1149                .row(i)
1150                .iter()
1151                .filter(|&&pred| pred > 0.5)
1152                .count() as Float;
1153
1154            let p = correct_count / n_models as Float;
1155            if p > 0.0 && p < 1.0 {
1156                entropy_sum += -p * p.log2() - (1.0 - p) * (1.0 - p).log2();
1157            }
1158        }
1159
1160        Ok(entropy_sum / n_samples as Float)
1161    }
1162
1163    /// Calculate Kohavi-Wolpert variance
1164    fn calculate_kw_variance(
1165        &self,
1166        predictions: &Array2<Float>,
1167        true_labels: &Array1<Float>,
1168    ) -> Result<Float, Box<dyn std::error::Error>> {
1169        let n_samples = predictions.nrows();
1170        let n_models = predictions.ncols();
1171
1172        let mut variance_sum = 0.0;
1173
1174        for i in 0..n_samples {
1175            let correct_count = predictions
1176                .row(i)
1177                .iter()
1178                .filter(|&&pred| (pred > 0.5) == (true_labels[i] > 0.5))
1179                .count() as Float;
1180
1181            let l = correct_count / n_models as Float;
1182            variance_sum += l * (1.0 - l);
1183        }
1184
1185        Ok(variance_sum / n_samples as Float)
1186    }
1187
1188    /// Calculate interrater agreement (simplified Kappa)
1189    fn calculate_interrater_agreement(
1190        &self,
1191        predictions: &Array2<Float>,
1192    ) -> Result<Float, Box<dyn std::error::Error>> {
1193        let n_models = predictions.ncols();
1194        if n_models < 2 {
1195            return Ok(0.0);
1196        }
1197
1198        let mut agreement_sum = 0.0;
1199        let mut pairs = 0;
1200
1201        for i in 0..n_models {
1202            for j in i + 1..n_models {
1203                let pred_i = predictions.column(i);
1204                let pred_j = predictions.column(j);
1205
1206                let mut agreements = 0;
1207                for k in 0..predictions.nrows() {
1208                    if (pred_i[k] > 0.5) == (pred_j[k] > 0.5) {
1209                        agreements += 1;
1210                    }
1211                }
1212
1213                agreement_sum += agreements as Float / predictions.nrows() as Float;
1214                pairs += 1;
1215            }
1216        }
1217
1218        Ok(if pairs > 0 {
1219            agreement_sum / pairs as Float
1220        } else {
1221            0.0
1222        })
1223    }
1224
1225    /// Calculate difficulty measure
1226    fn calculate_difficulty_measure(
1227        &self,
1228        predictions: &Array2<Float>,
1229        true_labels: &Array1<Float>,
1230    ) -> Result<Float, Box<dyn std::error::Error>> {
1231        let n_samples = predictions.nrows();
1232        let n_models = predictions.ncols();
1233
1234        let mut difficulty_sum = 0.0;
1235
1236        for i in 0..n_samples {
1237            let error_count = predictions
1238                .row(i)
1239                .iter()
1240                .filter(|&&pred| (pred > 0.5) != (true_labels[i] > 0.5))
1241                .count() as Float;
1242
1243            difficulty_sum += error_count / n_models as Float;
1244        }
1245
1246        Ok(difficulty_sum / n_samples as Float)
1247    }
1248
1249    /// Calculate generalized diversity index
1250    fn calculate_generalized_diversity(
1251        &self,
1252        predictions: &Array2<Float>,
1253        alpha: Float,
1254    ) -> Result<Float, Box<dyn std::error::Error>> {
1255        let n_samples = predictions.nrows();
1256        let n_models = predictions.ncols();
1257
1258        let mut diversity_sum = 0.0;
1259
1260        for i in 0..n_samples {
1261            let correct_count = predictions
1262                .row(i)
1263                .iter()
1264                .filter(|&&pred| pred > 0.5)
1265                .count() as Float;
1266
1267            let p = correct_count / n_models as Float;
1268            if alpha != 1.0 {
1269                diversity_sum += (1.0 - p.powf(alpha) - (1.0 - p).powf(alpha))
1270                    / (2.0_f64.powf(1.0 - alpha) as Float - 1.0);
1271            } else {
1272                // Shannon entropy case (alpha = 1)
1273                if p > 0.0 && p < 1.0 {
1274                    diversity_sum += -p * p.log2() - (1.0 - p) * (1.0 - p).log2();
1275                }
1276            }
1277        }
1278
1279        Ok(diversity_sum / n_samples as Float)
1280    }
1281
1282    /// Calculate prediction stability across bootstrap samples
1283    fn calculate_prediction_stability(
1284        &self,
1285        predictions: &[Array1<Float>],
1286    ) -> Result<Float, Box<dyn std::error::Error>> {
1287        if predictions.len() < 2 {
1288            return Ok(1.0);
1289        }
1290
1291        let n_samples = predictions[0].len();
1292        let mut stability_sum = 0.0;
1293
1294        for i in 0..n_samples {
1295            let sample_predictions: Vec<Float> = predictions.iter().map(|p| p[i]).collect();
1296            let mean_pred =
1297                sample_predictions.iter().sum::<Float>() / sample_predictions.len() as Float;
1298            let variance = sample_predictions
1299                .iter()
1300                .map(|&pred| (pred - mean_pred).powi(2))
1301                .sum::<Float>()
1302                / sample_predictions.len() as Float;
1303
1304            stability_sum += 1.0 / (1.0 + variance); // Higher variance = lower stability
1305        }
1306
1307        Ok(stability_sum / n_samples as Float)
1308    }
1309}
1310
1311/// Convenience function for ensemble evaluation
1312pub fn evaluate_ensemble<F>(
1313    ensemble_predictions: &Array2<Float>,
1314    true_labels: &Array1<Float>,
1315    ensemble_weights: Option<&Array1<Float>>,
1316    model_predictions: Option<&Array2<Float>>,
1317    evaluation_fn: F,
1318    config: Option<EnsembleEvaluationConfig>,
1319) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
1320where
1321    F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
1322{
1323    let config = config.unwrap_or_default();
1324    let mut evaluator = EnsembleEvaluator::new(config);
1325    evaluator.evaluate(
1326        ensemble_predictions,
1327        true_labels,
1328        ensemble_weights,
1329        model_predictions,
1330        evaluation_fn,
1331    )
1332}
1333
1334#[allow(non_snake_case)]
1335#[cfg(test)]
1336mod tests {
1337    use super::*;
1338
1339    fn mock_evaluation_function(
1340        predictions: &Array1<Float>,
1341        labels: &Array1<Float>,
1342    ) -> Result<Float, Box<dyn std::error::Error>> {
1343        let correct = predictions
1344            .iter()
1345            .zip(labels.iter())
1346            .filter(|(&pred, &label)| (pred > 0.5) == (label > 0.5))
1347            .count();
1348        Ok(correct as Float / predictions.len() as Float)
1349    }
1350
1351    #[test]
1352    fn test_ensemble_evaluator_creation() {
1353        let config = EnsembleEvaluationConfig::default();
1354        let evaluator = EnsembleEvaluator::new(config);
1355        assert_eq!(evaluator.config.confidence_level, 0.95);
1356    }
1357
1358    #[test]
1359    fn test_out_of_bag_evaluation() {
1360        let ensemble_predictions = Array2::from_shape_vec(
1361            (10, 3),
1362            vec![
1363                0.1, 0.8, 0.3, 0.9, 0.2, 0.7, 0.4, 0.6, 0.8, 0.1, 0.2, 0.9, 0.1, 0.8, 0.3, 0.6,
1364                0.5, 0.7, 0.9, 0.2, 0.3, 0.7, 0.2, 0.9, 0.1, 0.8, 0.4, 0.5, 0.6, 0.3,
1365            ],
1366        )
1367        .expect("operation should succeed");
1368        let true_labels = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
1369
1370        let config = EnsembleEvaluationConfig {
1371            strategy: EnsembleEvaluationStrategy::OutOfBag {
1372                bootstrap_samples: 10,
1373                confidence_level: 0.95,
1374            },
1375            ..Default::default()
1376        };
1377
1378        let result = evaluate_ensemble(
1379            &ensemble_predictions,
1380            &true_labels,
1381            None,
1382            None,
1383            mock_evaluation_function,
1384            Some(config),
1385        )
1386        .expect("operation should succeed");
1387
1388        assert!(result.out_of_bag_scores.is_some());
1389        let oob_scores = result.out_of_bag_scores.expect("operation should succeed");
1390        assert!(oob_scores.oob_score >= 0.0 && oob_scores.oob_score <= 1.0);
1391    }
1392
1393    #[test]
1394    fn test_diversity_evaluation() {
1395        let ensemble_predictions = Array2::from_shape_vec(
1396            (10, 3),
1397            vec![
1398                0.1, 0.8, 0.3, 0.9, 0.2, 0.7, 0.4, 0.6, 0.8, 0.1, 0.2, 0.9, 0.1, 0.8, 0.3, 0.6,
1399                0.5, 0.7, 0.9, 0.2, 0.3, 0.7, 0.2, 0.9, 0.1, 0.8, 0.4, 0.5, 0.6, 0.3,
1400            ],
1401        )
1402        .expect("operation should succeed");
1403        let true_labels = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
1404        let model_predictions = ensemble_predictions.clone();
1405
1406        let config = EnsembleEvaluationConfig {
1407            strategy: EnsembleEvaluationStrategy::DiversityEvaluation {
1408                diversity_measures: vec![
1409                    DiversityMeasure::QStatistic,
1410                    DiversityMeasure::DisagreementMeasure,
1411                ],
1412                diversity_threshold: 0.5,
1413            },
1414            ..Default::default()
1415        };
1416
1417        let result = evaluate_ensemble(
1418            &ensemble_predictions,
1419            &true_labels,
1420            None,
1421            Some(&model_predictions),
1422            mock_evaluation_function,
1423            Some(config),
1424        )
1425        .expect("operation should succeed");
1426
1427        assert!(!result.diversity_analysis.diversity_by_measure.is_empty());
1428        assert!(result.diversity_analysis.overall_diversity >= 0.0);
1429    }
1430
1431    #[test]
1432    fn test_cross_validation_evaluation() {
1433        let ensemble_predictions = Array2::from_shape_vec(
1434            (10, 3),
1435            vec![
1436                0.1, 0.8, 0.3, 0.9, 0.2, 0.7, 0.4, 0.6, 0.8, 0.1, 0.2, 0.9, 0.1, 0.8, 0.3, 0.6,
1437                0.5, 0.7, 0.9, 0.2, 0.3, 0.7, 0.2, 0.9, 0.1, 0.8, 0.4, 0.5, 0.6, 0.3,
1438            ],
1439        )
1440        .expect("operation should succeed");
1441        let true_labels = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
1442
1443        let config = EnsembleEvaluationConfig {
1444            strategy: EnsembleEvaluationStrategy::EnsembleCrossValidation {
1445                cv_strategy: EnsembleCVStrategy::KFoldEnsemble,
1446                n_folds: 5,
1447            },
1448            ..Default::default()
1449        };
1450
1451        let result = evaluate_ensemble(
1452            &ensemble_predictions,
1453            &true_labels,
1454            None,
1455            None,
1456            mock_evaluation_function,
1457            Some(config),
1458        )
1459        .expect("operation should succeed");
1460
1461        assert!(!result
1462            .ensemble_performance
1463            .individual_fold_scores
1464            .is_empty());
1465        assert!(result.ensemble_performance.mean_performance >= 0.0);
1466    }
1467}