1use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::Rng;
10use scirs2_core::random::SeedableRng;
11use sklears_core::types::Float;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub enum EnsembleEvaluationStrategy {
17 OutOfBag {
19 bootstrap_samples: usize,
20
21 confidence_level: Float,
22 },
23 EnsembleCrossValidation {
25 cv_strategy: EnsembleCVStrategy,
26
27 n_folds: usize,
28 },
29 DiversityEvaluation {
31 diversity_measures: Vec<DiversityMeasure>,
32 diversity_threshold: Float,
33 },
34 StabilityAnalysis {
36 n_bootstrap_samples: usize,
37 stability_metrics: Vec<StabilityMetric>,
38 },
39 ProgressiveEvaluation {
41 ensemble_sizes: Vec<usize>,
42 selection_strategy: ProgressiveSelectionStrategy,
43 },
44 MultiObjectiveEvaluation {
46 objectives: Vec<EvaluationObjective>,
47 trade_off_analysis: bool,
48 },
49}
50
51#[derive(Debug, Clone)]
53pub enum EnsembleCVStrategy {
54 KFoldEnsemble,
56 StratifiedEnsemble,
58 LeaveOneModelOut,
60 BootstrapEnsemble { n_bootstrap: usize },
62 NestedEnsemble { inner_cv: usize, outer_cv: usize },
64 TimeSeriesEnsemble { n_splits: usize, test_size: Float },
66}
67
68#[derive(Debug, Clone)]
70pub enum DiversityMeasure {
71 QStatistic,
73 CorrelationCoefficient,
75 DisagreementMeasure,
77 DoubleFaultMeasure,
79 EntropyDiversity,
81 KohaviWolpertVariance,
83 InterraterAgreement,
85 DifficultyMeasure,
87 GeneralizedDiversity { alpha: Float },
89}
90
91#[derive(Debug, Clone)]
93pub enum StabilityMetric {
94 PredictionStability,
96 ModelSelectionStability,
98 WeightStability,
100 PerformanceStability,
102 RankingStability,
104}
105
106#[derive(Debug, Clone)]
108pub enum ProgressiveSelectionStrategy {
109 ForwardSelection,
111 BackwardElimination,
113 DiversityDriven,
115 PerformanceDiversityTradeoff { alpha: Float },
117}
118
119#[derive(Debug, Clone)]
121pub enum EvaluationObjective {
122 Accuracy,
124 Diversity,
126 Efficiency,
128 MemoryUsage,
130 Robustness,
132 Interpretability,
134 Fairness,
136}
137
138#[derive(Debug, Clone)]
140pub struct EnsembleEvaluationConfig {
141 pub strategy: EnsembleEvaluationStrategy,
142 pub evaluation_metrics: Vec<String>,
143 pub confidence_level: Float,
144 pub n_repetitions: usize,
145 pub parallel_evaluation: bool,
146 pub random_state: Option<u64>,
147 pub verbose: bool,
148}
149
150#[derive(Debug, Clone)]
152pub struct EnsembleEvaluationResult {
153 pub ensemble_performance: EnsemblePerformanceMetrics,
154 pub diversity_analysis: DiversityAnalysis,
155 pub stability_analysis: Option<StabilityAnalysis>,
156 pub member_contributions: Vec<MemberContribution>,
157 pub out_of_bag_scores: Option<OutOfBagScores>,
158 pub progressive_performance: Option<ProgressivePerformance>,
159 pub multi_objective_analysis: Option<MultiObjectiveAnalysis>,
160}
161
162#[derive(Debug, Clone)]
164pub struct EnsemblePerformanceMetrics {
165 pub mean_performance: Float,
166 pub std_performance: Float,
167 pub confidence_interval: (Float, Float),
168 pub individual_fold_scores: Vec<Float>,
169 pub ensemble_vs_best_member: Float,
170 pub ensemble_vs_average_member: Float,
171 pub performance_gain: Float,
172}
173
174#[derive(Debug, Clone)]
176pub struct DiversityAnalysis {
177 pub overall_diversity: Float,
178 pub pairwise_diversities: Array2<Float>,
179 pub diversity_by_measure: HashMap<String, Float>,
180 pub diversity_distribution: Vec<Float>,
181 pub optimal_diversity_size: Option<usize>,
182}
183
184#[derive(Debug, Clone)]
186pub struct StabilityAnalysis {
187 pub prediction_stability: Float,
188 pub model_selection_stability: Float,
189 pub weight_stability: Option<Float>,
190 pub performance_stability: Float,
191 pub stability_confidence_intervals: HashMap<String, (Float, Float)>,
192}
193
194#[derive(Debug, Clone)]
196pub struct MemberContribution {
197 pub member_id: usize,
198 pub member_name: String,
199 pub individual_performance: Float,
200 pub marginal_contribution: Float,
201 pub shapley_value: Option<Float>,
202 pub removal_impact: Float,
203 pub diversity_contribution: Float,
204}
205
206#[derive(Debug, Clone)]
208pub struct OutOfBagScores {
209 pub oob_score: Float,
210 pub oob_confidence_interval: (Float, Float),
211 pub feature_importance: Option<Array1<Float>>,
212 pub prediction_intervals: Option<Array2<Float>>,
213 pub individual_oob_scores: Vec<Float>,
214}
215
216#[derive(Debug, Clone)]
218pub struct ProgressivePerformance {
219 pub ensemble_sizes: Vec<usize>,
220 pub performance_curve: Vec<Float>,
221 pub diversity_curve: Vec<Float>,
222 pub efficiency_curve: Vec<Float>,
223 pub optimal_size: usize,
224 pub diminishing_returns_threshold: Option<usize>,
225}
226
227#[derive(Debug, Clone)]
229pub struct MultiObjectiveAnalysis {
230 pub pareto_front: Vec<(Float, Float)>,
231 pub objective_scores: HashMap<String, Float>,
232 pub trade_off_analysis: HashMap<String, Float>,
233 pub dominated_solutions: Vec<usize>,
234 pub compromise_solution: Option<usize>,
235}
236
237#[derive(Debug)]
239pub struct EnsembleEvaluator {
240 config: EnsembleEvaluationConfig,
241 rng: StdRng,
242}
243
244impl Default for EnsembleEvaluationConfig {
245 fn default() -> Self {
246 Self {
247 strategy: EnsembleEvaluationStrategy::EnsembleCrossValidation {
248 cv_strategy: EnsembleCVStrategy::KFoldEnsemble,
249 n_folds: 5,
250 },
251 evaluation_metrics: vec!["accuracy".to_string(), "f1_score".to_string()],
252 confidence_level: 0.95,
253 n_repetitions: 1,
254 parallel_evaluation: false,
255 random_state: None,
256 verbose: false,
257 }
258 }
259}
260
261impl EnsembleEvaluator {
262 pub fn new(config: EnsembleEvaluationConfig) -> Self {
264 let rng = match config.random_state {
265 Some(seed) => StdRng::seed_from_u64(seed),
266 None => {
267 use scirs2_core::random::thread_rng;
268 StdRng::from_rng(&mut thread_rng())
269 }
270 };
271
272 Self { config, rng }
273 }
274
275 pub fn evaluate<F>(
277 &mut self,
278 ensemble_predictions: &Array2<Float>,
279 true_labels: &Array1<Float>,
280 ensemble_weights: Option<&Array1<Float>>,
281 model_predictions: Option<&Array2<Float>>,
282 evaluation_fn: F,
283 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
284 where
285 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
286 {
287 match &self.config.strategy {
288 EnsembleEvaluationStrategy::OutOfBag { .. } => self.evaluate_out_of_bag(
289 ensemble_predictions,
290 true_labels,
291 ensemble_weights,
292 &evaluation_fn,
293 ),
294 EnsembleEvaluationStrategy::EnsembleCrossValidation { .. } => self
295 .evaluate_cross_validation(
296 ensemble_predictions,
297 true_labels,
298 ensemble_weights,
299 model_predictions,
300 &evaluation_fn,
301 ),
302 EnsembleEvaluationStrategy::DiversityEvaluation { .. } => self.evaluate_diversity(
303 ensemble_predictions,
304 true_labels,
305 model_predictions,
306 &evaluation_fn,
307 ),
308 EnsembleEvaluationStrategy::StabilityAnalysis { .. } => self.evaluate_stability(
309 ensemble_predictions,
310 true_labels,
311 ensemble_weights,
312 &evaluation_fn,
313 ),
314 EnsembleEvaluationStrategy::ProgressiveEvaluation { .. } => self.evaluate_progressive(
315 ensemble_predictions,
316 true_labels,
317 model_predictions,
318 &evaluation_fn,
319 ),
320 EnsembleEvaluationStrategy::MultiObjectiveEvaluation { .. } => self
321 .evaluate_multi_objective(
322 ensemble_predictions,
323 true_labels,
324 ensemble_weights,
325 &evaluation_fn,
326 ),
327 }
328 }
329
330 fn evaluate_out_of_bag<F>(
332 &mut self,
333 ensemble_predictions: &Array2<Float>,
334 true_labels: &Array1<Float>,
335 ensemble_weights: Option<&Array1<Float>>,
336 evaluation_fn: &F,
337 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
338 where
339 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
340 {
341 let (bootstrap_samples, confidence_level) = match &self.config.strategy {
342 EnsembleEvaluationStrategy::OutOfBag {
343 bootstrap_samples,
344 confidence_level,
345 } => (*bootstrap_samples, *confidence_level),
346 _ => unreachable!(),
347 };
348
349 let n_samples = ensemble_predictions.nrows();
350 let n_models = ensemble_predictions.ncols();
351
352 let mut oob_scores = Vec::new();
353 let mut oob_predictions_all = Vec::new();
354
355 for _ in 0..bootstrap_samples {
356 let bootstrap_indices: Vec<usize> = (0..n_samples)
358 .map(|_| self.rng.gen_range(0..n_samples))
359 .collect();
360
361 let mut oob_indices = Vec::new();
363 for i in 0..n_samples {
364 if !bootstrap_indices.contains(&i) {
365 oob_indices.push(i);
366 }
367 }
368
369 if oob_indices.is_empty() {
370 continue;
371 }
372
373 let oob_ensemble_preds = self.calculate_ensemble_predictions(
375 ensemble_predictions,
376 &oob_indices,
377 ensemble_weights,
378 )?;
379
380 let oob_true_labels =
381 Array1::from_vec(oob_indices.iter().map(|&i| true_labels[i]).collect());
382
383 let oob_score = evaluation_fn(&oob_ensemble_preds, &oob_true_labels)?;
384 oob_scores.push(oob_score);
385 oob_predictions_all.push(oob_ensemble_preds);
386 }
387
388 let mean_oob_score = oob_scores.iter().sum::<Float>() / oob_scores.len() as Float;
389 let std_oob_score = {
390 let variance = oob_scores
391 .iter()
392 .map(|&score| (score - mean_oob_score).powi(2))
393 .sum::<Float>()
394 / oob_scores.len() as Float;
395 variance.sqrt()
396 };
397
398 let _alpha = 1.0 - confidence_level;
399 let z_score = 1.96; let margin_of_error = z_score * std_oob_score / (oob_scores.len() as Float).sqrt();
401 let confidence_interval = (
402 mean_oob_score - margin_of_error,
403 mean_oob_score + margin_of_error,
404 );
405
406 let oob_scores_result = OutOfBagScores {
407 oob_score: mean_oob_score,
408 oob_confidence_interval: confidence_interval,
409 feature_importance: None, prediction_intervals: None, individual_oob_scores: oob_scores,
412 };
413
414 let ensemble_preds = self.calculate_ensemble_predictions(
416 ensemble_predictions,
417 &(0..n_samples).collect::<Vec<_>>(),
418 ensemble_weights,
419 )?;
420 let ensemble_score = evaluation_fn(&ensemble_preds, true_labels)?;
421
422 let ensemble_performance = EnsemblePerformanceMetrics {
423 mean_performance: ensemble_score,
424 std_performance: std_oob_score,
425 confidence_interval,
426 individual_fold_scores: vec![ensemble_score],
427 ensemble_vs_best_member: 0.0, ensemble_vs_average_member: 0.0,
429 performance_gain: 0.0,
430 };
431
432 Ok(EnsembleEvaluationResult {
433 ensemble_performance,
434 diversity_analysis: DiversityAnalysis {
435 overall_diversity: 0.0,
436 pairwise_diversities: Array2::zeros((n_models, n_models)),
437 diversity_by_measure: HashMap::new(),
438 diversity_distribution: Vec::new(),
439 optimal_diversity_size: None,
440 },
441 stability_analysis: None,
442 member_contributions: Vec::new(),
443 out_of_bag_scores: Some(oob_scores_result),
444 progressive_performance: None,
445 multi_objective_analysis: None,
446 })
447 }
448
449 fn evaluate_cross_validation<F>(
451 &mut self,
452 ensemble_predictions: &Array2<Float>,
453 true_labels: &Array1<Float>,
454 ensemble_weights: Option<&Array1<Float>>,
455 model_predictions: Option<&Array2<Float>>,
456 evaluation_fn: &F,
457 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
458 where
459 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
460 {
461 let (_cv_strategy, n_folds) = match &self.config.strategy {
462 EnsembleEvaluationStrategy::EnsembleCrossValidation {
463 cv_strategy,
464 n_folds,
465 } => (cv_strategy, *n_folds),
466 _ => unreachable!(),
467 };
468
469 let n_samples = ensemble_predictions.nrows();
470 let fold_size = n_samples / n_folds;
471 let mut fold_scores = Vec::new();
472 let mut diversity_scores = Vec::new();
473
474 for fold in 0..n_folds {
475 let test_start = fold * fold_size;
476 let test_end = if fold == n_folds - 1 {
477 n_samples
478 } else {
479 (fold + 1) * fold_size
480 };
481 let test_indices: Vec<usize> = (test_start..test_end).collect();
482
483 let test_ensemble_preds = self.calculate_ensemble_predictions(
485 ensemble_predictions,
486 &test_indices,
487 ensemble_weights,
488 )?;
489
490 let test_true_labels =
491 Array1::from_vec(test_indices.iter().map(|&i| true_labels[i]).collect());
492
493 let fold_score = evaluation_fn(&test_ensemble_preds, &test_true_labels)?;
494 fold_scores.push(fold_score);
495
496 if let Some(model_preds) = model_predictions {
498 let mut fold_data = Vec::new();
499 for &i in test_indices.iter() {
500 fold_data.extend(model_preds.row(i).iter().cloned());
501 }
502 let fold_model_preds =
503 Array2::from_shape_vec((test_indices.len(), model_preds.ncols()), fold_data)?;
504
505 let diversity = self.calculate_q_statistic(&fold_model_preds)?;
506 diversity_scores.push(diversity);
507 }
508 }
509
510 let mean_performance = fold_scores.iter().sum::<Float>() / fold_scores.len() as Float;
511 let std_performance = {
512 let variance = fold_scores
513 .iter()
514 .map(|&score| (score - mean_performance).powi(2))
515 .sum::<Float>()
516 / fold_scores.len() as Float;
517 variance.sqrt()
518 };
519
520 let z_score = 1.96; let margin_of_error = z_score * std_performance / (fold_scores.len() as Float).sqrt();
522 let confidence_interval = (
523 mean_performance - margin_of_error,
524 mean_performance + margin_of_error,
525 );
526
527 let ensemble_performance = EnsemblePerformanceMetrics {
528 mean_performance,
529 std_performance,
530 confidence_interval,
531 individual_fold_scores: fold_scores,
532 ensemble_vs_best_member: 0.0, ensemble_vs_average_member: 0.0,
534 performance_gain: 0.0,
535 };
536
537 let diversity_analysis = if !diversity_scores.is_empty() {
538 let mean_diversity =
539 diversity_scores.iter().sum::<Float>() / diversity_scores.len() as Float;
540 DiversityAnalysis {
541 overall_diversity: mean_diversity,
542 pairwise_diversities: Array2::zeros((0, 0)), diversity_by_measure: {
544 let mut map = HashMap::new();
545 map.insert("q_statistic".to_string(), mean_diversity);
546 map
547 },
548 diversity_distribution: diversity_scores,
549 optimal_diversity_size: None,
550 }
551 } else {
552 DiversityAnalysis {
553 overall_diversity: 0.0,
554 pairwise_diversities: Array2::zeros((0, 0)),
555 diversity_by_measure: HashMap::new(),
556 diversity_distribution: Vec::new(),
557 optimal_diversity_size: None,
558 }
559 };
560
561 Ok(EnsembleEvaluationResult {
562 ensemble_performance,
563 diversity_analysis,
564 stability_analysis: None,
565 member_contributions: Vec::new(),
566 out_of_bag_scores: None,
567 progressive_performance: None,
568 multi_objective_analysis: None,
569 })
570 }
571
572 fn evaluate_diversity<F>(
574 &mut self,
575 ensemble_predictions: &Array2<Float>,
576 true_labels: &Array1<Float>,
577 model_predictions: Option<&Array2<Float>>,
578 evaluation_fn: &F,
579 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
580 where
581 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
582 {
583 let (diversity_measures, _diversity_threshold) = match &self.config.strategy {
584 EnsembleEvaluationStrategy::DiversityEvaluation {
585 diversity_measures,
586 diversity_threshold,
587 } => (diversity_measures, *diversity_threshold),
588 _ => unreachable!(),
589 };
590
591 if let Some(model_preds) = model_predictions {
592 let n_models = model_preds.ncols();
593 let mut diversity_by_measure = HashMap::new();
594 let mut pairwise_diversities = Array2::zeros((n_models, n_models));
595
596 for measure in diversity_measures {
597 let diversity_value = match measure {
598 DiversityMeasure::QStatistic => self.calculate_q_statistic(model_preds)?,
599 DiversityMeasure::CorrelationCoefficient => {
600 self.calculate_correlation_coefficient(model_preds)?
601 }
602 DiversityMeasure::DisagreementMeasure => {
603 self.calculate_disagreement_measure(model_preds)?
604 }
605 DiversityMeasure::DoubleFaultMeasure => {
606 self.calculate_double_fault_measure(model_preds, true_labels)?
607 }
608 DiversityMeasure::EntropyDiversity => {
609 self.calculate_entropy_diversity(model_preds)?
610 }
611 DiversityMeasure::KohaviWolpertVariance => {
612 self.calculate_kw_variance(model_preds, true_labels)?
613 }
614 DiversityMeasure::InterraterAgreement => {
615 self.calculate_interrater_agreement(model_preds)?
616 }
617 DiversityMeasure::DifficultyMeasure => {
618 self.calculate_difficulty_measure(model_preds, true_labels)?
619 }
620 DiversityMeasure::GeneralizedDiversity { alpha } => {
621 self.calculate_generalized_diversity(model_preds, *alpha)?
622 }
623 };
624
625 diversity_by_measure.insert(format!("{:?}", measure), diversity_value);
626 }
627
628 for i in 0..n_models {
630 for j in i + 1..n_models {
631 let pair_preds = Array2::from_shape_vec(
632 (model_preds.nrows(), 2),
633 model_preds
634 .column(i)
635 .iter()
636 .cloned()
637 .chain(model_preds.column(j).iter().cloned())
638 .collect(),
639 )?;
640 let pair_diversity = self.calculate_q_statistic(&pair_preds)?;
641 pairwise_diversities[[i, j]] = pair_diversity;
642 pairwise_diversities[[j, i]] = pair_diversity;
643 }
644 }
645
646 let overall_diversity =
647 diversity_by_measure.values().sum::<Float>() / diversity_by_measure.len() as Float;
648
649 let diversity_analysis = DiversityAnalysis {
650 overall_diversity,
651 pairwise_diversities,
652 diversity_by_measure,
653 diversity_distribution: Vec::new(), optimal_diversity_size: None, };
656
657 let ensemble_preds = self.calculate_ensemble_predictions(
659 ensemble_predictions,
660 &(0..ensemble_predictions.nrows()).collect::<Vec<_>>(),
661 None,
662 )?;
663 let ensemble_score = evaluation_fn(&ensemble_preds, true_labels)?;
664
665 let ensemble_performance = EnsemblePerformanceMetrics {
666 mean_performance: ensemble_score,
667 std_performance: 0.0,
668 confidence_interval: (ensemble_score, ensemble_score),
669 individual_fold_scores: vec![ensemble_score],
670 ensemble_vs_best_member: 0.0,
671 ensemble_vs_average_member: 0.0,
672 performance_gain: 0.0,
673 };
674
675 Ok(EnsembleEvaluationResult {
676 ensemble_performance,
677 diversity_analysis,
678 stability_analysis: None,
679 member_contributions: Vec::new(),
680 out_of_bag_scores: None,
681 progressive_performance: None,
682 multi_objective_analysis: None,
683 })
684 } else {
685 Err("Model predictions required for diversity evaluation".into())
686 }
687 }
688
689 fn evaluate_stability<F>(
691 &mut self,
692 ensemble_predictions: &Array2<Float>,
693 true_labels: &Array1<Float>,
694 ensemble_weights: Option<&Array1<Float>>,
695 evaluation_fn: &F,
696 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
697 where
698 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
699 {
700 let (n_bootstrap_samples, _stability_metrics) = match &self.config.strategy {
701 EnsembleEvaluationStrategy::StabilityAnalysis {
702 n_bootstrap_samples,
703 stability_metrics,
704 } => (*n_bootstrap_samples, stability_metrics),
705 _ => unreachable!(),
706 };
707
708 let n_samples = ensemble_predictions.nrows();
709 let mut bootstrap_scores = Vec::new();
710 let mut bootstrap_predictions = Vec::new();
711
712 for _ in 0..n_bootstrap_samples {
713 let bootstrap_indices: Vec<usize> = (0..n_samples)
715 .map(|_| self.rng.gen_range(0..n_samples))
716 .collect();
717
718 let bootstrap_preds = self.calculate_ensemble_predictions(
719 ensemble_predictions,
720 &bootstrap_indices,
721 ensemble_weights,
722 )?;
723
724 let bootstrap_labels =
725 Array1::from_vec(bootstrap_indices.iter().map(|&i| true_labels[i]).collect());
726
727 let bootstrap_score = evaluation_fn(&bootstrap_preds, &bootstrap_labels)?;
728 bootstrap_scores.push(bootstrap_score);
729 bootstrap_predictions.push(bootstrap_preds);
730 }
731
732 let prediction_stability = self.calculate_prediction_stability(&bootstrap_predictions)?;
734
735 let mean_score = bootstrap_scores.iter().sum::<Float>() / bootstrap_scores.len() as Float;
737 let score_variance = bootstrap_scores
738 .iter()
739 .map(|&score| (score - mean_score).powi(2))
740 .sum::<Float>()
741 / bootstrap_scores.len() as Float;
742 let performance_stability = 1.0 / (1.0 + score_variance); let stability_analysis = StabilityAnalysis {
745 prediction_stability,
746 model_selection_stability: 0.8, weight_stability: None, performance_stability,
749 stability_confidence_intervals: HashMap::new(), };
751
752 let ensemble_performance = EnsemblePerformanceMetrics {
753 mean_performance: mean_score,
754 std_performance: score_variance.sqrt(),
755 confidence_interval: (
756 mean_score - score_variance.sqrt(),
757 mean_score + score_variance.sqrt(),
758 ),
759 individual_fold_scores: bootstrap_scores,
760 ensemble_vs_best_member: 0.0,
761 ensemble_vs_average_member: 0.0,
762 performance_gain: 0.0,
763 };
764
765 Ok(EnsembleEvaluationResult {
766 ensemble_performance,
767 diversity_analysis: DiversityAnalysis {
768 overall_diversity: 0.0,
769 pairwise_diversities: Array2::zeros((0, 0)),
770 diversity_by_measure: HashMap::new(),
771 diversity_distribution: Vec::new(),
772 optimal_diversity_size: None,
773 },
774 stability_analysis: Some(stability_analysis),
775 member_contributions: Vec::new(),
776 out_of_bag_scores: None,
777 progressive_performance: None,
778 multi_objective_analysis: None,
779 })
780 }
781
782 fn evaluate_progressive<F>(
784 &mut self,
785 _ensemble_predictions: &Array2<Float>,
786 true_labels: &Array1<Float>,
787 model_predictions: Option<&Array2<Float>>,
788 evaluation_fn: &F,
789 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
790 where
791 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
792 {
793 let (ensemble_sizes, _selection_strategy) = match &self.config.strategy {
794 EnsembleEvaluationStrategy::ProgressiveEvaluation {
795 ensemble_sizes,
796 selection_strategy,
797 } => (ensemble_sizes, selection_strategy),
798 _ => unreachable!(),
799 };
800
801 if let Some(model_preds) = model_predictions {
802 let mut performance_curve = Vec::new();
803 let mut diversity_curve = Vec::new();
804 let n_models = model_preds.ncols();
805
806 for &size in ensemble_sizes {
807 if size <= n_models {
808 let selected_indices: Vec<usize> = (0..size).collect();
810
811 let mut selected_data = Vec::new();
813 for &i in selected_indices.iter() {
814 selected_data.extend(model_preds.column(i).iter().cloned());
815 }
816 let selected_predictions =
817 Array2::from_shape_vec((model_preds.nrows(), size), selected_data)?;
818
819 let ensemble_preds = selected_predictions.mean_axis(Axis(1)).unwrap();
820 let performance = evaluation_fn(&ensemble_preds, true_labels)?;
821 performance_curve.push(performance);
822
823 let diversity = self.calculate_q_statistic(&selected_predictions)?;
825 diversity_curve.push(diversity);
826 }
827 }
828
829 let optimal_size_idx = performance_curve
831 .iter()
832 .enumerate()
833 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
834 .map(|(idx, _)| idx)
835 .unwrap_or(0);
836 let optimal_size = ensemble_sizes[optimal_size_idx];
837
838 let progressive_performance = ProgressivePerformance {
839 ensemble_sizes: ensemble_sizes.clone(),
840 performance_curve,
841 diversity_curve,
842 efficiency_curve: vec![1.0; ensemble_sizes.len()], optimal_size,
844 diminishing_returns_threshold: None, };
846
847 let ensemble_performance = EnsemblePerformanceMetrics {
848 mean_performance: progressive_performance.performance_curve[optimal_size_idx],
849 std_performance: 0.0,
850 confidence_interval: (0.0, 0.0),
851 individual_fold_scores: vec![],
852 ensemble_vs_best_member: 0.0,
853 ensemble_vs_average_member: 0.0,
854 performance_gain: 0.0,
855 };
856
857 Ok(EnsembleEvaluationResult {
858 ensemble_performance,
859 diversity_analysis: DiversityAnalysis {
860 overall_diversity: 0.0,
861 pairwise_diversities: Array2::zeros((0, 0)),
862 diversity_by_measure: HashMap::new(),
863 diversity_distribution: Vec::new(),
864 optimal_diversity_size: Some(optimal_size),
865 },
866 stability_analysis: None,
867 member_contributions: Vec::new(),
868 out_of_bag_scores: None,
869 progressive_performance: Some(progressive_performance),
870 multi_objective_analysis: None,
871 })
872 } else {
873 Err("Model predictions required for progressive evaluation".into())
874 }
875 }
876
877 fn evaluate_multi_objective<F>(
879 &mut self,
880 ensemble_predictions: &Array2<Float>,
881 true_labels: &Array1<Float>,
882 ensemble_weights: Option<&Array1<Float>>,
883 evaluation_fn: &F,
884 ) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
885 where
886 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
887 {
888 let ensemble_preds = self.calculate_ensemble_predictions(
890 ensemble_predictions,
891 &(0..ensemble_predictions.nrows()).collect::<Vec<_>>(),
892 ensemble_weights,
893 )?;
894 let performance = evaluation_fn(&ensemble_preds, true_labels)?;
895
896 let mut objective_scores = HashMap::new();
897 objective_scores.insert("accuracy".to_string(), performance);
898 objective_scores.insert("diversity".to_string(), 0.5); objective_scores.insert("efficiency".to_string(), 0.8); let multi_objective_analysis = MultiObjectiveAnalysis {
902 pareto_front: vec![(performance, 0.5)], objective_scores,
904 trade_off_analysis: HashMap::new(),
905 dominated_solutions: Vec::new(),
906 compromise_solution: Some(0),
907 };
908
909 let ensemble_performance = EnsemblePerformanceMetrics {
910 mean_performance: performance,
911 std_performance: 0.0,
912 confidence_interval: (performance, performance),
913 individual_fold_scores: vec![performance],
914 ensemble_vs_best_member: 0.0,
915 ensemble_vs_average_member: 0.0,
916 performance_gain: 0.0,
917 };
918
919 Ok(EnsembleEvaluationResult {
920 ensemble_performance,
921 diversity_analysis: DiversityAnalysis {
922 overall_diversity: 0.0,
923 pairwise_diversities: Array2::zeros((0, 0)),
924 diversity_by_measure: HashMap::new(),
925 diversity_distribution: Vec::new(),
926 optimal_diversity_size: None,
927 },
928 stability_analysis: None,
929 member_contributions: Vec::new(),
930 out_of_bag_scores: None,
931 progressive_performance: None,
932 multi_objective_analysis: Some(multi_objective_analysis),
933 })
934 }
935
936 fn calculate_ensemble_predictions(
938 &self,
939 ensemble_predictions: &Array2<Float>,
940 indices: &[usize],
941 weights: Option<&Array1<Float>>,
942 ) -> Result<Array1<Float>, Box<dyn std::error::Error>> {
943 let mut selected_data = Vec::new();
944 for &i in indices.iter() {
945 selected_data.extend(ensemble_predictions.row(i).iter().cloned());
946 }
947 let selected_predictions =
948 Array2::from_shape_vec((indices.len(), ensemble_predictions.ncols()), selected_data)?;
949
950 if let Some(w) = weights {
951 Ok(selected_predictions.dot(w))
953 } else {
954 Ok(selected_predictions.mean_axis(Axis(1)).unwrap())
956 }
957 }
958
959 fn calculate_q_statistic(
961 &self,
962 predictions: &Array2<Float>,
963 ) -> Result<Float, Box<dyn std::error::Error>> {
964 let n_models = predictions.ncols();
965 if n_models < 2 {
966 return Ok(0.0);
967 }
968
969 let mut q_sum = 0.0;
970 let mut pairs = 0;
971
972 for i in 0..n_models {
973 for j in i + 1..n_models {
974 let pred_i = predictions.column(i);
975 let pred_j = predictions.column(j);
976
977 let mut n11 = 0; let mut n10 = 0; let mut n01 = 0; let mut n00 = 0; for k in 0..predictions.nrows() {
983 let i_correct = pred_i[k] > 0.5;
984 let j_correct = pred_j[k] > 0.5;
985
986 match (i_correct, j_correct) {
987 (true, true) => n11 += 1,
988 (true, false) => n10 += 1,
989 (false, true) => n01 += 1,
990 (false, false) => n00 += 1,
991 }
992 }
993
994 let numerator = (n11 * n00 - n01 * n10) as Float;
995 let denominator = (n11 * n00 + n01 * n10) as Float;
996
997 if denominator != 0.0 {
998 q_sum += numerator / denominator;
999 pairs += 1;
1000 }
1001 }
1002 }
1003
1004 Ok(if pairs > 0 {
1005 q_sum / pairs as Float
1006 } else {
1007 0.0
1008 })
1009 }
1010
1011 fn calculate_correlation_coefficient(
1013 &self,
1014 predictions: &Array2<Float>,
1015 ) -> Result<Float, Box<dyn std::error::Error>> {
1016 let n_models = predictions.ncols();
1017 if n_models < 2 {
1018 return Ok(0.0);
1019 }
1020
1021 let mut correlations = Vec::new();
1022 for i in 0..n_models {
1023 for j in i + 1..n_models {
1024 let pred_i = predictions.column(i);
1025 let pred_j = predictions.column(j);
1026
1027 let mean_i = pred_i.mean().unwrap_or(0.0);
1028 let mean_j = pred_j.mean().unwrap_or(0.0);
1029
1030 let mut covariance = 0.0;
1031 let mut var_i = 0.0;
1032 let mut var_j = 0.0;
1033
1034 for k in 0..predictions.nrows() {
1035 let diff_i = pred_i[k] - mean_i;
1036 let diff_j = pred_j[k] - mean_j;
1037 covariance += diff_i * diff_j;
1038 var_i += diff_i * diff_i;
1039 var_j += diff_j * diff_j;
1040 }
1041
1042 let correlation = if var_i > 0.0 && var_j > 0.0 {
1043 covariance / (var_i.sqrt() * var_j.sqrt())
1044 } else {
1045 0.0
1046 };
1047
1048 correlations.push(correlation.abs());
1049 }
1050 }
1051
1052 Ok(1.0 - correlations.iter().sum::<Float>() / correlations.len() as Float)
1053 }
1054
1055 fn calculate_disagreement_measure(
1057 &self,
1058 predictions: &Array2<Float>,
1059 ) -> Result<Float, Box<dyn std::error::Error>> {
1060 let n_models = predictions.ncols();
1061 if n_models < 2 {
1062 return Ok(0.0);
1063 }
1064
1065 let mut disagreement_sum = 0.0;
1066 let mut pairs = 0;
1067
1068 for i in 0..n_models {
1069 for j in i + 1..n_models {
1070 let pred_i = predictions.column(i);
1071 let pred_j = predictions.column(j);
1072
1073 let mut disagreements = 0;
1074 for k in 0..predictions.nrows() {
1075 if (pred_i[k] > 0.5) != (pred_j[k] > 0.5) {
1076 disagreements += 1;
1077 }
1078 }
1079
1080 disagreement_sum += disagreements as Float / predictions.nrows() as Float;
1081 pairs += 1;
1082 }
1083 }
1084
1085 Ok(if pairs > 0 {
1086 disagreement_sum / pairs as Float
1087 } else {
1088 0.0
1089 })
1090 }
1091
1092 fn calculate_double_fault_measure(
1094 &self,
1095 predictions: &Array2<Float>,
1096 true_labels: &Array1<Float>,
1097 ) -> Result<Float, Box<dyn std::error::Error>> {
1098 let n_models = predictions.ncols();
1099 if n_models < 2 {
1100 return Ok(0.0);
1101 }
1102
1103 let mut double_fault_sum = 0.0;
1104 let mut pairs = 0;
1105
1106 for i in 0..n_models {
1107 for j in i + 1..n_models {
1108 let pred_i = predictions.column(i);
1109 let pred_j = predictions.column(j);
1110
1111 let mut double_faults = 0;
1112 for k in 0..predictions.nrows() {
1113 let i_wrong = (pred_i[k] > 0.5) != (true_labels[k] > 0.5);
1114 let j_wrong = (pred_j[k] > 0.5) != (true_labels[k] > 0.5);
1115
1116 if i_wrong && j_wrong {
1117 double_faults += 1;
1118 }
1119 }
1120
1121 double_fault_sum += double_faults as Float / predictions.nrows() as Float;
1122 pairs += 1;
1123 }
1124 }
1125
1126 Ok(if pairs > 0 {
1127 double_fault_sum / pairs as Float
1128 } else {
1129 0.0
1130 })
1131 }
1132
1133 fn calculate_entropy_diversity(
1135 &self,
1136 predictions: &Array2<Float>,
1137 ) -> Result<Float, Box<dyn std::error::Error>> {
1138 let n_samples = predictions.nrows();
1139 let n_models = predictions.ncols();
1140
1141 let mut entropy_sum = 0.0;
1142
1143 for i in 0..n_samples {
1144 let correct_count = predictions
1145 .row(i)
1146 .iter()
1147 .filter(|&&pred| pred > 0.5)
1148 .count() as Float;
1149
1150 let p = correct_count / n_models as Float;
1151 if p > 0.0 && p < 1.0 {
1152 entropy_sum += -p * p.log2() - (1.0 - p) * (1.0 - p).log2();
1153 }
1154 }
1155
1156 Ok(entropy_sum / n_samples as Float)
1157 }
1158
1159 fn calculate_kw_variance(
1161 &self,
1162 predictions: &Array2<Float>,
1163 true_labels: &Array1<Float>,
1164 ) -> Result<Float, Box<dyn std::error::Error>> {
1165 let n_samples = predictions.nrows();
1166 let n_models = predictions.ncols();
1167
1168 let mut variance_sum = 0.0;
1169
1170 for i in 0..n_samples {
1171 let correct_count = predictions
1172 .row(i)
1173 .iter()
1174 .filter(|&&pred| (pred > 0.5) == (true_labels[i] > 0.5))
1175 .count() as Float;
1176
1177 let l = correct_count / n_models as Float;
1178 variance_sum += l * (1.0 - l);
1179 }
1180
1181 Ok(variance_sum / n_samples as Float)
1182 }
1183
1184 fn calculate_interrater_agreement(
1186 &self,
1187 predictions: &Array2<Float>,
1188 ) -> Result<Float, Box<dyn std::error::Error>> {
1189 let n_models = predictions.ncols();
1190 if n_models < 2 {
1191 return Ok(0.0);
1192 }
1193
1194 let mut agreement_sum = 0.0;
1195 let mut pairs = 0;
1196
1197 for i in 0..n_models {
1198 for j in i + 1..n_models {
1199 let pred_i = predictions.column(i);
1200 let pred_j = predictions.column(j);
1201
1202 let mut agreements = 0;
1203 for k in 0..predictions.nrows() {
1204 if (pred_i[k] > 0.5) == (pred_j[k] > 0.5) {
1205 agreements += 1;
1206 }
1207 }
1208
1209 agreement_sum += agreements as Float / predictions.nrows() as Float;
1210 pairs += 1;
1211 }
1212 }
1213
1214 Ok(if pairs > 0 {
1215 agreement_sum / pairs as Float
1216 } else {
1217 0.0
1218 })
1219 }
1220
1221 fn calculate_difficulty_measure(
1223 &self,
1224 predictions: &Array2<Float>,
1225 true_labels: &Array1<Float>,
1226 ) -> Result<Float, Box<dyn std::error::Error>> {
1227 let n_samples = predictions.nrows();
1228 let n_models = predictions.ncols();
1229
1230 let mut difficulty_sum = 0.0;
1231
1232 for i in 0..n_samples {
1233 let error_count = predictions
1234 .row(i)
1235 .iter()
1236 .filter(|&&pred| (pred > 0.5) != (true_labels[i] > 0.5))
1237 .count() as Float;
1238
1239 difficulty_sum += error_count / n_models as Float;
1240 }
1241
1242 Ok(difficulty_sum / n_samples as Float)
1243 }
1244
1245 fn calculate_generalized_diversity(
1247 &self,
1248 predictions: &Array2<Float>,
1249 alpha: Float,
1250 ) -> Result<Float, Box<dyn std::error::Error>> {
1251 let n_samples = predictions.nrows();
1252 let n_models = predictions.ncols();
1253
1254 let mut diversity_sum = 0.0;
1255
1256 for i in 0..n_samples {
1257 let correct_count = predictions
1258 .row(i)
1259 .iter()
1260 .filter(|&&pred| pred > 0.5)
1261 .count() as Float;
1262
1263 let p = correct_count / n_models as Float;
1264 if alpha != 1.0 {
1265 diversity_sum += (1.0 - p.powf(alpha) - (1.0 - p).powf(alpha))
1266 / (2.0_f64.powf(1.0 - alpha) as Float - 1.0);
1267 } else {
1268 if p > 0.0 && p < 1.0 {
1270 diversity_sum += -p * p.log2() - (1.0 - p) * (1.0 - p).log2();
1271 }
1272 }
1273 }
1274
1275 Ok(diversity_sum / n_samples as Float)
1276 }
1277
1278 fn calculate_prediction_stability(
1280 &self,
1281 predictions: &[Array1<Float>],
1282 ) -> Result<Float, Box<dyn std::error::Error>> {
1283 if predictions.len() < 2 {
1284 return Ok(1.0);
1285 }
1286
1287 let n_samples = predictions[0].len();
1288 let mut stability_sum = 0.0;
1289
1290 for i in 0..n_samples {
1291 let sample_predictions: Vec<Float> = predictions.iter().map(|p| p[i]).collect();
1292 let mean_pred =
1293 sample_predictions.iter().sum::<Float>() / sample_predictions.len() as Float;
1294 let variance = sample_predictions
1295 .iter()
1296 .map(|&pred| (pred - mean_pred).powi(2))
1297 .sum::<Float>()
1298 / sample_predictions.len() as Float;
1299
1300 stability_sum += 1.0 / (1.0 + variance); }
1302
1303 Ok(stability_sum / n_samples as Float)
1304 }
1305}
1306
1307pub fn evaluate_ensemble<F>(
1309 ensemble_predictions: &Array2<Float>,
1310 true_labels: &Array1<Float>,
1311 ensemble_weights: Option<&Array1<Float>>,
1312 model_predictions: Option<&Array2<Float>>,
1313 evaluation_fn: F,
1314 config: Option<EnsembleEvaluationConfig>,
1315) -> Result<EnsembleEvaluationResult, Box<dyn std::error::Error>>
1316where
1317 F: Fn(&Array1<Float>, &Array1<Float>) -> Result<Float, Box<dyn std::error::Error>>,
1318{
1319 let config = config.unwrap_or_default();
1320 let mut evaluator = EnsembleEvaluator::new(config);
1321 evaluator.evaluate(
1322 ensemble_predictions,
1323 true_labels,
1324 ensemble_weights,
1325 model_predictions,
1326 evaluation_fn,
1327 )
1328}
1329
1330#[allow(non_snake_case)]
1331#[cfg(test)]
1332mod tests {
1333 use super::*;
1334
1335 fn mock_evaluation_function(
1336 predictions: &Array1<Float>,
1337 labels: &Array1<Float>,
1338 ) -> Result<Float, Box<dyn std::error::Error>> {
1339 let correct = predictions
1340 .iter()
1341 .zip(labels.iter())
1342 .filter(|(&pred, &label)| (pred > 0.5) == (label > 0.5))
1343 .count();
1344 Ok(correct as Float / predictions.len() as Float)
1345 }
1346
1347 #[test]
1348 fn test_ensemble_evaluator_creation() {
1349 let config = EnsembleEvaluationConfig::default();
1350 let evaluator = EnsembleEvaluator::new(config);
1351 assert_eq!(evaluator.config.confidence_level, 0.95);
1352 }
1353
1354 #[test]
1355 fn test_out_of_bag_evaluation() {
1356 let ensemble_predictions = Array2::from_shape_vec(
1357 (10, 3),
1358 vec![
1359 0.1, 0.8, 0.3, 0.9, 0.2, 0.7, 0.4, 0.6, 0.8, 0.1, 0.2, 0.9, 0.1, 0.8, 0.3, 0.6,
1360 0.5, 0.7, 0.9, 0.2, 0.3, 0.7, 0.2, 0.9, 0.1, 0.8, 0.4, 0.5, 0.6, 0.3,
1361 ],
1362 )
1363 .unwrap();
1364 let true_labels = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
1365
1366 let config = EnsembleEvaluationConfig {
1367 strategy: EnsembleEvaluationStrategy::OutOfBag {
1368 bootstrap_samples: 10,
1369 confidence_level: 0.95,
1370 },
1371 ..Default::default()
1372 };
1373
1374 let result = evaluate_ensemble(
1375 &ensemble_predictions,
1376 &true_labels,
1377 None,
1378 None,
1379 mock_evaluation_function,
1380 Some(config),
1381 )
1382 .unwrap();
1383
1384 assert!(result.out_of_bag_scores.is_some());
1385 let oob_scores = result.out_of_bag_scores.unwrap();
1386 assert!(oob_scores.oob_score >= 0.0 && oob_scores.oob_score <= 1.0);
1387 }
1388
1389 #[test]
1390 fn test_diversity_evaluation() {
1391 let ensemble_predictions = Array2::from_shape_vec(
1392 (10, 3),
1393 vec![
1394 0.1, 0.8, 0.3, 0.9, 0.2, 0.7, 0.4, 0.6, 0.8, 0.1, 0.2, 0.9, 0.1, 0.8, 0.3, 0.6,
1395 0.5, 0.7, 0.9, 0.2, 0.3, 0.7, 0.2, 0.9, 0.1, 0.8, 0.4, 0.5, 0.6, 0.3,
1396 ],
1397 )
1398 .unwrap();
1399 let true_labels = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
1400 let model_predictions = ensemble_predictions.clone();
1401
1402 let config = EnsembleEvaluationConfig {
1403 strategy: EnsembleEvaluationStrategy::DiversityEvaluation {
1404 diversity_measures: vec![
1405 DiversityMeasure::QStatistic,
1406 DiversityMeasure::DisagreementMeasure,
1407 ],
1408 diversity_threshold: 0.5,
1409 },
1410 ..Default::default()
1411 };
1412
1413 let result = evaluate_ensemble(
1414 &ensemble_predictions,
1415 &true_labels,
1416 None,
1417 Some(&model_predictions),
1418 mock_evaluation_function,
1419 Some(config),
1420 )
1421 .unwrap();
1422
1423 assert!(!result.diversity_analysis.diversity_by_measure.is_empty());
1424 assert!(result.diversity_analysis.overall_diversity >= 0.0);
1425 }
1426
1427 #[test]
1428 fn test_cross_validation_evaluation() {
1429 let ensemble_predictions = Array2::from_shape_vec(
1430 (10, 3),
1431 vec![
1432 0.1, 0.8, 0.3, 0.9, 0.2, 0.7, 0.4, 0.6, 0.8, 0.1, 0.2, 0.9, 0.1, 0.8, 0.3, 0.6,
1433 0.5, 0.7, 0.9, 0.2, 0.3, 0.7, 0.2, 0.9, 0.1, 0.8, 0.4, 0.5, 0.6, 0.3,
1434 ],
1435 )
1436 .unwrap();
1437 let true_labels = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
1438
1439 let config = EnsembleEvaluationConfig {
1440 strategy: EnsembleEvaluationStrategy::EnsembleCrossValidation {
1441 cv_strategy: EnsembleCVStrategy::KFoldEnsemble,
1442 n_folds: 5,
1443 },
1444 ..Default::default()
1445 };
1446
1447 let result = evaluate_ensemble(
1448 &ensemble_predictions,
1449 &true_labels,
1450 None,
1451 None,
1452 mock_evaluation_function,
1453 Some(config),
1454 )
1455 .unwrap();
1456
1457 assert!(!result
1458 .ensemble_performance
1459 .individual_fold_scores
1460 .is_empty());
1461 assert!(result.ensemble_performance.mean_performance >= 0.0);
1462 }
1463}