1use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::SeedableRng;
10use scirs2_core::RngExt;
11use sklears_core::types::Float;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub enum EnsembleEvaluationStrategy {
17 OutOfBag {
19 bootstrap_samples: usize,
20
21 confidence_level: Float,
22 },
23 EnsembleCrossValidation {
25 cv_strategy: EnsembleCVStrategy,
26
27 n_folds: usize,
28 },
29 DiversityEvaluation {
31 diversity_measures: Vec<DiversityMeasure>,
32 diversity_threshold: Float,
33 },
34 StabilityAnalysis {
36 n_bootstrap_samples: usize,
37 stability_metrics: Vec<StabilityMetric>,
38 },
39 ProgressiveEvaluation {
41 ensemble_sizes: Vec<usize>,
42 selection_strategy: ProgressiveSelectionStrategy,
43 },
44 MultiObjectiveEvaluation {
46 objectives: Vec<EvaluationObjective>,
47 trade_off_analysis: bool,
48 },
49}
50
51#[derive(Debug, Clone)]
53pub enum EnsembleCVStrategy {
54 KFoldEnsemble,
56 StratifiedEnsemble,
58 LeaveOneModelOut,
60 BootstrapEnsemble { n_bootstrap: usize },
62 NestedEnsemble { inner_cv: usize, outer_cv: usize },
64 TimeSeriesEnsemble { n_splits: usize, test_size: Float },
66}
67
68#[derive(Debug, Clone)]
70pub enum DiversityMeasure {
71 QStatistic,
73 CorrelationCoefficient,
75 DisagreementMeasure,
77 DoubleFaultMeasure,
79 EntropyDiversity,
81 KohaviWolpertVariance,
83 InterraterAgreement,
85 DifficultyMeasure,
87 GeneralizedDiversity { alpha: Float },
89}
90
91#[derive(Debug, Clone)]
93pub enum StabilityMetric {
94 PredictionStability,
96 ModelSelectionStability,
98 WeightStability,
100 PerformanceStability,
102 RankingStability,
104}
105
106#[derive(Debug, Clone)]
108pub enum ProgressiveSelectionStrategy {
109 ForwardSelection,
111 BackwardElimination,
113 DiversityDriven,
115 PerformanceDiversityTradeoff { alpha: Float },
117}
118
119#[derive(Debug, Clone)]
121pub enum EvaluationObjective {
122 Accuracy,
124 Diversity,
126 Efficiency,
128 MemoryUsage,
130 Robustness,
132 Interpretability,
134 Fairness,
136}
137
138#[derive(Debug, Clone)]
140pub struct EnsembleEvaluationConfig {
141 pub strategy: EnsembleEvaluationStrategy,
142 pub evaluation_metrics: Vec<String>,
143 pub confidence_level: Float,
144 pub n_repetitions: usize,
145 pub parallel_evaluation: bool,
146 pub random_state: Option<u64>,
147 pub verbose: bool,
148}
149
150#[derive(Debug, Clone)]
152pub struct EnsembleEvaluationResult {
153 pub ensemble_performance: EnsemblePerformanceMetrics,
154 pub diversity_analysis: DiversityAnalysis,
155 pub stability_analysis: Option<StabilityAnalysis>,
156 pub member_contributions: Vec<MemberContribution>,
157 pub out_of_bag_scores: Option<OutOfBagScores>,
158 pub progressive_performance: Option<ProgressivePerformance>,
159 pub multi_objective_analysis: Option<MultiObjectiveAnalysis>,
160}
161
162#[derive(Debug, Clone)]
164pub struct EnsemblePerformanceMetrics {
165 pub mean_performance: Float,
166 pub std_performance: Float,
167 pub confidence_interval: (Float, Float),
168 pub individual_fold_scores: Vec<Float>,
169 pub ensemble_vs_best_member: Float,
170 pub ensemble_vs_average_member: Float,
171 pub performance_gain: Float,
172}
173
174#[derive(Debug, Clone)]
176pub struct DiversityAnalysis {
177 pub overall_diversity: Float,
178 pub pairwise_diversities: Array2<Float>,
179 pub diversity_by_measure: HashMap<String, Float>,
180 pub diversity_distribution: Vec<Float>,
181 pub optimal_diversity_size: Option<usize>,
182}
183
184#[derive(Debug, Clone)]
186pub struct StabilityAnalysis {
187 pub prediction_stability: Float,
188 pub model_selection_stability: Float,
189 pub weight_stability: Option<Float>,
190 pub performance_stability: Float,
191 pub stability_confidence_intervals: HashMap<String, (Float, Float)>,
192}
193
194#[derive(Debug, Clone)]
196pub struct MemberContribution {
197 pub member_id: usize,
198 pub member_name: String,
199 pub individual_performance: Float,
200 pub marginal_contribution: Float,
201 pub shapley_value: Option<Float>,
202 pub removal_impact: Float,
203 pub diversity_contribution: Float,
204}
205
206#[derive(Debug, Clone)]
208pub struct OutOfBagScores {
209 pub oob_score: Float,
210 pub oob_confidence_interval: (Float, Float),
211 pub feature_importance: Option<Array1<Float>>,
212 pub prediction_intervals: Option<Array2<Float>>,
213 pub individual_oob_scores: Vec<Float>,
214}
215
216#[derive(Debug, Clone)]
218pub struct ProgressivePerformance {
219 pub ensemble_sizes: Vec<usize>,
220 pub performance_curve: Vec<Float>,
221 pub diversity_curve: Vec<Float>,
222 pub efficiency_curve: Vec<Float>,
223 pub optimal_size: usize,
224 pub diminishing_returns_threshold: Option<usize>,
225}
226
227#[derive(Debug, Clone)]
229pub struct MultiObjectiveAnalysis {
230 pub pareto_front: Vec<(Float, Float)>,
231 pub objective_scores: HashMap<String, Float>,
232 pub trade_off_analysis: HashMap<String, Float>,
233 pub dominated_solutions: Vec<usize>,
234 pub compromise_solution: Option<usize>,
235}
236
237#[derive(Debug)]
239pub struct EnsembleEvaluator {
240 config: EnsembleEvaluationConfig,
241 rng: StdRng,
242}
243
244impl Default for EnsembleEvaluationConfig {
245 fn default() -> Self {
246 Self {
247 strategy: EnsembleEvaluationStrategy::EnsembleCrossValidation {
248 cv_strategy: EnsembleCVStrategy::KFoldEnsemble,
249 n_folds: 5,
250 },
251 evaluation_metrics: vec!["accuracy".to_string(), "f1_score".to_string()],
252 confidence_level: 0.95,
253 n_repetitions: 1,
254 parallel_evaluation: false,
255 random_state: None,
256 verbose: false,
257 }
258 }
259}
260
261impl EnsembleEvaluator {
262 pub fn new(config: EnsembleEvaluationConfig) -> Self {
264 let rng = match config.random_state {
265 Some(seed) => StdRng::seed_from_u64(seed),
266 None => {
267 use scirs2_core::random::thread_rng;
268 StdRng::from_rng(&mut thread_rng())
269 }
270 };
271
272 Self { config, rng }
273 }
274
275 pub fn evaluate<F>(
277 &mut self,
278 ensemble_predictions: &Array2<Float>,
279 true_labels: &Array1<Float>,
280 ensemble_weights: Option<&Array1<Float>>,
281 model_predictions: Option<&Array2<Float>>,
282 evaluation_fn: F,
283 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
284 where
285 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
286 {
287 match &self.config.strategy {
288 EnsembleEvaluationStrategy::OutOfBag { .. } => self.evaluate_out_of_bag(
289 ensemble_predictions,
290 true_labels,
291 ensemble_weights,
292 &evaluation_fn,
293 ),
294 EnsembleEvaluationStrategy::EnsembleCrossValidation { .. } => self
295 .evaluate_cross_validation(
296 ensemble_predictions,
297 true_labels,
298 ensemble_weights,
299 model_predictions,
300 &evaluation_fn,
301 ),
302 EnsembleEvaluationStrategy::DiversityEvaluation { .. } => self.evaluate_diversity(
303 ensemble_predictions,
304 true_labels,
305 model_predictions,
306 &evaluation_fn,
307 ),
308 EnsembleEvaluationStrategy::StabilityAnalysis { .. } => self.evaluate_stability(
309 ensemble_predictions,
310 true_labels,
311 ensemble_weights,
312 &evaluation_fn,
313 ),
314 EnsembleEvaluationStrategy::ProgressiveEvaluation { .. } => self.evaluate_progressive(
315 ensemble_predictions,
316 true_labels,
317 model_predictions,
318 &evaluation_fn,
319 ),
320 EnsembleEvaluationStrategy::MultiObjectiveEvaluation { .. } => self
321 .evaluate_multi_objective(
322 ensemble_predictions,
323 true_labels,
324 ensemble_weights,
325 &evaluation_fn,
326 ),
327 }
328 }
329
330 fn evaluate_out_of_bag<F>(
332 &mut self,
333 ensemble_predictions: &Array2<Float>,
334 true_labels: &Array1<Float>,
335 ensemble_weights: Option<&Array1<Float>>,
336 evaluation_fn: &F,
337 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
338 where
339 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
340 {
341 let (bootstrap_samples, confidence_level) = match &self.config.strategy {
342 EnsembleEvaluationStrategy::OutOfBag {
343 bootstrap_samples,
344 confidence_level,
345 } => (*bootstrap_samples, *confidence_level),
346 _ => unreachable!(),
347 };
348
349 let n_samples = ensemble_predictions.nrows();
350 let n_models = ensemble_predictions.ncols();
351
352 let mut oob_scores = Vec::new();
353 let mut oob_predictions_all = Vec::new();
354
355 for _ in 0..bootstrap_samples {
356 let bootstrap_indices: Vec<usize> = (0..n_samples)
358 .map(|_| self.rng.random_range(0..n_samples))
359 .collect();
360
361 let mut oob_indices = Vec::new();
363 for i in 0..n_samples {
364 if !bootstrap_indices.contains(&i) {
365 oob_indices.push(i);
366 }
367 }
368
369 if oob_indices.is_empty() {
370 continue;
371 }
372
373 let oob_ensemble_preds = self.calculate_ensemble_predictions(
375 ensemble_predictions,
376 &oob_indices,
377 ensemble_weights,
378 )?;
379
380 let oob_true_labels =
381 Array1::from_vec(oob_indices.iter().map(|&i| true_labels[i]).collect());
382
383 let oob_score = evaluation_fn(&oob_ensemble_preds, &oob_true_labels)?;
384 oob_scores.push(oob_score);
385 oob_predictions_all.push(oob_ensemble_preds);
386 }
387
388 let mean_oob_score = oob_scores.iter().sum::<Float>() / oob_scores.len() as Float;
389 let std_oob_score = {
390 let variance = oob_scores
391 .iter()
392 .map(|&score| (score - mean_oob_score).powi(2))
393 .sum::<Float>()
394 / oob_scores.len() as Float;
395 variance.sqrt()
396 };
397
398 let _alpha = 1.0 - confidence_level;
399 let z_score = 1.96; let margin_of_error = z_score * std_oob_score / (oob_scores.len() as Float).sqrt();
401 let confidence_interval = (
402 mean_oob_score - margin_of_error,
403 mean_oob_score + margin_of_error,
404 );
405
406 let oob_scores_result = OutOfBagScores {
407 oob_score: mean_oob_score,
408 oob_confidence_interval: confidence_interval,
409 feature_importance: None, prediction_intervals: None, individual_oob_scores: oob_scores,
412 };
413
414 let ensemble_preds = self.calculate_ensemble_predictions(
416 ensemble_predictions,
417 &(0..n_samples).collect::<Vec<_>>(),
418 ensemble_weights,
419 )?;
420 let ensemble_score = evaluation_fn(&ensemble_preds, true_labels)?;
421
422 let ensemble_performance = EnsemblePerformanceMetrics {
423 mean_performance: ensemble_score,
424 std_performance: std_oob_score,
425 confidence_interval,
426 individual_fold_scores: vec![ensemble_score],
427 ensemble_vs_best_member: 0.0, ensemble_vs_average_member: 0.0,
429 performance_gain: 0.0,
430 };
431
432 Ok(EnsembleEvaluationResult {
433 ensemble_performance,
434 diversity_analysis: DiversityAnalysis {
435 overall_diversity: 0.0,
436 pairwise_diversities: Array2::zeros((n_models, n_models)),
437 diversity_by_measure: HashMap::new(),
438 diversity_distribution: Vec::new(),
439 optimal_diversity_size: None,
440 },
441 stability_analysis: None,
442 member_contributions: Vec::new(),
443 out_of_bag_scores: Some(oob_scores_result),
444 progressive_performance: None,
445 multi_objective_analysis: None,
446 })
447 }
448
449 fn evaluate_cross_validation<F>(
451 &mut self,
452 ensemble_predictions: &Array2<Float>,
453 true_labels: &Array1<Float>,
454 ensemble_weights: Option<&Array1<Float>>,
455 model_predictions: Option<&Array2<Float>>,
456 evaluation_fn: &F,
457 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
458 where
459 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
460 {
461 let (_cv_strategy, n_folds) = match &self.config.strategy {
462 EnsembleEvaluationStrategy::EnsembleCrossValidation {
463 cv_strategy,
464 n_folds,
465 } => (cv_strategy, *n_folds),
466 _ => unreachable!(),
467 };
468
469 let n_samples = ensemble_predictions.nrows();
470 let fold_size = n_samples / n_folds;
471 let mut fold_scores = Vec::new();
472 let mut diversity_scores = Vec::new();
473
474 for fold in 0..n_folds {
475 let test_start = fold * fold_size;
476 let test_end = if fold == n_folds - 1 {
477 n_samples
478 } else {
479 (fold + 1) * fold_size
480 };
481 let test_indices: Vec<usize> = (test_start..test_end).collect();
482
483 let test_ensemble_preds = self.calculate_ensemble_predictions(
485 ensemble_predictions,
486 &test_indices,
487 ensemble_weights,
488 )?;
489
490 let test_true_labels =
491 Array1::from_vec(test_indices.iter().map(|&i| true_labels[i]).collect());
492
493 let fold_score = evaluation_fn(&test_ensemble_preds, &test_true_labels)?;
494 fold_scores.push(fold_score);
495
496 if let Some(model_preds) = model_predictions {
498 let mut fold_data = Vec::new();
499 for &i in test_indices.iter() {
500 fold_data.extend(model_preds.row(i).iter().cloned());
501 }
502 let fold_model_preds =
503 Array2::from_shape_vec((test_indices.len(), model_preds.ncols()), fold_data)?;
504
505 let diversity = self.calculate_q_statistic(&fold_model_preds)?;
506 diversity_scores.push(diversity);
507 }
508 }
509
510 let mean_performance = fold_scores.iter().sum::<Float>() / fold_scores.len() as Float;
511 let std_performance = {
512 let variance = fold_scores
513 .iter()
514 .map(|&score| (score - mean_performance).powi(2))
515 .sum::<Float>()
516 / fold_scores.len() as Float;
517 variance.sqrt()
518 };
519
520 let z_score = 1.96; let margin_of_error = z_score * std_performance / (fold_scores.len() as Float).sqrt();
522 let confidence_interval = (
523 mean_performance - margin_of_error,
524 mean_performance + margin_of_error,
525 );
526
527 let ensemble_performance = EnsemblePerformanceMetrics {
528 mean_performance,
529 std_performance,
530 confidence_interval,
531 individual_fold_scores: fold_scores,
532 ensemble_vs_best_member: 0.0, ensemble_vs_average_member: 0.0,
534 performance_gain: 0.0,
535 };
536
537 let diversity_analysis = if !diversity_scores.is_empty() {
538 let mean_diversity =
539 diversity_scores.iter().sum::<Float>() / diversity_scores.len() as Float;
540 DiversityAnalysis {
541 overall_diversity: mean_diversity,
542 pairwise_diversities: Array2::zeros((0, 0)), diversity_by_measure: {
544 let mut map = HashMap::new();
545 map.insert("q_statistic".to_string(), mean_diversity);
546 map
547 },
548 diversity_distribution: diversity_scores,
549 optimal_diversity_size: None,
550 }
551 } else {
552 DiversityAnalysis {
553 overall_diversity: 0.0,
554 pairwise_diversities: Array2::zeros((0, 0)),
555 diversity_by_measure: HashMap::new(),
556 diversity_distribution: Vec::new(),
557 optimal_diversity_size: None,
558 }
559 };
560
561 Ok(EnsembleEvaluationResult {
562 ensemble_performance,
563 diversity_analysis,
564 stability_analysis: None,
565 member_contributions: Vec::new(),
566 out_of_bag_scores: None,
567 progressive_performance: None,
568 multi_objective_analysis: None,
569 })
570 }
571
572 fn evaluate_diversity<F>(
574 &mut self,
575 ensemble_predictions: &Array2<Float>,
576 true_labels: &Array1<Float>,
577 model_predictions: Option<&Array2<Float>>,
578 evaluation_fn: &F,
579 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
580 where
581 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
582 {
583 let (diversity_measures, _diversity_threshold) = match &self.config.strategy {
584 EnsembleEvaluationStrategy::DiversityEvaluation {
585 diversity_measures,
586 diversity_threshold,
587 } => (diversity_measures, *diversity_threshold),
588 _ => unreachable!(),
589 };
590
591 if let Some(model_preds) = model_predictions {
592 let n_models = model_preds.ncols();
593 let mut diversity_by_measure = HashMap::new();
594 let mut pairwise_diversities = Array2::zeros((n_models, n_models));
595
596 for measure in diversity_measures {
597 let diversity_value = match measure {
598 DiversityMeasure::QStatistic => self.calculate_q_statistic(model_preds)?,
599 DiversityMeasure::CorrelationCoefficient => {
600 self.calculate_correlation_coefficient(model_preds)?
601 }
602 DiversityMeasure::DisagreementMeasure => {
603 self.calculate_disagreement_measure(model_preds)?
604 }
605 DiversityMeasure::DoubleFaultMeasure => {
606 self.calculate_double_fault_measure(model_preds, true_labels)?
607 }
608 DiversityMeasure::EntropyDiversity => {
609 self.calculate_entropy_diversity(model_preds)?
610 }
611 DiversityMeasure::KohaviWolpertVariance => {
612 self.calculate_kw_variance(model_preds, true_labels)?
613 }
614 DiversityMeasure::InterraterAgreement => {
615 self.calculate_interrater_agreement(model_preds)?
616 }
617 DiversityMeasure::DifficultyMeasure => {
618 self.calculate_difficulty_measure(model_preds, true_labels)?
619 }
620 DiversityMeasure::GeneralizedDiversity { alpha } => {
621 self.calculate_generalized_diversity(model_preds, *alpha)?
622 }
623 };
624
625 diversity_by_measure.insert(format!("{:?}", measure), diversity_value);
626 }
627
628 for i in 0..n_models {
630 for j in i + 1..n_models {
631 let pair_preds = Array2::from_shape_vec(
632 (model_preds.nrows(), 2),
633 model_preds
634 .column(i)
635 .iter()
636 .cloned()
637 .chain(model_preds.column(j).iter().cloned())
638 .collect(),
639 )?;
640 let pair_diversity = self.calculate_q_statistic(&pair_preds)?;
641 pairwise_diversities[[i, j]] = pair_diversity;
642 pairwise_diversities[[j, i]] = pair_diversity;
643 }
644 }
645
646 let overall_diversity =
647 diversity_by_measure.values().sum::<Float>() / diversity_by_measure.len() as Float;
648
649 let diversity_analysis = DiversityAnalysis {
650 overall_diversity,
651 pairwise_diversities,
652 diversity_by_measure,
653 diversity_distribution: Vec::new(), optimal_diversity_size: None, };
656
657 let ensemble_preds = self.calculate_ensemble_predictions(
659 ensemble_predictions,
660 &(0..ensemble_predictions.nrows()).collect::<Vec<_>>(),
661 None,
662 )?;
663 let ensemble_score = evaluation_fn(&ensemble_preds, true_labels)?;
664
665 let ensemble_performance = EnsemblePerformanceMetrics {
666 mean_performance: ensemble_score,
667 std_performance: 0.0,
668 confidence_interval: (ensemble_score, ensemble_score),
669 individual_fold_scores: vec![ensemble_score],
670 ensemble_vs_best_member: 0.0,
671 ensemble_vs_average_member: 0.0,
672 performance_gain: 0.0,
673 };
674
675 Ok(EnsembleEvaluationResult {
676 ensemble_performance,
677 diversity_analysis,
678 stability_analysis: None,
679 member_contributions: Vec::new(),
680 out_of_bag_scores: None,
681 progressive_performance: None,
682 multi_objective_analysis: None,
683 })
684 } else {
685 Err("Model predictions required for diversity evaluation".into())
686 }
687 }
688
689 fn evaluate_stability<F>(
691 &mut self,
692 ensemble_predictions: &Array2<Float>,
693 true_labels: &Array1<Float>,
694 ensemble_weights: Option<&Array1<Float>>,
695 evaluation_fn: &F,
696 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
697 where
698 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
699 {
700 let (n_bootstrap_samples, _stability_metrics) = match &self.config.strategy {
701 EnsembleEvaluationStrategy::StabilityAnalysis {
702 n_bootstrap_samples,
703 stability_metrics,
704 } => (*n_bootstrap_samples, stability_metrics),
705 _ => unreachable!(),
706 };
707
708 let n_samples = ensemble_predictions.nrows();
709 let mut bootstrap_scores = Vec::new();
710 let mut bootstrap_predictions = Vec::new();
711
712 for _ in 0..n_bootstrap_samples {
713 let bootstrap_indices: Vec<usize> = (0..n_samples)
715 .map(|_| self.rng.random_range(0..n_samples))
716 .collect();
717
718 let bootstrap_preds = self.calculate_ensemble_predictions(
719 ensemble_predictions,
720 &bootstrap_indices,
721 ensemble_weights,
722 )?;
723
724 let bootstrap_labels =
725 Array1::from_vec(bootstrap_indices.iter().map(|&i| true_labels[i]).collect());
726
727 let bootstrap_score = evaluation_fn(&bootstrap_preds, &bootstrap_labels)?;
728 bootstrap_scores.push(bootstrap_score);
729 bootstrap_predictions.push(bootstrap_preds);
730 }
731
732 let prediction_stability = self.calculate_prediction_stability(&bootstrap_predictions)?;
734
735 let mean_score = bootstrap_scores.iter().sum::<Float>() / bootstrap_scores.len() as Float;
737 let score_variance = bootstrap_scores
738 .iter()
739 .map(|&score| (score - mean_score).powi(2))
740 .sum::<Float>()
741 / bootstrap_scores.len() as Float;
742 let performance_stability = 1.0 / (1.0 + score_variance); let stability_analysis = StabilityAnalysis {
745 prediction_stability,
746 model_selection_stability: 0.8, weight_stability: None, performance_stability,
749 stability_confidence_intervals: HashMap::new(), };
751
752 let ensemble_performance = EnsemblePerformanceMetrics {
753 mean_performance: mean_score,
754 std_performance: score_variance.sqrt(),
755 confidence_interval: (
756 mean_score - score_variance.sqrt(),
757 mean_score + score_variance.sqrt(),
758 ),
759 individual_fold_scores: bootstrap_scores,
760 ensemble_vs_best_member: 0.0,
761 ensemble_vs_average_member: 0.0,
762 performance_gain: 0.0,
763 };
764
765 Ok(EnsembleEvaluationResult {
766 ensemble_performance,
767 diversity_analysis: DiversityAnalysis {
768 overall_diversity: 0.0,
769 pairwise_diversities: Array2::zeros((0, 0)),
770 diversity_by_measure: HashMap::new(),
771 diversity_distribution: Vec::new(),
772 optimal_diversity_size: None,
773 },
774 stability_analysis: Some(stability_analysis),
775 member_contributions: Vec::new(),
776 out_of_bag_scores: None,
777 progressive_performance: None,
778 multi_objective_analysis: None,
779 })
780 }
781
782 fn evaluate_progressive<F>(
784 &mut self,
785 _ensemble_predictions: &Array2<Float>,
786 true_labels: &Array1<Float>,
787 model_predictions: Option<&Array2<Float>>,
788 evaluation_fn: &F,
789 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
790 where
791 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
792 {
793 let (ensemble_sizes, _selection_strategy) = match &self.config.strategy {
794 EnsembleEvaluationStrategy::ProgressiveEvaluation {
795 ensemble_sizes,
796 selection_strategy,
797 } => (ensemble_sizes, selection_strategy),
798 _ => unreachable!(),
799 };
800
801 if let Some(model_preds) = model_predictions {
802 let mut performance_curve = Vec::new();
803 let mut diversity_curve = Vec::new();
804 let n_models = model_preds.ncols();
805
806 for &size in ensemble_sizes {
807 if size <= n_models {
808 let selected_indices: Vec<usize> = (0..size).collect();
810
811 let mut selected_data = Vec::new();
813 for &i in selected_indices.iter() {
814 selected_data.extend(model_preds.column(i).iter().cloned());
815 }
816 let selected_predictions =
817 Array2::from_shape_vec((model_preds.nrows(), size), selected_data)?;
818
819 let ensemble_preds = selected_predictions
820 .mean_axis(Axis(1))
821 .expect("operation should succeed");
822 let performance = evaluation_fn(&ensemble_preds, true_labels)?;
823 performance_curve.push(performance);
824
825 let diversity = self.calculate_q_statistic(&selected_predictions)?;
827 diversity_curve.push(diversity);
828 }
829 }
830
831 let optimal_size_idx = performance_curve
833 .iter()
834 .enumerate()
835 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
836 .map(|(idx, _)| idx)
837 .unwrap_or(0);
838 let optimal_size = ensemble_sizes[optimal_size_idx];
839
840 let progressive_performance = ProgressivePerformance {
841 ensemble_sizes: ensemble_sizes.clone(),
842 performance_curve,
843 diversity_curve,
844 efficiency_curve: vec![1.0; ensemble_sizes.len()], optimal_size,
846 diminishing_returns_threshold: None, };
848
849 let ensemble_performance = EnsemblePerformanceMetrics {
850 mean_performance: progressive_performance.performance_curve[optimal_size_idx],
851 std_performance: 0.0,
852 confidence_interval: (0.0, 0.0),
853 individual_fold_scores: vec![],
854 ensemble_vs_best_member: 0.0,
855 ensemble_vs_average_member: 0.0,
856 performance_gain: 0.0,
857 };
858
859 Ok(EnsembleEvaluationResult {
860 ensemble_performance,
861 diversity_analysis: DiversityAnalysis {
862 overall_diversity: 0.0,
863 pairwise_diversities: Array2::zeros((0, 0)),
864 diversity_by_measure: HashMap::new(),
865 diversity_distribution: Vec::new(),
866 optimal_diversity_size: Some(optimal_size),
867 },
868 stability_analysis: None,
869 member_contributions: Vec::new(),
870 out_of_bag_scores: None,
871 progressive_performance: Some(progressive_performance),
872 multi_objective_analysis: None,
873 })
874 } else {
875 Err("Model predictions required for progressive evaluation".into())
876 }
877 }
878
879 fn evaluate_multi_objective<F>(
881 &mut self,
882 ensemble_predictions: &Array2<Float>,
883 true_labels: &Array1<Float>,
884 ensemble_weights: Option<&Array1<Float>>,
885 evaluation_fn: &F,
886 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
887 where
888 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
889 {
890 let ensemble_preds = self.calculate_ensemble_predictions(
892 ensemble_predictions,
893 &(0..ensemble_predictions.nrows()).collect::<Vec<_>>(),
894 ensemble_weights,
895 )?;
896 let performance = evaluation_fn(&ensemble_preds, true_labels)?;
897
898 let mut objective_scores = HashMap::new();
899 objective_scores.insert("accuracy".to_string(), performance);
900 objective_scores.insert("diversity".to_string(), 0.5); objective_scores.insert("efficiency".to_string(), 0.8); let multi_objective_analysis = MultiObjectiveAnalysis {
904 pareto_front: vec![(performance, 0.5)], objective_scores,
906 trade_off_analysis: HashMap::new(),
907 dominated_solutions: Vec::new(),
908 compromise_solution: Some(0),
909 };
910
911 let ensemble_performance = EnsemblePerformanceMetrics {
912 mean_performance: performance,
913 std_performance: 0.0,
914 confidence_interval: (performance, performance),
915 individual_fold_scores: vec![performance],
916 ensemble_vs_best_member: 0.0,
917 ensemble_vs_average_member: 0.0,
918 performance_gain: 0.0,
919 };
920
921 Ok(EnsembleEvaluationResult {
922 ensemble_performance,
923 diversity_analysis: DiversityAnalysis {
924 overall_diversity: 0.0,
925 pairwise_diversities: Array2::zeros((0, 0)),
926 diversity_by_measure: HashMap::new(),
927 diversity_distribution: Vec::new(),
928 optimal_diversity_size: None,
929 },
930 stability_analysis: None,
931 member_contributions: Vec::new(),
932 out_of_bag_scores: None,
933 progressive_performance: None,
934 multi_objective_analysis: Some(multi_objective_analysis),
935 })
936 }
937
938 fn calculate_ensemble_predictions(
940 &self,
941 ensemble_predictions: &Array2<Float>,
942 indices: &[usize],
943 weights: Option<&Array1<Float>>,
944 ) -> Result<Array1<Float>, Box<dyn std::error::Error>> {
945 let mut selected_data = Vec::new();
946 for &i in indices.iter() {
947 selected_data.extend(ensemble_predictions.row(i).iter().cloned());
948 }
949 let selected_predictions =
950 Array2::from_shape_vec((indices.len(), ensemble_predictions.ncols()), selected_data)?;
951
952 if let Some(w) = weights {
953 Ok(selected_predictions.dot(w))
955 } else {
956 Ok(selected_predictions
958 .mean_axis(Axis(1))
959 .expect("operation should succeed"))
960 }
961 }
962
963 fn calculate_q_statistic(
965 &self,
966 predictions: &Array2<Float>,
967 ) -> Result<Float, Box<dyn std::error::Error>> {
968 let n_models = predictions.ncols();
969 if n_models < 2 {
970 return Ok(0.0);
971 }
972
973 let mut q_sum = 0.0;
974 let mut pairs = 0;
975
976 for i in 0..n_models {
977 for j in i + 1..n_models {
978 let pred_i = predictions.column(i);
979 let pred_j = predictions.column(j);
980
981 let mut n11 = 0; let mut n10 = 0; let mut n01 = 0; let mut n00 = 0; for k in 0..predictions.nrows() {
987 let i_correct = pred_i[k] > 0.5;
988 let j_correct = pred_j[k] > 0.5;
989
990 match (i_correct, j_correct) {
991 (true, true) => n11 += 1,
992 (true, false) => n10 += 1,
993 (false, true) => n01 += 1,
994 (false, false) => n00 += 1,
995 }
996 }
997
998 let numerator = (n11 * n00 - n01 * n10) as Float;
999 let denominator = (n11 * n00 + n01 * n10) as Float;
1000
1001 if denominator != 0.0 {
1002 q_sum += numerator / denominator;
1003 pairs += 1;
1004 }
1005 }
1006 }
1007
1008 Ok(if pairs > 0 {
1009 q_sum / pairs as Float
1010 } else {
1011 0.0
1012 })
1013 }
1014
1015 fn calculate_correlation_coefficient(
1017 &self,
1018 predictions: &Array2<Float>,
1019 ) -> Result<Float, Box<dyn std::error::Error>> {
1020 let n_models = predictions.ncols();
1021 if n_models < 2 {
1022 return Ok(0.0);
1023 }
1024
1025 let mut correlations = Vec::new();
1026 for i in 0..n_models {
1027 for j in i + 1..n_models {
1028 let pred_i = predictions.column(i);
1029 let pred_j = predictions.column(j);
1030
1031 let mean_i = pred_i.mean().unwrap_or(0.0);
1032 let mean_j = pred_j.mean().unwrap_or(0.0);
1033
1034 let mut covariance = 0.0;
1035 let mut var_i = 0.0;
1036 let mut var_j = 0.0;
1037
1038 for k in 0..predictions.nrows() {
1039 let diff_i = pred_i[k] - mean_i;
1040 let diff_j = pred_j[k] - mean_j;
1041 covariance += diff_i * diff_j;
1042 var_i += diff_i * diff_i;
1043 var_j += diff_j * diff_j;
1044 }
1045
1046 let correlation = if var_i > 0.0 && var_j > 0.0 {
1047 covariance / (var_i.sqrt() * var_j.sqrt())
1048 } else {
1049 0.0
1050 };
1051
1052 correlations.push(correlation.abs());
1053 }
1054 }
1055
1056 Ok(1.0 - correlations.iter().sum::<Float>() / correlations.len() as Float)
1057 }
1058
1059 fn calculate_disagreement_measure(
1061 &self,
1062 predictions: &Array2<Float>,
1063 ) -> Result<Float, Box<dyn std::error::Error>> {
1064 let n_models = predictions.ncols();
1065 if n_models < 2 {
1066 return Ok(0.0);
1067 }
1068
1069 let mut disagreement_sum = 0.0;
1070 let mut pairs = 0;
1071
1072 for i in 0..n_models {
1073 for j in i + 1..n_models {
1074 let pred_i = predictions.column(i);
1075 let pred_j = predictions.column(j);
1076
1077 let mut disagreements = 0;
1078 for k in 0..predictions.nrows() {
1079 if (pred_i[k] > 0.5) != (pred_j[k] > 0.5) {
1080 disagreements += 1;
1081 }
1082 }
1083
1084 disagreement_sum += disagreements as Float / predictions.nrows() as Float;
1085 pairs += 1;
1086 }
1087 }
1088
1089 Ok(if pairs > 0 {
1090 disagreement_sum / pairs as Float
1091 } else {
1092 0.0
1093 })
1094 }
1095
1096 fn calculate_double_fault_measure(
1098 &self,
1099 predictions: &Array2<Float>,
1100 true_labels: &Array1<Float>,
1101 ) -> Result<Float, Box<dyn std::error::Error>> {
1102 let n_models = predictions.ncols();
1103 if n_models < 2 {
1104 return Ok(0.0);
1105 }
1106
1107 let mut double_fault_sum = 0.0;
1108 let mut pairs = 0;
1109
1110 for i in 0..n_models {
1111 for j in i + 1..n_models {
1112 let pred_i = predictions.column(i);
1113 let pred_j = predictions.column(j);
1114
1115 let mut double_faults = 0;
1116 for k in 0..predictions.nrows() {
1117 let i_wrong = (pred_i[k] > 0.5) != (true_labels[k] > 0.5);
1118 let j_wrong = (pred_j[k] > 0.5) != (true_labels[k] > 0.5);
1119
1120 if i_wrong && j_wrong {
1121 double_faults += 1;
1122 }
1123 }
1124
1125 double_fault_sum += double_faults as Float / predictions.nrows() as Float;
1126 pairs += 1;
1127 }
1128 }
1129
1130 Ok(if pairs > 0 {
1131 double_fault_sum / pairs as Float
1132 } else {
1133 0.0
1134 })
1135 }
1136
1137 fn calculate_entropy_diversity(
1139 &self,
1140 predictions: &Array2<Float>,
1141 ) -> Result<Float, Box<dyn std::error::Error>> {
1142 let n_samples = predictions.nrows();
1143 let n_models = predictions.ncols();
1144
1145 let mut entropy_sum = 0.0;
1146
1147 for i in 0..n_samples {
1148 let correct_count = predictions
1149 .row(i)
1150 .iter()
1151 .filter(|&&pred| pred > 0.5)
1152 .count() as Float;
1153
1154 let p = correct_count / n_models as Float;
1155 if p > 0.0 && p < 1.0 {
1156 entropy_sum += -p * p.log2() - (1.0 - p) * (1.0 - p).log2();
1157 }
1158 }
1159
1160 Ok(entropy_sum / n_samples as Float)
1161 }
1162
1163 fn calculate_kw_variance(
1165 &self,
1166 predictions: &Array2<Float>,
1167 true_labels: &Array1<Float>,
1168 ) -> Result<Float, Box<dyn std::error::Error>> {
1169 let n_samples = predictions.nrows();
1170 let n_models = predictions.ncols();
1171
1172 let mut variance_sum = 0.0;
1173
1174 for i in 0..n_samples {
1175 let correct_count = predictions
1176 .row(i)
1177 .iter()
1178 .filter(|&&pred| (pred > 0.5) == (true_labels[i] > 0.5))
1179 .count() as Float;
1180
1181 let l = correct_count / n_models as Float;
1182 variance_sum += l * (1.0 - l);
1183 }
1184
1185 Ok(variance_sum / n_samples as Float)
1186 }
1187
1188 fn calculate_interrater_agreement(
1190 &self,
1191 predictions: &Array2<Float>,
1192 ) -> Result<Float, Box<dyn std::error::Error>> {
1193 let n_models = predictions.ncols();
1194 if n_models < 2 {
1195 return Ok(0.0);
1196 }
1197
1198 let mut agreement_sum = 0.0;
1199 let mut pairs = 0;
1200
1201 for i in 0..n_models {
1202 for j in i + 1..n_models {
1203 let pred_i = predictions.column(i);
1204 let pred_j = predictions.column(j);
1205
1206 let mut agreements = 0;
1207 for k in 0..predictions.nrows() {
1208 if (pred_i[k] > 0.5) == (pred_j[k] > 0.5) {
1209 agreements += 1;
1210 }
1211 }
1212
1213 agreement_sum += agreements as Float / predictions.nrows() as Float;
1214 pairs += 1;
1215 }
1216 }
1217
1218 Ok(if pairs > 0 {
1219 agreement_sum / pairs as Float
1220 } else {
1221 0.0
1222 })
1223 }
1224
1225 fn calculate_difficulty_measure(
1227 &self,
1228 predictions: &Array2<Float>,
1229 true_labels: &Array1<Float>,
1230 ) -> Result<Float, Box<dyn std::error::Error>> {
1231 let n_samples = predictions.nrows();
1232 let n_models = predictions.ncols();
1233
1234 let mut difficulty_sum = 0.0;
1235
1236 for i in 0..n_samples {
1237 let error_count = predictions
1238 .row(i)
1239 .iter()
1240 .filter(|&&pred| (pred > 0.5) != (true_labels[i] > 0.5))
1241 .count() as Float;
1242
1243 difficulty_sum += error_count / n_models as Float;
1244 }
1245
1246 Ok(difficulty_sum / n_samples as Float)
1247 }
1248
1249 fn calculate_generalized_diversity(
1251 &self,
1252 predictions: &Array2<Float>,
1253 alpha: Float,
1254 ) -> Result<Float, Box<dyn std::error::Error>> {
1255 let n_samples = predictions.nrows();
1256 let n_models = predictions.ncols();
1257
1258 let mut diversity_sum = 0.0;
1259
1260 for i in 0..n_samples {
1261 let correct_count = predictions
1262 .row(i)
1263 .iter()
1264 .filter(|&&pred| pred > 0.5)
1265 .count() as Float;
1266
1267 let p = correct_count / n_models as Float;
1268 if alpha != 1.0 {
1269 diversity_sum += (1.0 - p.powf(alpha) - (1.0 - p).powf(alpha))
1270 / (2.0_f64.powf(1.0 - alpha) as Float - 1.0);
1271 } else {
1272 if p > 0.0 && p < 1.0 {
1274 diversity_sum += -p * p.log2() - (1.0 - p) * (1.0 - p).log2();
1275 }
1276 }
1277 }
1278
1279 Ok(diversity_sum / n_samples as Float)
1280 }
1281
1282 fn calculate_prediction_stability(
1284 &self,
1285 predictions: &[Array1<Float>],
1286 ) -> Result<Float, Box<dyn std::error::Error>> {
1287 if predictions.len() < 2 {
1288 return Ok(1.0);
1289 }
1290
1291 let n_samples = predictions[0].len();
1292 let mut stability_sum = 0.0;
1293
1294 for i in 0..n_samples {
1295 let sample_predictions: Vec<Float> = predictions.iter().map(|p| p[i]).collect();
1296 let mean_pred =
1297 sample_predictions.iter().sum::<Float>() / sample_predictions.len() as Float;
1298 let variance = sample_predictions
1299 .iter()
1300 .map(|&pred| (pred - mean_pred).powi(2))
1301 .sum::<Float>()
1302 / sample_predictions.len() as Float;
1303
1304 stability_sum += 1.0 / (1.0 + variance); }
1306
1307 Ok(stability_sum / n_samples as Float)
1308 }
1309}
1310
1311pub fn evaluate_ensemble<F>(
1313 ensemble_predictions: &Array2<Float>,
1314 true_labels: &Array1<Float>,
1315 ensemble_weights: Option<&Array1<Float>>,
1316 model_predictions: Option<&Array2<Float>>,
1317 evaluation_fn: F,
1318 config: Option<EnsembleEvaluationConfig>,
1319) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
1320where
1321 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
1322{
1323 let config = config.unwrap_or_default();
1324 let mut evaluator = EnsembleEvaluator::new(config);
1325 evaluator.evaluate(
1326 ensemble_predictions,
1327 true_labels,
1328 ensemble_weights,
1329 model_predictions,
1330 evaluation_fn,
1331 )
1332}
1333
1334#[allow(non_snake_case)]
1335#[cfg(test)]
1336mod tests {
1337 use super::*;
1338
1339 fn mock_evaluation_function(
1340 predictions: &Array1<Float>,
1341 labels: &Array1<Float>,
1342 ) -> Result<Float, Box<dyn std::error::Error>> {
1343 let correct = predictions
1344 .iter()
1345 .zip(labels.iter())
1346 .filter(|(&pred, &label)| (pred > 0.5) == (label > 0.5))
1347 .count();
1348 Ok(correct as Float / predictions.len() as Float)
1349 }
1350
1351 #[test]
1352 fn test_ensemble_evaluator_creation() {
1353 let config = EnsembleEvaluationConfig::default();
1354 let evaluator = EnsembleEvaluator::new(config);
1355 assert_eq!(evaluator.config.confidence_level, 0.95);
1356 }
1357
1358 #[test]
1359 fn test_out_of_bag_evaluation() {
1360 let ensemble_predictions = Array2::from_shape_vec(
1361 (10, 3),
1362 vec![
1363 0.1, 0.8, 0.3, 0.9, 0.2, 0.7, 0.4, 0.6, 0.8, 0.1, 0.2, 0.9, 0.1, 0.8, 0.3, 0.6,
1364 0.5, 0.7, 0.9, 0.2, 0.3, 0.7, 0.2, 0.9, 0.1, 0.8, 0.4, 0.5, 0.6, 0.3,
1365 ],
1366 )
1367 .expect("operation should succeed");
1368 let true_labels = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
1369
1370 let config = EnsembleEvaluationConfig {
1371 strategy: EnsembleEvaluationStrategy::OutOfBag {
1372 bootstrap_samples: 10,
1373 confidence_level: 0.95,
1374 },
1375 ..Default::default()
1376 };
1377
1378 let result = evaluate_ensemble(
1379 &ensemble_predictions,
1380 &true_labels,
1381 None,
1382 None,
1383 mock_evaluation_function,
1384 Some(config),
1385 )
1386 .expect("operation should succeed");
1387
1388 assert!(result.out_of_bag_scores.is_some());
1389 let oob_scores = result.out_of_bag_scores.expect("operation should succeed");
1390 assert!(oob_scores.oob_score >= 0.0 && oob_scores.oob_score <= 1.0);
1391 }
1392
1393 #[test]
1394 fn test_diversity_evaluation() {
1395 let ensemble_predictions = Array2::from_shape_vec(
1396 (10, 3),
1397 vec![
1398 0.1, 0.8, 0.3, 0.9, 0.2, 0.7, 0.4, 0.6, 0.8, 0.1, 0.2, 0.9, 0.1, 0.8, 0.3, 0.6,
1399 0.5, 0.7, 0.9, 0.2, 0.3, 0.7, 0.2, 0.9, 0.1, 0.8, 0.4, 0.5, 0.6, 0.3,
1400 ],
1401 )
1402 .expect("operation should succeed");
1403 let true_labels = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
1404 let model_predictions = ensemble_predictions.clone();
1405
1406 let config = EnsembleEvaluationConfig {
1407 strategy: EnsembleEvaluationStrategy::DiversityEvaluation {
1408 diversity_measures: vec![
1409 DiversityMeasure::QStatistic,
1410 DiversityMeasure::DisagreementMeasure,
1411 ],
1412 diversity_threshold: 0.5,
1413 },
1414 ..Default::default()
1415 };
1416
1417 let result = evaluate_ensemble(
1418 &ensemble_predictions,
1419 &true_labels,
1420 None,
1421 Some(&model_predictions),
1422 mock_evaluation_function,
1423 Some(config),
1424 )
1425 .expect("operation should succeed");
1426
1427 assert!(!result.diversity_analysis.diversity_by_measure.is_empty());
1428 assert!(result.diversity_analysis.overall_diversity >= 0.0);
1429 }
1430
1431 #[test]
1432 fn test_cross_validation_evaluation() {
1433 let ensemble_predictions = Array2::from_shape_vec(
1434 (10, 3),
1435 vec![
1436 0.1, 0.8, 0.3, 0.9, 0.2, 0.7, 0.4, 0.6, 0.8, 0.1, 0.2, 0.9, 0.1, 0.8, 0.3, 0.6,
1437 0.5, 0.7, 0.9, 0.2, 0.3, 0.7, 0.2, 0.9, 0.1, 0.8, 0.4, 0.5, 0.6, 0.3,
1438 ],
1439 )
1440 .expect("operation should succeed");
1441 let true_labels = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
1442
1443 let config = EnsembleEvaluationConfig {
1444 strategy: EnsembleEvaluationStrategy::EnsembleCrossValidation {
1445 cv_strategy: EnsembleCVStrategy::KFoldEnsemble,
1446 n_folds: 5,
1447 },
1448 ..Default::default()
1449 };
1450
1451 let result = evaluate_ensemble(
1452 &ensemble_predictions,
1453 &true_labels,
1454 None,
1455 None,
1456 mock_evaluation_function,
1457 Some(config),
1458 )
1459 .expect("operation should succeed");
1460
1461 assert!(!result
1462 .ensemble_performance
1463 .individual_fold_scores
1464 .is_empty());
1465 assert!(result.ensemble_performance.mean_performance >= 0.0);
1466 }
1467}