Skip to main content

sklears_model_selection/
ensemble_selection.rs

1//! Ensemble model selection with automatic composition strategies
2//!
3//! This module provides tools for automatically selecting and composing ensemble models.
4//! It includes various ensemble strategies like voting, stacking, blending, and dynamic
5//! selection with automatic hyperparameter optimization for ensemble components.
6
7use crate::cross_validation::CrossValidator;
8use sklears_core::{
9    error::{Result, SklearsError},
10    traits::{Estimator, Fit, Predict},
11};
12
13// Simple scoring trait for testing
14pub trait Scoring {
15    fn score(&self, y_true: &[f64], y_pred: &[f64]) -> Result<f64>;
16}
17use std::fmt::{self, Display, Formatter};
18
19/// Result of ensemble model selection
20#[derive(Debug, Clone)]
21pub struct EnsembleSelectionResult {
22    /// Selected ensemble strategy
23    pub ensemble_strategy: EnsembleStrategy,
24    /// Base models included in the ensemble
25    pub selected_models: Vec<ModelInfo>,
26    /// Ensemble weights (if applicable)
27    pub model_weights: Vec<f64>,
28    /// Cross-validation performance of the ensemble
29    pub ensemble_performance: EnsemblePerformance,
30    /// Individual model performances
31    pub individual_performances: Vec<ModelPerformance>,
32    /// Diversity measures
33    pub diversity_measures: DiversityMeasures,
34}
35
36/// Information about a model in the ensemble
37#[derive(Debug, Clone)]
38pub struct ModelInfo {
39    /// Model index in the original candidate list
40    pub model_index: usize,
41    /// Model name/identifier
42    pub model_name: String,
43    /// Model weight in the ensemble
44    pub weight: f64,
45    /// Individual performance score
46    pub individual_score: f64,
47    /// Contribution to ensemble performance
48    pub contribution_score: f64,
49}
50
51/// Performance metrics for the ensemble
52#[derive(Debug, Clone)]
53pub struct EnsemblePerformance {
54    /// Mean cross-validation score
55    pub mean_score: f64,
56    /// Standard deviation of CV scores
57    pub std_score: f64,
58    /// Individual fold scores
59    pub fold_scores: Vec<f64>,
60    /// Improvement over best individual model
61    pub improvement_over_best: f64,
62    /// Ensemble size
63    pub ensemble_size: usize,
64}
65
66/// Performance metrics for individual models
67#[derive(Debug, Clone)]
68pub struct ModelPerformance {
69    /// Model index
70    pub model_index: usize,
71    /// Model name
72    pub model_name: String,
73    /// Cross-validation score
74    pub cv_score: f64,
75    /// Standard deviation
76    pub cv_std: f64,
77    /// Correlation with other models
78    pub avg_correlation: f64,
79}
80
81/// Diversity measures for the ensemble
82#[derive(Debug, Clone)]
83pub struct DiversityMeasures {
84    /// Average pairwise correlation between predictions
85    pub avg_correlation: f64,
86    /// Disagreement measure
87    pub disagreement: f64,
88    /// Q statistic (average pairwise Q statistic)
89    pub q_statistic: f64,
90    /// Entropy-based diversity
91    pub entropy_diversity: f64,
92}
93
94/// Ensemble composition strategies
95#[derive(Debug, Clone, PartialEq)]
96pub enum EnsembleStrategy {
97    /// Simple voting (equal weights)
98    Voting,
99    /// Weighted voting based on individual performance
100    WeightedVoting,
101    /// Stacking with meta-learner
102    Stacking { meta_learner: String },
103    /// Blending (holdout-based stacking)
104    Blending { blend_ratio: f64 },
105    /// Dynamic selection based on instance
106    DynamicSelection,
107    /// Bayesian model averaging
108    BayesianAveraging,
109}
110
111impl Display for EnsembleStrategy {
112    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
113        match self {
114            EnsembleStrategy::Voting => write!(f, "Simple Voting"),
115            EnsembleStrategy::WeightedVoting => write!(f, "Weighted Voting"),
116            EnsembleStrategy::Stacking { meta_learner } => write!(f, "Stacking ({})", meta_learner),
117            EnsembleStrategy::Blending { blend_ratio } => {
118                write!(f, "Blending (ratio: {:.2})", blend_ratio)
119            }
120            EnsembleStrategy::DynamicSelection => write!(f, "Dynamic Selection"),
121            EnsembleStrategy::BayesianAveraging => write!(f, "Bayesian Averaging"),
122        }
123    }
124}
125
126impl Display for EnsembleSelectionResult {
127    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
128        writeln!(f, "Ensemble Selection Results:")?;
129        writeln!(f, "Strategy: {}", self.ensemble_strategy)?;
130        writeln!(f, "Ensemble Size: {}", self.selected_models.len())?;
131        writeln!(
132            f,
133            "Ensemble Performance: {:.4} ± {:.4}",
134            self.ensemble_performance.mean_score, self.ensemble_performance.std_score
135        )?;
136        writeln!(
137            f,
138            "Improvement over Best Individual: {:.4}",
139            self.ensemble_performance.improvement_over_best
140        )?;
141        writeln!(
142            f,
143            "Average Diversity (Correlation): {:.4}",
144            self.diversity_measures.avg_correlation
145        )?;
146        writeln!(f, "\nSelected Models:")?;
147        for model in &self.selected_models {
148            writeln!(
149                f,
150                "  {} - Weight: {:.3}, Score: {:.4}",
151                model.model_name, model.weight, model.individual_score
152            )?;
153        }
154        Ok(())
155    }
156}
157
158/// Configuration for ensemble selection
159#[derive(Debug, Clone)]
160pub struct EnsembleSelectionConfig {
161    /// Maximum ensemble size
162    pub max_ensemble_size: usize,
163    /// Minimum ensemble size
164    pub min_ensemble_size: usize,
165    /// Strategies to consider
166    pub candidate_strategies: Vec<EnsembleStrategy>,
167    /// Diversity threshold (minimum required diversity)
168    pub diversity_threshold: f64,
169    /// Whether to use greedy selection
170    pub use_greedy_selection: bool,
171    /// Performance improvement threshold
172    pub improvement_threshold: f64,
173    /// Cross-validation folds for ensemble evaluation
174    pub cv_folds: usize,
175    /// Random seed for reproducibility
176    pub random_seed: Option<u64>,
177}
178
179impl Default for EnsembleSelectionConfig {
180    fn default() -> Self {
181        Self {
182            max_ensemble_size: 10,
183            min_ensemble_size: 2,
184            candidate_strategies: vec![
185                EnsembleStrategy::Voting,
186                EnsembleStrategy::WeightedVoting,
187                EnsembleStrategy::Stacking {
188                    meta_learner: "Linear".to_string(),
189                },
190                EnsembleStrategy::Blending { blend_ratio: 0.2 },
191            ],
192            diversity_threshold: 0.1,
193            use_greedy_selection: true,
194            improvement_threshold: 0.01,
195            cv_folds: 5,
196            random_seed: None,
197        }
198    }
199}
200
201/// Ensemble model selector
202pub struct EnsembleSelector {
203    config: EnsembleSelectionConfig,
204}
205
206impl EnsembleSelector {
207    /// Create a new ensemble selector with default configuration
208    pub fn new() -> Self {
209        Self {
210            config: EnsembleSelectionConfig::default(),
211        }
212    }
213
214    /// Create a new ensemble selector with custom configuration
215    pub fn with_config(config: EnsembleSelectionConfig) -> Self {
216        Self { config }
217    }
218
219    /// Set maximum ensemble size
220    pub fn max_ensemble_size(mut self, size: usize) -> Self {
221        self.config.max_ensemble_size = size;
222        self
223    }
224
225    /// Set minimum ensemble size
226    pub fn min_ensemble_size(mut self, size: usize) -> Self {
227        self.config.min_ensemble_size = size;
228        self
229    }
230
231    /// Set candidate strategies
232    pub fn strategies(mut self, strategies: Vec<EnsembleStrategy>) -> Self {
233        self.config.candidate_strategies = strategies;
234        self
235    }
236
237    /// Set diversity threshold
238    pub fn diversity_threshold(mut self, threshold: f64) -> Self {
239        self.config.diversity_threshold = threshold;
240        self
241    }
242
243    /// Enable or disable greedy selection
244    pub fn use_greedy_selection(mut self, use_greedy: bool) -> Self {
245        self.config.use_greedy_selection = use_greedy;
246        self
247    }
248
249    /// Select optimal ensemble from candidate models
250    pub fn select_ensemble<E, X, Y>(
251        &self,
252        models: &[(E, String)],
253        x: &[X],
254        y: &[Y],
255        cv: &dyn CrossValidator,
256        scoring: &dyn Scoring,
257    ) -> Result<EnsembleSelectionResult>
258    where
259        E: Estimator + Clone,
260        X: Clone,
261        Y: Clone + Into<f64>,
262    {
263        if models.len() < self.config.min_ensemble_size {
264            return Err(SklearsError::InvalidParameter {
265                name: "models".to_string(),
266                reason: format!(
267                    "at least {} models required for ensemble",
268                    self.config.min_ensemble_size
269                ),
270            });
271        }
272
273        // Evaluate individual models
274        let individual_performances = self.evaluate_individual_models(models, x, y, cv, scoring)?;
275
276        // Generate ensemble candidates
277        let ensemble_candidates = self.generate_ensemble_candidates(&individual_performances)?;
278
279        // Evaluate ensemble candidates
280        let mut best_ensemble = None;
281        let mut best_score = f64::NEG_INFINITY;
282
283        for candidate in &ensemble_candidates {
284            let ensemble_performance =
285                self.evaluate_ensemble_candidate(models, candidate, x, y, cv, scoring)?;
286
287            if ensemble_performance.mean_score > best_score {
288                best_score = ensemble_performance.mean_score;
289                best_ensemble = Some((candidate.clone(), ensemble_performance));
290            }
291        }
292
293        let (best_candidate, ensemble_performance) =
294            best_ensemble.ok_or_else(|| SklearsError::InvalidParameter {
295                name: "ensemble".to_string(),
296                reason: "no valid ensemble found".to_string(),
297            })?;
298
299        // Calculate diversity measures
300        let diversity_measures =
301            self.calculate_diversity_measures(models, &best_candidate.selected_models, x, y)?;
302
303        // Calculate improvement over best individual model
304        let best_individual_score = individual_performances
305            .iter()
306            .map(|p| p.cv_score)
307            .fold(f64::NEG_INFINITY, f64::max);
308
309        let mut ensemble_performance = ensemble_performance;
310        ensemble_performance.improvement_over_best =
311            ensemble_performance.mean_score - best_individual_score;
312
313        Ok(EnsembleSelectionResult {
314            ensemble_strategy: best_candidate.ensemble_strategy,
315            selected_models: best_candidate.selected_models,
316            model_weights: best_candidate.model_weights,
317            ensemble_performance,
318            individual_performances,
319            diversity_measures,
320        })
321    }
322
323    /// Evaluate individual model performances
324    fn evaluate_individual_models<E, X, Y>(
325        &self,
326        models: &[(E, String)],
327        _x: &[X],
328        _y: &[Y],
329        _cv: &dyn CrossValidator,
330        _scoring: &dyn Scoring,
331    ) -> Result<Vec<ModelPerformance>>
332    where
333        E: Estimator + Clone,
334        X: Clone,
335        Y: Clone + Into<f64>,
336    {
337        // Placeholder implementation - create dummy performance data
338        let mut performances = Vec::new();
339        for (idx, (_, name)) in models.iter().enumerate() {
340            performances.push(ModelPerformance {
341                model_index: idx,
342                model_name: name.clone(),
343                cv_score: 0.8 + (idx as f64) * 0.05, // Dummy scores
344                cv_std: 0.1,
345                avg_correlation: 0.3,
346            });
347        }
348        Ok(performances)
349    }
350
351    /// Generate ensemble candidates using different strategies
352    fn generate_ensemble_candidates(
353        &self,
354        individual_performances: &[ModelPerformance],
355    ) -> Result<Vec<EnsembleCandidate>> {
356        let mut candidates = Vec::new();
357
358        for strategy in &self.config.candidate_strategies {
359            if self.config.use_greedy_selection {
360                // Use greedy selection to build ensemble
361                let ensemble =
362                    self.greedy_ensemble_selection(individual_performances, strategy.clone())?;
363                candidates.push(ensemble);
364            } else {
365                // Try different subset sizes
366                for size in self.config.min_ensemble_size
367                    ..=self
368                        .config
369                        .max_ensemble_size
370                        .min(individual_performances.len())
371                {
372                    let ensemble = self.select_diverse_subset(
373                        individual_performances,
374                        size,
375                        strategy.clone(),
376                    )?;
377                    candidates.push(ensemble);
378                }
379            }
380        }
381
382        Ok(candidates)
383    }
384
385    /// Greedy ensemble selection algorithm
386    fn greedy_ensemble_selection(
387        &self,
388        individual_performances: &[ModelPerformance],
389        strategy: EnsembleStrategy,
390    ) -> Result<EnsembleCandidate> {
391        let mut selected_indices = Vec::new();
392        let mut remaining_indices: Vec<usize> = (0..individual_performances.len()).collect();
393
394        // Start with the best individual model
395        let best_idx = individual_performances
396            .iter()
397            .enumerate()
398            .max_by(|(_, a), (_, b)| {
399                a.cv_score
400                    .partial_cmp(&b.cv_score)
401                    .expect("operation should succeed")
402            })
403            .map(|(idx, _)| idx)
404            .expect("operation should succeed");
405
406        selected_indices.push(best_idx);
407        remaining_indices.retain(|&x| x != best_idx);
408
409        // Greedily add models that improve ensemble performance
410        while selected_indices.len() < self.config.max_ensemble_size
411            && !remaining_indices.is_empty()
412        {
413            let mut best_addition = None;
414            let mut best_improvement = 0.0;
415
416            for &candidate_idx in &remaining_indices {
417                let mut test_ensemble = selected_indices.clone();
418                test_ensemble.push(candidate_idx);
419
420                // Check diversity
421                let diversity =
422                    self.calculate_subset_diversity(individual_performances, &test_ensemble);
423                if diversity < self.config.diversity_threshold {
424                    continue;
425                }
426
427                // Estimate performance improvement (simplified)
428                let estimated_improvement =
429                    self.estimate_ensemble_improvement(individual_performances, &test_ensemble);
430
431                if estimated_improvement > best_improvement + self.config.improvement_threshold {
432                    best_improvement = estimated_improvement;
433                    best_addition = Some(candidate_idx);
434                }
435            }
436
437            match best_addition {
438                Some(idx) => {
439                    selected_indices.push(idx);
440                    remaining_indices.retain(|&x| x != idx);
441                }
442                None => break, // No more beneficial additions
443            }
444        }
445
446        self.create_ensemble_candidate(individual_performances, selected_indices, strategy)
447    }
448
449    /// Select diverse subset of models
450    fn select_diverse_subset(
451        &self,
452        individual_performances: &[ModelPerformance],
453        subset_size: usize,
454        strategy: EnsembleStrategy,
455    ) -> Result<EnsembleCandidate> {
456        // Simple strategy: select top performers with diversity constraint
457        let mut candidates: Vec<(usize, f64)> = individual_performances
458            .iter()
459            .enumerate()
460            .map(|(idx, perf)| (idx, perf.cv_score))
461            .collect();
462
463        // Sort by performance
464        candidates.sort_by(|(_, a), (_, b)| b.partial_cmp(a).expect("operation should succeed"));
465
466        let mut selected_indices = Vec::new();
467        for (idx, _) in candidates {
468            if selected_indices.len() >= subset_size {
469                break;
470            }
471
472            // Check diversity constraint
473            let mut test_ensemble = selected_indices.clone();
474            test_ensemble.push(idx);
475
476            let diversity =
477                self.calculate_subset_diversity(individual_performances, &test_ensemble);
478            if diversity >= self.config.diversity_threshold || selected_indices.is_empty() {
479                selected_indices.push(idx);
480            }
481        }
482
483        self.create_ensemble_candidate(individual_performances, selected_indices, strategy)
484    }
485
486    /// Create ensemble candidate from selected model indices
487    fn create_ensemble_candidate(
488        &self,
489        individual_performances: &[ModelPerformance],
490        selected_indices: Vec<usize>,
491        strategy: EnsembleStrategy,
492    ) -> Result<EnsembleCandidate> {
493        let model_weights =
494            self.calculate_model_weights(&selected_indices, individual_performances, &strategy);
495
496        let selected_models = selected_indices
497            .iter()
498            .enumerate()
499            .map(|(i, &model_idx)| {
500                let perf = &individual_performances[model_idx];
501                ModelInfo {
502                    model_index: model_idx,
503                    model_name: perf.model_name.clone(),
504                    weight: model_weights[i],
505                    individual_score: perf.cv_score,
506                    contribution_score: 0.0, // Will be calculated during evaluation
507                }
508            })
509            .collect();
510
511        Ok(EnsembleCandidate {
512            ensemble_strategy: strategy,
513            selected_models,
514            model_weights,
515        })
516    }
517
518    /// Calculate model weights based on strategy
519    fn calculate_model_weights(
520        &self,
521        selected_indices: &[usize],
522        individual_performances: &[ModelPerformance],
523        strategy: &EnsembleStrategy,
524    ) -> Vec<f64> {
525        match strategy {
526            EnsembleStrategy::Voting => {
527                // Equal weights
528                vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
529            }
530            EnsembleStrategy::WeightedVoting => {
531                // Weights based on individual performance
532                let scores: Vec<f64> = selected_indices
533                    .iter()
534                    .map(|&idx| individual_performances[idx].cv_score.max(0.0))
535                    .collect();
536                let sum: f64 = scores.iter().sum();
537                if sum > 0.0 {
538                    scores.iter().map(|&s| s / sum).collect()
539                } else {
540                    vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
541                }
542            }
543            EnsembleStrategy::BayesianAveraging => {
544                // Bayesian model averaging weights (simplified)
545                let log_likelihoods: Vec<f64> = selected_indices
546                    .iter()
547                    .map(|&idx| individual_performances[idx].cv_score)
548                    .collect();
549
550                let max_ll = log_likelihoods
551                    .iter()
552                    .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
553                let exp_weights: Vec<f64> = log_likelihoods
554                    .iter()
555                    .map(|&ll| (ll - max_ll).exp())
556                    .collect();
557                let sum: f64 = exp_weights.iter().sum();
558
559                if sum > 0.0 {
560                    exp_weights.iter().map(|&w| w / sum).collect()
561                } else {
562                    vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
563                }
564            }
565            _ => {
566                // Default to equal weights for other strategies
567                vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
568            }
569        }
570    }
571
572    /// Evaluate an ensemble candidate
573    fn evaluate_ensemble_candidate<E, X, Y>(
574        &self,
575        _models: &[(E, String)],
576        candidate: &EnsembleCandidate,
577        x: &[X],
578        _y: &[Y],
579        cv: &dyn CrossValidator,
580        _scoring: &dyn Scoring,
581    ) -> Result<EnsemblePerformance>
582    where
583        E: Estimator + Clone,
584        X: Clone,
585        Y: Clone + Into<f64>,
586    {
587        let n_samples = x.len();
588        let splits = cv.split(n_samples, None);
589        let mut fold_scores = Vec::with_capacity(splits.len());
590
591        for (train_indices, _test_indices) in &splits {
592            // Placeholder implementation - in a real implementation, this would:
593            // 1. Create train and test sets from indices
594            // 2. Train ensemble models on training data
595            // 3. Make ensemble predictions on test data
596            // 4. Calculate score using the scoring function
597
598            // For now, just generate dummy scores
599            let dummy_score = 0.8 + (train_indices.len() as f64) * 0.01 / 100.0;
600            fold_scores.push(dummy_score);
601        }
602
603        let mean_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
604        let std_score = self.calculate_std(&fold_scores, mean_score);
605
606        Ok(EnsemblePerformance {
607            mean_score,
608            std_score,
609            fold_scores,
610            improvement_over_best: 0.0, // Will be set later
611            ensemble_size: candidate.selected_models.len(),
612        })
613    }
614
615    /// Make ensemble predictions
616    fn make_ensemble_predictions<T, X>(
617        &self,
618        trained_models: &[T],
619        x_test: &[X],
620        weights: &[f64],
621        strategy: &EnsembleStrategy,
622    ) -> Result<Vec<f64>>
623    where
624        T: Predict<Vec<X>, Vec<f64>>,
625        X: Clone,
626    {
627        if trained_models.is_empty() {
628            return Err(SklearsError::InvalidParameter {
629                name: "trained_models".to_string(),
630                reason: "no trained models provided".to_string(),
631            });
632        }
633
634        // Get predictions from all models
635        let mut all_predictions = Vec::with_capacity(trained_models.len());
636        let x_test_vec = x_test.to_vec();
637        for model in trained_models {
638            let predictions = model.predict(&x_test_vec)?;
639            all_predictions.push(predictions);
640        }
641
642        if all_predictions.is_empty() {
643            return Ok(vec![]);
644        }
645
646        let n_samples = all_predictions[0].len();
647        let mut ensemble_predictions = vec![0.0; n_samples];
648
649        match strategy {
650            EnsembleStrategy::Voting
651            | EnsembleStrategy::WeightedVoting
652            | EnsembleStrategy::BayesianAveraging => {
653                // Weighted average
654                for i in 0..n_samples {
655                    let mut weighted_sum = 0.0;
656                    for (model_idx, predictions) in all_predictions.iter().enumerate() {
657                        if i < predictions.len() {
658                            weighted_sum += predictions[i] * weights[model_idx];
659                        }
660                    }
661                    ensemble_predictions[i] = weighted_sum;
662                }
663            }
664            EnsembleStrategy::Stacking { .. } => {
665                // For now, use weighted average (meta-learner training would require more complex implementation)
666                for i in 0..n_samples {
667                    let mut weighted_sum = 0.0;
668                    for (model_idx, predictions) in all_predictions.iter().enumerate() {
669                        if i < predictions.len() {
670                            weighted_sum += predictions[i] * weights[model_idx];
671                        }
672                    }
673                    ensemble_predictions[i] = weighted_sum;
674                }
675            }
676            EnsembleStrategy::Blending { .. } => {
677                // Similar to stacking for this implementation
678                for i in 0..n_samples {
679                    let mut weighted_sum = 0.0;
680                    for (model_idx, predictions) in all_predictions.iter().enumerate() {
681                        if i < predictions.len() {
682                            weighted_sum += predictions[i] * weights[model_idx];
683                        }
684                    }
685                    ensemble_predictions[i] = weighted_sum;
686                }
687            }
688            EnsembleStrategy::DynamicSelection => {
689                // For now, use the best model per sample (simplified)
690                for i in 0..n_samples {
691                    if all_predictions[0].len() > i {
692                        let mut best_pred = all_predictions[0][i];
693                        let mut best_weight = weights[0];
694
695                        for (model_idx, predictions) in all_predictions.iter().enumerate() {
696                            if predictions.len() > i && weights[model_idx] > best_weight {
697                                best_pred = predictions[i];
698                                best_weight = weights[model_idx];
699                            }
700                        }
701                        ensemble_predictions[i] = best_pred;
702                    }
703                }
704            }
705        }
706
707        Ok(ensemble_predictions)
708    }
709
710    /// Calculate diversity measures for the ensemble
711    fn calculate_diversity_measures<E, X, Y>(
712        &self,
713        _models: &[(E, String)],
714        _selected_models: &[ModelInfo],
715        _x: &[X],
716        _y: &[Y],
717    ) -> Result<DiversityMeasures>
718    where
719        E: Estimator + Clone,
720        X: Clone,
721        Y: Clone + Into<f64>,
722    {
723        // Simplified diversity calculation - in a full implementation,
724        // this would require re-training models and comparing predictions
725        Ok(DiversityMeasures {
726            avg_correlation: 0.3,   // Placeholder
727            disagreement: 0.2,      // Placeholder
728            q_statistic: 0.1,       // Placeholder
729            entropy_diversity: 0.4, // Placeholder
730        })
731    }
732
733    /// Calculate diversity of a subset of models
734    fn calculate_subset_diversity(
735        &self,
736        individual_performances: &[ModelPerformance],
737        subset_indices: &[usize],
738    ) -> f64 {
739        if subset_indices.len() <= 1 {
740            return 0.0;
741        }
742
743        // Average pairwise correlation (lower correlation = higher diversity)
744        let mut correlations = Vec::new();
745        for i in 0..subset_indices.len() {
746            for j in (i + 1)..subset_indices.len() {
747                let corr1 = individual_performances[subset_indices[i]].avg_correlation;
748                let corr2 = individual_performances[subset_indices[j]].avg_correlation;
749                correlations.push((corr1 + corr2) / 2.0);
750            }
751        }
752
753        if correlations.is_empty() {
754            0.0
755        } else {
756            let avg_correlation = correlations.iter().sum::<f64>() / correlations.len() as f64;
757            1.0 - avg_correlation.abs() // Higher diversity when correlation is lower
758        }
759    }
760
761    /// Estimate ensemble improvement (simplified heuristic)
762    fn estimate_ensemble_improvement(
763        &self,
764        individual_performances: &[ModelPerformance],
765        ensemble_indices: &[usize],
766    ) -> f64 {
767        if ensemble_indices.is_empty() {
768            return 0.0;
769        }
770
771        // Simple heuristic: weighted average with diversity bonus
772        let avg_score = ensemble_indices
773            .iter()
774            .map(|&idx| individual_performances[idx].cv_score)
775            .sum::<f64>()
776            / ensemble_indices.len() as f64;
777
778        let diversity_bonus =
779            self.calculate_subset_diversity(individual_performances, ensemble_indices) * 0.1;
780
781        avg_score + diversity_bonus
782    }
783
784    /// Calculate standard deviation
785    fn calculate_std(&self, values: &[f64], mean: f64) -> f64 {
786        if values.len() <= 1 {
787            return 0.0;
788        }
789
790        let variance =
791            values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
792
793        variance.sqrt()
794    }
795
796    /// Calculate correlation between two prediction vectors
797    fn calculate_correlation(&self, pred1: &[f64], pred2: &[f64]) -> f64 {
798        if pred1.len() != pred2.len() || pred1.is_empty() {
799            return 0.0;
800        }
801
802        let n = pred1.len() as f64;
803        let mean1 = pred1.iter().sum::<f64>() / n;
804        let mean2 = pred2.iter().sum::<f64>() / n;
805
806        let mut numerator = 0.0;
807        let mut sum_sq1 = 0.0;
808        let mut sum_sq2 = 0.0;
809
810        for i in 0..pred1.len() {
811            let diff1 = pred1[i] - mean1;
812            let diff2 = pred2[i] - mean2;
813            numerator += diff1 * diff2;
814            sum_sq1 += diff1 * diff1;
815            sum_sq2 += diff2 * diff2;
816        }
817
818        let denominator = (sum_sq1 * sum_sq2).sqrt();
819        if denominator > 0.0 {
820            numerator / denominator
821        } else {
822            0.0
823        }
824    }
825}
826
827impl Default for EnsembleSelector {
828    fn default() -> Self {
829        Self::new()
830    }
831}
832
833/// Internal ensemble candidate structure
834#[derive(Debug, Clone)]
835struct EnsembleCandidate {
836    ensemble_strategy: EnsembleStrategy,
837    selected_models: Vec<ModelInfo>,
838    model_weights: Vec<f64>,
839}
840
841/// Convenience function for ensemble selection
842pub fn select_ensemble<E, X, Y>(
843    models: &[(E, String)],
844    x: &[X],
845    y: &[Y],
846    cv: &dyn CrossValidator,
847    scoring: &dyn Scoring,
848    max_size: Option<usize>,
849) -> Result<EnsembleSelectionResult>
850where
851    E: Estimator + Fit<Vec<X>, Vec<Y>> + Clone,
852    E::Fitted: Predict<Vec<X>, Vec<f64>>,
853    X: Clone,
854    Y: Clone + Into<f64>,
855{
856    let mut selector = EnsembleSelector::new();
857    if let Some(size) = max_size {
858        selector = selector.max_ensemble_size(size);
859    }
860    selector.select_ensemble(models, x, y, cv, scoring)
861}
862
863#[allow(non_snake_case)]
864#[cfg(test)]
865mod tests {
866    use super::*;
867    use crate::cross_validation::KFold;
868
869    // Mock estimator for testing
870    #[derive(Clone)]
871    struct MockEstimator {
872        performance_level: f64,
873    }
874
875    struct MockTrained {
876        performance_level: f64,
877    }
878
879    impl Estimator for MockEstimator {
880        type Config = ();
881        type Error = SklearsError;
882        type Float = f64;
883
884        fn config(&self) -> &Self::Config {
885            &()
886        }
887    }
888
889    impl Fit<Vec<f64>, Vec<f64>> for MockEstimator {
890        type Fitted = MockTrained;
891
892        fn fit(self, _x: &Vec<f64>, _y: &Vec<f64>) -> Result<Self::Fitted> {
893            Ok(MockTrained {
894                performance_level: self.performance_level,
895            })
896        }
897    }
898
899    impl Predict<Vec<f64>, Vec<f64>> for MockTrained {
900        fn predict(&self, x: &Vec<f64>) -> Result<Vec<f64>> {
901            Ok(x.iter().map(|&xi| xi * self.performance_level).collect())
902        }
903    }
904
905    // Mock scoring function
906    struct MockScoring;
907
908    impl Scoring for MockScoring {
909        fn score(&self, y_true: &[f64], y_pred: &[f64]) -> Result<f64> {
910            let mse = y_true
911                .iter()
912                .zip(y_pred.iter())
913                .map(|(&true_val, &pred_val)| (true_val - pred_val).powi(2))
914                .sum::<f64>()
915                / y_true.len() as f64;
916            Ok(-mse) // Higher is better
917        }
918    }
919
920    #[test]
921    fn test_ensemble_selector_creation() {
922        let selector = EnsembleSelector::new();
923        assert_eq!(selector.config.max_ensemble_size, 10);
924        assert_eq!(selector.config.min_ensemble_size, 2);
925        assert!(selector.config.use_greedy_selection);
926    }
927
928    #[test]
929    fn test_ensemble_selection() {
930        let models = vec![
931            (
932                MockEstimator {
933                    performance_level: 0.8,
934                },
935                "Model A".to_string(),
936            ),
937            (
938                MockEstimator {
939                    performance_level: 0.9,
940                },
941                "Model B".to_string(),
942            ),
943            (
944                MockEstimator {
945                    performance_level: 0.85,
946                },
947                "Model C".to_string(),
948            ),
949            (
950                MockEstimator {
951                    performance_level: 0.75,
952                },
953                "Model D".to_string(),
954            ),
955        ];
956
957        let x: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
958        let y: Vec<f64> = x.iter().map(|&xi| xi * 0.5).collect();
959        let cv = KFold::new(3);
960        let scoring = MockScoring;
961
962        let selector = EnsembleSelector::new().max_ensemble_size(3);
963        let result = selector.select_ensemble(&models, &x, &y, &cv, &scoring);
964
965        assert!(result.is_ok());
966        let result = result.expect("operation should succeed");
967        assert!(result.selected_models.len() >= 2);
968        assert!(result.selected_models.len() <= 3);
969        assert_eq!(result.model_weights.len(), result.selected_models.len());
970        assert!(!result.individual_performances.is_empty());
971    }
972
973    #[test]
974    fn test_different_ensemble_strategies() {
975        let models = vec![
976            (
977                MockEstimator {
978                    performance_level: 0.9,
979                },
980                "Good Model".to_string(),
981            ),
982            (
983                MockEstimator {
984                    performance_level: 0.8,
985                },
986                "Decent Model".to_string(),
987            ),
988        ];
989
990        let x: Vec<f64> = (0..50).map(|i| i as f64 * 0.05).collect();
991        let y: Vec<f64> = x.iter().map(|&xi| xi * 0.3).collect();
992        let cv = KFold::new(3);
993        let scoring = MockScoring;
994
995        let strategies = vec![
996            EnsembleStrategy::Voting,
997            EnsembleStrategy::WeightedVoting,
998            EnsembleStrategy::BayesianAveraging,
999        ];
1000
1001        let selector = EnsembleSelector::new().strategies(strategies);
1002        let result = selector.select_ensemble(&models, &x, &y, &cv, &scoring);
1003
1004        assert!(result.is_ok());
1005        let result = result.expect("operation should succeed");
1006        assert_eq!(result.selected_models.len(), 2);
1007    }
1008
1009    #[test]
1010    fn test_convenience_function() {
1011        let models = vec![
1012            (
1013                MockEstimator {
1014                    performance_level: 0.95,
1015                },
1016                "Best Model".to_string(),
1017            ),
1018            (
1019                MockEstimator {
1020                    performance_level: 0.85,
1021                },
1022                "Good Model".to_string(),
1023            ),
1024            (
1025                MockEstimator {
1026                    performance_level: 0.8,
1027                },
1028                "Okay Model".to_string(),
1029            ),
1030        ];
1031
1032        let x: Vec<f64> = (0..40).map(|i| i as f64 * 0.1).collect();
1033        let y: Vec<f64> = x.iter().map(|&xi| xi * 0.4).collect();
1034        let cv = KFold::new(3);
1035        let scoring = MockScoring;
1036
1037        let result = select_ensemble(&models, &x, &y, &cv, &scoring, Some(2));
1038        assert!(result.is_ok());
1039
1040        let result = result.expect("operation should succeed");
1041        assert!(result.selected_models.len() <= 2);
1042        assert!(result.ensemble_performance.ensemble_size <= 2);
1043    }
1044
1045    #[test]
1046    fn test_ensemble_strategy_display() {
1047        assert_eq!(format!("{}", EnsembleStrategy::Voting), "Simple Voting");
1048        assert_eq!(
1049            format!("{}", EnsembleStrategy::WeightedVoting),
1050            "Weighted Voting"
1051        );
1052        assert_eq!(
1053            format!(
1054                "{}",
1055                EnsembleStrategy::Stacking {
1056                    meta_learner: "Linear".to_string()
1057                }
1058            ),
1059            "Stacking (Linear)"
1060        );
1061        assert_eq!(
1062            format!("{}", EnsembleStrategy::Blending { blend_ratio: 0.2 }),
1063            "Blending (ratio: 0.20)"
1064        );
1065    }
1066
1067    #[test]
1068    fn test_insufficient_models() {
1069        let models = vec![(
1070            MockEstimator {
1071                performance_level: 0.9,
1072            },
1073            "Only Model".to_string(),
1074        )];
1075
1076        let x: Vec<f64> = (0..20).map(|i| i as f64 * 0.1).collect();
1077        let y: Vec<f64> = x.iter().map(|&xi| xi * 0.5).collect();
1078        let cv = KFold::new(3);
1079        let scoring = MockScoring;
1080
1081        let selector = EnsembleSelector::new();
1082        let result = selector.select_ensemble(&models, &x, &y, &cv, &scoring);
1083
1084        assert!(result.is_err());
1085    }
1086}