sklears_model_selection/
ensemble_selection.rs

1//! Ensemble model selection with automatic composition strategies
2//!
3//! This module provides tools for automatically selecting and composing ensemble models.
4//! It includes various ensemble strategies like voting, stacking, blending, and dynamic
5//! selection with automatic hyperparameter optimization for ensemble components.
6
7use crate::cross_validation::CrossValidator;
8use sklears_core::{
9    error::{Result, SklearsError},
10    traits::{Estimator, Fit, Predict},
11};
12
13// Simple scoring trait for testing
14pub trait Scoring {
15    fn score(&self, y_true: &[f64], y_pred: &[f64]) -> Result<f64>;
16}
17use std::fmt::{self, Display, Formatter};
18
19/// Result of ensemble model selection
20#[derive(Debug, Clone)]
21pub struct EnsembleSelectionResult {
22    /// Selected ensemble strategy
23    pub ensemble_strategy: EnsembleStrategy,
24    /// Base models included in the ensemble
25    pub selected_models: Vec<ModelInfo>,
26    /// Ensemble weights (if applicable)
27    pub model_weights: Vec<f64>,
28    /// Cross-validation performance of the ensemble
29    pub ensemble_performance: EnsemblePerformance,
30    /// Individual model performances
31    pub individual_performances: Vec<ModelPerformance>,
32    /// Diversity measures
33    pub diversity_measures: DiversityMeasures,
34}
35
36/// Information about a model in the ensemble
37#[derive(Debug, Clone)]
38pub struct ModelInfo {
39    /// Model index in the original candidate list
40    pub model_index: usize,
41    /// Model name/identifier
42    pub model_name: String,
43    /// Model weight in the ensemble
44    pub weight: f64,
45    /// Individual performance score
46    pub individual_score: f64,
47    /// Contribution to ensemble performance
48    pub contribution_score: f64,
49}
50
51/// Performance metrics for the ensemble
52#[derive(Debug, Clone)]
53pub struct EnsemblePerformance {
54    /// Mean cross-validation score
55    pub mean_score: f64,
56    /// Standard deviation of CV scores
57    pub std_score: f64,
58    /// Individual fold scores
59    pub fold_scores: Vec<f64>,
60    /// Improvement over best individual model
61    pub improvement_over_best: f64,
62    /// Ensemble size
63    pub ensemble_size: usize,
64}
65
66/// Performance metrics for individual models
67#[derive(Debug, Clone)]
68pub struct ModelPerformance {
69    /// Model index
70    pub model_index: usize,
71    /// Model name
72    pub model_name: String,
73    /// Cross-validation score
74    pub cv_score: f64,
75    /// Standard deviation
76    pub cv_std: f64,
77    /// Correlation with other models
78    pub avg_correlation: f64,
79}
80
81/// Diversity measures for the ensemble
82#[derive(Debug, Clone)]
83pub struct DiversityMeasures {
84    /// Average pairwise correlation between predictions
85    pub avg_correlation: f64,
86    /// Disagreement measure
87    pub disagreement: f64,
88    /// Q statistic (average pairwise Q statistic)
89    pub q_statistic: f64,
90    /// Entropy-based diversity
91    pub entropy_diversity: f64,
92}
93
94/// Ensemble composition strategies
95#[derive(Debug, Clone, PartialEq)]
96pub enum EnsembleStrategy {
97    /// Simple voting (equal weights)
98    Voting,
99    /// Weighted voting based on individual performance
100    WeightedVoting,
101    /// Stacking with meta-learner
102    Stacking { meta_learner: String },
103    /// Blending (holdout-based stacking)
104    Blending { blend_ratio: f64 },
105    /// Dynamic selection based on instance
106    DynamicSelection,
107    /// Bayesian model averaging
108    BayesianAveraging,
109}
110
111impl Display for EnsembleStrategy {
112    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
113        match self {
114            EnsembleStrategy::Voting => write!(f, "Simple Voting"),
115            EnsembleStrategy::WeightedVoting => write!(f, "Weighted Voting"),
116            EnsembleStrategy::Stacking { meta_learner } => write!(f, "Stacking ({})", meta_learner),
117            EnsembleStrategy::Blending { blend_ratio } => {
118                write!(f, "Blending (ratio: {:.2})", blend_ratio)
119            }
120            EnsembleStrategy::DynamicSelection => write!(f, "Dynamic Selection"),
121            EnsembleStrategy::BayesianAveraging => write!(f, "Bayesian Averaging"),
122        }
123    }
124}
125
126impl Display for EnsembleSelectionResult {
127    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
128        writeln!(f, "Ensemble Selection Results:")?;
129        writeln!(f, "Strategy: {}", self.ensemble_strategy)?;
130        writeln!(f, "Ensemble Size: {}", self.selected_models.len())?;
131        writeln!(
132            f,
133            "Ensemble Performance: {:.4} ± {:.4}",
134            self.ensemble_performance.mean_score, self.ensemble_performance.std_score
135        )?;
136        writeln!(
137            f,
138            "Improvement over Best Individual: {:.4}",
139            self.ensemble_performance.improvement_over_best
140        )?;
141        writeln!(
142            f,
143            "Average Diversity (Correlation): {:.4}",
144            self.diversity_measures.avg_correlation
145        )?;
146        writeln!(f, "\nSelected Models:")?;
147        for model in &self.selected_models {
148            writeln!(
149                f,
150                "  {} - Weight: {:.3}, Score: {:.4}",
151                model.model_name, model.weight, model.individual_score
152            )?;
153        }
154        Ok(())
155    }
156}
157
158/// Configuration for ensemble selection
159#[derive(Debug, Clone)]
160pub struct EnsembleSelectionConfig {
161    /// Maximum ensemble size
162    pub max_ensemble_size: usize,
163    /// Minimum ensemble size
164    pub min_ensemble_size: usize,
165    /// Strategies to consider
166    pub candidate_strategies: Vec<EnsembleStrategy>,
167    /// Diversity threshold (minimum required diversity)
168    pub diversity_threshold: f64,
169    /// Whether to use greedy selection
170    pub use_greedy_selection: bool,
171    /// Performance improvement threshold
172    pub improvement_threshold: f64,
173    /// Cross-validation folds for ensemble evaluation
174    pub cv_folds: usize,
175    /// Random seed for reproducibility
176    pub random_seed: Option<u64>,
177}
178
179impl Default for EnsembleSelectionConfig {
180    fn default() -> Self {
181        Self {
182            max_ensemble_size: 10,
183            min_ensemble_size: 2,
184            candidate_strategies: vec![
185                EnsembleStrategy::Voting,
186                EnsembleStrategy::WeightedVoting,
187                EnsembleStrategy::Stacking {
188                    meta_learner: "Linear".to_string(),
189                },
190                EnsembleStrategy::Blending { blend_ratio: 0.2 },
191            ],
192            diversity_threshold: 0.1,
193            use_greedy_selection: true,
194            improvement_threshold: 0.01,
195            cv_folds: 5,
196            random_seed: None,
197        }
198    }
199}
200
201/// Ensemble model selector
202pub struct EnsembleSelector {
203    config: EnsembleSelectionConfig,
204}
205
206impl EnsembleSelector {
207    /// Create a new ensemble selector with default configuration
208    pub fn new() -> Self {
209        Self {
210            config: EnsembleSelectionConfig::default(),
211        }
212    }
213
214    /// Create a new ensemble selector with custom configuration
215    pub fn with_config(config: EnsembleSelectionConfig) -> Self {
216        Self { config }
217    }
218
219    /// Set maximum ensemble size
220    pub fn max_ensemble_size(mut self, size: usize) -> Self {
221        self.config.max_ensemble_size = size;
222        self
223    }
224
225    /// Set minimum ensemble size
226    pub fn min_ensemble_size(mut self, size: usize) -> Self {
227        self.config.min_ensemble_size = size;
228        self
229    }
230
231    /// Set candidate strategies
232    pub fn strategies(mut self, strategies: Vec<EnsembleStrategy>) -> Self {
233        self.config.candidate_strategies = strategies;
234        self
235    }
236
237    /// Set diversity threshold
238    pub fn diversity_threshold(mut self, threshold: f64) -> Self {
239        self.config.diversity_threshold = threshold;
240        self
241    }
242
243    /// Enable or disable greedy selection
244    pub fn use_greedy_selection(mut self, use_greedy: bool) -> Self {
245        self.config.use_greedy_selection = use_greedy;
246        self
247    }
248
249    /// Select optimal ensemble from candidate models
250    pub fn select_ensemble<E, X, Y>(
251        &self,
252        models: &[(E, String)],
253        x: &[X],
254        y: &[Y],
255        cv: &dyn CrossValidator,
256        scoring: &dyn Scoring,
257    ) -> Result<EnsembleSelectionResult>
258    where
259        E: Estimator + Clone,
260        X: Clone,
261        Y: Clone + Into<f64>,
262    {
263        if models.len() < self.config.min_ensemble_size {
264            return Err(SklearsError::InvalidParameter {
265                name: "models".to_string(),
266                reason: format!(
267                    "at least {} models required for ensemble",
268                    self.config.min_ensemble_size
269                ),
270            });
271        }
272
273        // Evaluate individual models
274        let individual_performances = self.evaluate_individual_models(models, x, y, cv, scoring)?;
275
276        // Generate ensemble candidates
277        let ensemble_candidates = self.generate_ensemble_candidates(&individual_performances)?;
278
279        // Evaluate ensemble candidates
280        let mut best_ensemble = None;
281        let mut best_score = f64::NEG_INFINITY;
282
283        for candidate in &ensemble_candidates {
284            let ensemble_performance =
285                self.evaluate_ensemble_candidate(models, candidate, x, y, cv, scoring)?;
286
287            if ensemble_performance.mean_score > best_score {
288                best_score = ensemble_performance.mean_score;
289                best_ensemble = Some((candidate.clone(), ensemble_performance));
290            }
291        }
292
293        let (best_candidate, ensemble_performance) =
294            best_ensemble.ok_or_else(|| SklearsError::InvalidParameter {
295                name: "ensemble".to_string(),
296                reason: "no valid ensemble found".to_string(),
297            })?;
298
299        // Calculate diversity measures
300        let diversity_measures =
301            self.calculate_diversity_measures(models, &best_candidate.selected_models, x, y)?;
302
303        // Calculate improvement over best individual model
304        let best_individual_score = individual_performances
305            .iter()
306            .map(|p| p.cv_score)
307            .fold(f64::NEG_INFINITY, f64::max);
308
309        let mut ensemble_performance = ensemble_performance;
310        ensemble_performance.improvement_over_best =
311            ensemble_performance.mean_score - best_individual_score;
312
313        Ok(EnsembleSelectionResult {
314            ensemble_strategy: best_candidate.ensemble_strategy,
315            selected_models: best_candidate.selected_models,
316            model_weights: best_candidate.model_weights,
317            ensemble_performance,
318            individual_performances,
319            diversity_measures,
320        })
321    }
322
323    /// Evaluate individual model performances
324    fn evaluate_individual_models<E, X, Y>(
325        &self,
326        models: &[(E, String)],
327        _x: &[X],
328        _y: &[Y],
329        _cv: &dyn CrossValidator,
330        _scoring: &dyn Scoring,
331    ) -> Result<Vec<ModelPerformance>>
332    where
333        E: Estimator + Clone,
334        X: Clone,
335        Y: Clone + Into<f64>,
336    {
337        // Placeholder implementation - create dummy performance data
338        let mut performances = Vec::new();
339        for (idx, (_, name)) in models.iter().enumerate() {
340            performances.push(ModelPerformance {
341                model_index: idx,
342                model_name: name.clone(),
343                cv_score: 0.8 + (idx as f64) * 0.05, // Dummy scores
344                cv_std: 0.1,
345                avg_correlation: 0.3,
346            });
347        }
348        Ok(performances)
349    }
350
351    /// Generate ensemble candidates using different strategies
352    fn generate_ensemble_candidates(
353        &self,
354        individual_performances: &[ModelPerformance],
355    ) -> Result<Vec<EnsembleCandidate>> {
356        let mut candidates = Vec::new();
357
358        for strategy in &self.config.candidate_strategies {
359            if self.config.use_greedy_selection {
360                // Use greedy selection to build ensemble
361                let ensemble =
362                    self.greedy_ensemble_selection(individual_performances, strategy.clone())?;
363                candidates.push(ensemble);
364            } else {
365                // Try different subset sizes
366                for size in self.config.min_ensemble_size
367                    ..=self
368                        .config
369                        .max_ensemble_size
370                        .min(individual_performances.len())
371                {
372                    let ensemble = self.select_diverse_subset(
373                        individual_performances,
374                        size,
375                        strategy.clone(),
376                    )?;
377                    candidates.push(ensemble);
378                }
379            }
380        }
381
382        Ok(candidates)
383    }
384
385    /// Greedy ensemble selection algorithm
386    fn greedy_ensemble_selection(
387        &self,
388        individual_performances: &[ModelPerformance],
389        strategy: EnsembleStrategy,
390    ) -> Result<EnsembleCandidate> {
391        let mut selected_indices = Vec::new();
392        let mut remaining_indices: Vec<usize> = (0..individual_performances.len()).collect();
393
394        // Start with the best individual model
395        let best_idx = individual_performances
396            .iter()
397            .enumerate()
398            .max_by(|(_, a), (_, b)| a.cv_score.partial_cmp(&b.cv_score).unwrap())
399            .map(|(idx, _)| idx)
400            .unwrap();
401
402        selected_indices.push(best_idx);
403        remaining_indices.retain(|&x| x != best_idx);
404
405        // Greedily add models that improve ensemble performance
406        while selected_indices.len() < self.config.max_ensemble_size
407            && !remaining_indices.is_empty()
408        {
409            let mut best_addition = None;
410            let mut best_improvement = 0.0;
411
412            for &candidate_idx in &remaining_indices {
413                let mut test_ensemble = selected_indices.clone();
414                test_ensemble.push(candidate_idx);
415
416                // Check diversity
417                let diversity =
418                    self.calculate_subset_diversity(individual_performances, &test_ensemble);
419                if diversity < self.config.diversity_threshold {
420                    continue;
421                }
422
423                // Estimate performance improvement (simplified)
424                let estimated_improvement =
425                    self.estimate_ensemble_improvement(individual_performances, &test_ensemble);
426
427                if estimated_improvement > best_improvement + self.config.improvement_threshold {
428                    best_improvement = estimated_improvement;
429                    best_addition = Some(candidate_idx);
430                }
431            }
432
433            match best_addition {
434                Some(idx) => {
435                    selected_indices.push(idx);
436                    remaining_indices.retain(|&x| x != idx);
437                }
438                None => break, // No more beneficial additions
439            }
440        }
441
442        self.create_ensemble_candidate(individual_performances, selected_indices, strategy)
443    }
444
445    /// Select diverse subset of models
446    fn select_diverse_subset(
447        &self,
448        individual_performances: &[ModelPerformance],
449        subset_size: usize,
450        strategy: EnsembleStrategy,
451    ) -> Result<EnsembleCandidate> {
452        // Simple strategy: select top performers with diversity constraint
453        let mut candidates: Vec<(usize, f64)> = individual_performances
454            .iter()
455            .enumerate()
456            .map(|(idx, perf)| (idx, perf.cv_score))
457            .collect();
458
459        // Sort by performance
460        candidates.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
461
462        let mut selected_indices = Vec::new();
463        for (idx, _) in candidates {
464            if selected_indices.len() >= subset_size {
465                break;
466            }
467
468            // Check diversity constraint
469            let mut test_ensemble = selected_indices.clone();
470            test_ensemble.push(idx);
471
472            let diversity =
473                self.calculate_subset_diversity(individual_performances, &test_ensemble);
474            if diversity >= self.config.diversity_threshold || selected_indices.is_empty() {
475                selected_indices.push(idx);
476            }
477        }
478
479        self.create_ensemble_candidate(individual_performances, selected_indices, strategy)
480    }
481
482    /// Create ensemble candidate from selected model indices
483    fn create_ensemble_candidate(
484        &self,
485        individual_performances: &[ModelPerformance],
486        selected_indices: Vec<usize>,
487        strategy: EnsembleStrategy,
488    ) -> Result<EnsembleCandidate> {
489        let model_weights =
490            self.calculate_model_weights(&selected_indices, individual_performances, &strategy);
491
492        let selected_models = selected_indices
493            .iter()
494            .enumerate()
495            .map(|(i, &model_idx)| {
496                let perf = &individual_performances[model_idx];
497                ModelInfo {
498                    model_index: model_idx,
499                    model_name: perf.model_name.clone(),
500                    weight: model_weights[i],
501                    individual_score: perf.cv_score,
502                    contribution_score: 0.0, // Will be calculated during evaluation
503                }
504            })
505            .collect();
506
507        Ok(EnsembleCandidate {
508            ensemble_strategy: strategy,
509            selected_models,
510            model_weights,
511        })
512    }
513
514    /// Calculate model weights based on strategy
515    fn calculate_model_weights(
516        &self,
517        selected_indices: &[usize],
518        individual_performances: &[ModelPerformance],
519        strategy: &EnsembleStrategy,
520    ) -> Vec<f64> {
521        match strategy {
522            EnsembleStrategy::Voting => {
523                // Equal weights
524                vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
525            }
526            EnsembleStrategy::WeightedVoting => {
527                // Weights based on individual performance
528                let scores: Vec<f64> = selected_indices
529                    .iter()
530                    .map(|&idx| individual_performances[idx].cv_score.max(0.0))
531                    .collect();
532                let sum: f64 = scores.iter().sum();
533                if sum > 0.0 {
534                    scores.iter().map(|&s| s / sum).collect()
535                } else {
536                    vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
537                }
538            }
539            EnsembleStrategy::BayesianAveraging => {
540                // Bayesian model averaging weights (simplified)
541                let log_likelihoods: Vec<f64> = selected_indices
542                    .iter()
543                    .map(|&idx| individual_performances[idx].cv_score)
544                    .collect();
545
546                let max_ll = log_likelihoods
547                    .iter()
548                    .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
549                let exp_weights: Vec<f64> = log_likelihoods
550                    .iter()
551                    .map(|&ll| (ll - max_ll).exp())
552                    .collect();
553                let sum: f64 = exp_weights.iter().sum();
554
555                if sum > 0.0 {
556                    exp_weights.iter().map(|&w| w / sum).collect()
557                } else {
558                    vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
559                }
560            }
561            _ => {
562                // Default to equal weights for other strategies
563                vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
564            }
565        }
566    }
567
568    /// Evaluate an ensemble candidate
569    fn evaluate_ensemble_candidate<E, X, Y>(
570        &self,
571        _models: &[(E, String)],
572        candidate: &EnsembleCandidate,
573        x: &[X],
574        _y: &[Y],
575        cv: &dyn CrossValidator,
576        _scoring: &dyn Scoring,
577    ) -> Result<EnsemblePerformance>
578    where
579        E: Estimator + Clone,
580        X: Clone,
581        Y: Clone + Into<f64>,
582    {
583        let n_samples = x.len();
584        let splits = cv.split(n_samples, None);
585        let mut fold_scores = Vec::with_capacity(splits.len());
586
587        for (train_indices, _test_indices) in &splits {
588            // Placeholder implementation - in a real implementation, this would:
589            // 1. Create train and test sets from indices
590            // 2. Train ensemble models on training data
591            // 3. Make ensemble predictions on test data
592            // 4. Calculate score using the scoring function
593
594            // For now, just generate dummy scores
595            let dummy_score = 0.8 + (train_indices.len() as f64) * 0.01 / 100.0;
596            fold_scores.push(dummy_score);
597        }
598
599        let mean_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
600        let std_score = self.calculate_std(&fold_scores, mean_score);
601
602        Ok(EnsemblePerformance {
603            mean_score,
604            std_score,
605            fold_scores,
606            improvement_over_best: 0.0, // Will be set later
607            ensemble_size: candidate.selected_models.len(),
608        })
609    }
610
611    /// Make ensemble predictions
612    fn make_ensemble_predictions<T, X>(
613        &self,
614        trained_models: &[T],
615        x_test: &[X],
616        weights: &[f64],
617        strategy: &EnsembleStrategy,
618    ) -> Result<Vec<f64>>
619    where
620        T: Predict<Vec<X>, Vec<f64>>,
621        X: Clone,
622    {
623        if trained_models.is_empty() {
624            return Err(SklearsError::InvalidParameter {
625                name: "trained_models".to_string(),
626                reason: "no trained models provided".to_string(),
627            });
628        }
629
630        // Get predictions from all models
631        let mut all_predictions = Vec::with_capacity(trained_models.len());
632        let x_test_vec = x_test.to_vec();
633        for model in trained_models {
634            let predictions = model.predict(&x_test_vec)?;
635            all_predictions.push(predictions);
636        }
637
638        if all_predictions.is_empty() {
639            return Ok(vec![]);
640        }
641
642        let n_samples = all_predictions[0].len();
643        let mut ensemble_predictions = vec![0.0; n_samples];
644
645        match strategy {
646            EnsembleStrategy::Voting
647            | EnsembleStrategy::WeightedVoting
648            | EnsembleStrategy::BayesianAveraging => {
649                // Weighted average
650                for i in 0..n_samples {
651                    let mut weighted_sum = 0.0;
652                    for (model_idx, predictions) in all_predictions.iter().enumerate() {
653                        if i < predictions.len() {
654                            weighted_sum += predictions[i] * weights[model_idx];
655                        }
656                    }
657                    ensemble_predictions[i] = weighted_sum;
658                }
659            }
660            EnsembleStrategy::Stacking { .. } => {
661                // For now, use weighted average (meta-learner training would require more complex implementation)
662                for i in 0..n_samples {
663                    let mut weighted_sum = 0.0;
664                    for (model_idx, predictions) in all_predictions.iter().enumerate() {
665                        if i < predictions.len() {
666                            weighted_sum += predictions[i] * weights[model_idx];
667                        }
668                    }
669                    ensemble_predictions[i] = weighted_sum;
670                }
671            }
672            EnsembleStrategy::Blending { .. } => {
673                // Similar to stacking for this implementation
674                for i in 0..n_samples {
675                    let mut weighted_sum = 0.0;
676                    for (model_idx, predictions) in all_predictions.iter().enumerate() {
677                        if i < predictions.len() {
678                            weighted_sum += predictions[i] * weights[model_idx];
679                        }
680                    }
681                    ensemble_predictions[i] = weighted_sum;
682                }
683            }
684            EnsembleStrategy::DynamicSelection => {
685                // For now, use the best model per sample (simplified)
686                for i in 0..n_samples {
687                    if all_predictions[0].len() > i {
688                        let mut best_pred = all_predictions[0][i];
689                        let mut best_weight = weights[0];
690
691                        for (model_idx, predictions) in all_predictions.iter().enumerate() {
692                            if predictions.len() > i && weights[model_idx] > best_weight {
693                                best_pred = predictions[i];
694                                best_weight = weights[model_idx];
695                            }
696                        }
697                        ensemble_predictions[i] = best_pred;
698                    }
699                }
700            }
701        }
702
703        Ok(ensemble_predictions)
704    }
705
706    /// Calculate diversity measures for the ensemble
707    fn calculate_diversity_measures<E, X, Y>(
708        &self,
709        _models: &[(E, String)],
710        _selected_models: &[ModelInfo],
711        _x: &[X],
712        _y: &[Y],
713    ) -> Result<DiversityMeasures>
714    where
715        E: Estimator + Clone,
716        X: Clone,
717        Y: Clone + Into<f64>,
718    {
719        // Simplified diversity calculation - in a full implementation,
720        // this would require re-training models and comparing predictions
721        Ok(DiversityMeasures {
722            avg_correlation: 0.3,   // Placeholder
723            disagreement: 0.2,      // Placeholder
724            q_statistic: 0.1,       // Placeholder
725            entropy_diversity: 0.4, // Placeholder
726        })
727    }
728
729    /// Calculate diversity of a subset of models
730    fn calculate_subset_diversity(
731        &self,
732        individual_performances: &[ModelPerformance],
733        subset_indices: &[usize],
734    ) -> f64 {
735        if subset_indices.len() <= 1 {
736            return 0.0;
737        }
738
739        // Average pairwise correlation (lower correlation = higher diversity)
740        let mut correlations = Vec::new();
741        for i in 0..subset_indices.len() {
742            for j in (i + 1)..subset_indices.len() {
743                let corr1 = individual_performances[subset_indices[i]].avg_correlation;
744                let corr2 = individual_performances[subset_indices[j]].avg_correlation;
745                correlations.push((corr1 + corr2) / 2.0);
746            }
747        }
748
749        if correlations.is_empty() {
750            0.0
751        } else {
752            let avg_correlation = correlations.iter().sum::<f64>() / correlations.len() as f64;
753            1.0 - avg_correlation.abs() // Higher diversity when correlation is lower
754        }
755    }
756
757    /// Estimate ensemble improvement (simplified heuristic)
758    fn estimate_ensemble_improvement(
759        &self,
760        individual_performances: &[ModelPerformance],
761        ensemble_indices: &[usize],
762    ) -> f64 {
763        if ensemble_indices.is_empty() {
764            return 0.0;
765        }
766
767        // Simple heuristic: weighted average with diversity bonus
768        let avg_score = ensemble_indices
769            .iter()
770            .map(|&idx| individual_performances[idx].cv_score)
771            .sum::<f64>()
772            / ensemble_indices.len() as f64;
773
774        let diversity_bonus =
775            self.calculate_subset_diversity(individual_performances, ensemble_indices) * 0.1;
776
777        avg_score + diversity_bonus
778    }
779
780    /// Calculate standard deviation
781    fn calculate_std(&self, values: &[f64], mean: f64) -> f64 {
782        if values.len() <= 1 {
783            return 0.0;
784        }
785
786        let variance =
787            values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
788
789        variance.sqrt()
790    }
791
792    /// Calculate correlation between two prediction vectors
793    fn calculate_correlation(&self, pred1: &[f64], pred2: &[f64]) -> f64 {
794        if pred1.len() != pred2.len() || pred1.is_empty() {
795            return 0.0;
796        }
797
798        let n = pred1.len() as f64;
799        let mean1 = pred1.iter().sum::<f64>() / n;
800        let mean2 = pred2.iter().sum::<f64>() / n;
801
802        let mut numerator = 0.0;
803        let mut sum_sq1 = 0.0;
804        let mut sum_sq2 = 0.0;
805
806        for i in 0..pred1.len() {
807            let diff1 = pred1[i] - mean1;
808            let diff2 = pred2[i] - mean2;
809            numerator += diff1 * diff2;
810            sum_sq1 += diff1 * diff1;
811            sum_sq2 += diff2 * diff2;
812        }
813
814        let denominator = (sum_sq1 * sum_sq2).sqrt();
815        if denominator > 0.0 {
816            numerator / denominator
817        } else {
818            0.0
819        }
820    }
821}
822
823impl Default for EnsembleSelector {
824    fn default() -> Self {
825        Self::new()
826    }
827}
828
829/// Internal ensemble candidate structure
830#[derive(Debug, Clone)]
831struct EnsembleCandidate {
832    ensemble_strategy: EnsembleStrategy,
833    selected_models: Vec<ModelInfo>,
834    model_weights: Vec<f64>,
835}
836
837/// Convenience function for ensemble selection
838pub fn select_ensemble<E, X, Y>(
839    models: &[(E, String)],
840    x: &[X],
841    y: &[Y],
842    cv: &dyn CrossValidator,
843    scoring: &dyn Scoring,
844    max_size: Option<usize>,
845) -> Result<EnsembleSelectionResult>
846where
847    E: Estimator + Fit<Vec<X>, Vec<Y>> + Clone,
848    E::Fitted: Predict<Vec<X>, Vec<f64>>,
849    X: Clone,
850    Y: Clone + Into<f64>,
851{
852    let mut selector = EnsembleSelector::new();
853    if let Some(size) = max_size {
854        selector = selector.max_ensemble_size(size);
855    }
856    selector.select_ensemble(models, x, y, cv, scoring)
857}
858
859#[allow(non_snake_case)]
860#[cfg(test)]
861mod tests {
862    use super::*;
863    use crate::cross_validation::KFold;
864
865    // Mock estimator for testing
866    #[derive(Clone)]
867    struct MockEstimator {
868        performance_level: f64,
869    }
870
871    struct MockTrained {
872        performance_level: f64,
873    }
874
875    impl Estimator for MockEstimator {
876        type Config = ();
877        type Error = SklearsError;
878        type Float = f64;
879
880        fn config(&self) -> &Self::Config {
881            &()
882        }
883    }
884
885    impl Fit<Vec<f64>, Vec<f64>> for MockEstimator {
886        type Fitted = MockTrained;
887
888        fn fit(self, _x: &Vec<f64>, _y: &Vec<f64>) -> Result<Self::Fitted> {
889            Ok(MockTrained {
890                performance_level: self.performance_level,
891            })
892        }
893    }
894
895    impl Predict<Vec<f64>, Vec<f64>> for MockTrained {
896        fn predict(&self, x: &Vec<f64>) -> Result<Vec<f64>> {
897            Ok(x.iter().map(|&xi| xi * self.performance_level).collect())
898        }
899    }
900
901    // Mock scoring function
902    struct MockScoring;
903
904    impl Scoring for MockScoring {
905        fn score(&self, y_true: &[f64], y_pred: &[f64]) -> Result<f64> {
906            let mse = y_true
907                .iter()
908                .zip(y_pred.iter())
909                .map(|(&true_val, &pred_val)| (true_val - pred_val).powi(2))
910                .sum::<f64>()
911                / y_true.len() as f64;
912            Ok(-mse) // Higher is better
913        }
914    }
915
916    #[test]
917    fn test_ensemble_selector_creation() {
918        let selector = EnsembleSelector::new();
919        assert_eq!(selector.config.max_ensemble_size, 10);
920        assert_eq!(selector.config.min_ensemble_size, 2);
921        assert!(selector.config.use_greedy_selection);
922    }
923
924    #[test]
925    fn test_ensemble_selection() {
926        let models = vec![
927            (
928                MockEstimator {
929                    performance_level: 0.8,
930                },
931                "Model A".to_string(),
932            ),
933            (
934                MockEstimator {
935                    performance_level: 0.9,
936                },
937                "Model B".to_string(),
938            ),
939            (
940                MockEstimator {
941                    performance_level: 0.85,
942                },
943                "Model C".to_string(),
944            ),
945            (
946                MockEstimator {
947                    performance_level: 0.75,
948                },
949                "Model D".to_string(),
950            ),
951        ];
952
953        let x: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
954        let y: Vec<f64> = x.iter().map(|&xi| xi * 0.5).collect();
955        let cv = KFold::new(3);
956        let scoring = MockScoring;
957
958        let selector = EnsembleSelector::new().max_ensemble_size(3);
959        let result = selector.select_ensemble(&models, &x, &y, &cv, &scoring);
960
961        assert!(result.is_ok());
962        let result = result.unwrap();
963        assert!(result.selected_models.len() >= 2);
964        assert!(result.selected_models.len() <= 3);
965        assert_eq!(result.model_weights.len(), result.selected_models.len());
966        assert!(!result.individual_performances.is_empty());
967    }
968
969    #[test]
970    fn test_different_ensemble_strategies() {
971        let models = vec![
972            (
973                MockEstimator {
974                    performance_level: 0.9,
975                },
976                "Good Model".to_string(),
977            ),
978            (
979                MockEstimator {
980                    performance_level: 0.8,
981                },
982                "Decent Model".to_string(),
983            ),
984        ];
985
986        let x: Vec<f64> = (0..50).map(|i| i as f64 * 0.05).collect();
987        let y: Vec<f64> = x.iter().map(|&xi| xi * 0.3).collect();
988        let cv = KFold::new(3);
989        let scoring = MockScoring;
990
991        let strategies = vec![
992            EnsembleStrategy::Voting,
993            EnsembleStrategy::WeightedVoting,
994            EnsembleStrategy::BayesianAveraging,
995        ];
996
997        let selector = EnsembleSelector::new().strategies(strategies);
998        let result = selector.select_ensemble(&models, &x, &y, &cv, &scoring);
999
1000        assert!(result.is_ok());
1001        let result = result.unwrap();
1002        assert_eq!(result.selected_models.len(), 2);
1003    }
1004
1005    #[test]
1006    fn test_convenience_function() {
1007        let models = vec![
1008            (
1009                MockEstimator {
1010                    performance_level: 0.95,
1011                },
1012                "Best Model".to_string(),
1013            ),
1014            (
1015                MockEstimator {
1016                    performance_level: 0.85,
1017                },
1018                "Good Model".to_string(),
1019            ),
1020            (
1021                MockEstimator {
1022                    performance_level: 0.8,
1023                },
1024                "Okay Model".to_string(),
1025            ),
1026        ];
1027
1028        let x: Vec<f64> = (0..40).map(|i| i as f64 * 0.1).collect();
1029        let y: Vec<f64> = x.iter().map(|&xi| xi * 0.4).collect();
1030        let cv = KFold::new(3);
1031        let scoring = MockScoring;
1032
1033        let result = select_ensemble(&models, &x, &y, &cv, &scoring, Some(2));
1034        assert!(result.is_ok());
1035
1036        let result = result.unwrap();
1037        assert!(result.selected_models.len() <= 2);
1038        assert!(result.ensemble_performance.ensemble_size <= 2);
1039    }
1040
1041    #[test]
1042    fn test_ensemble_strategy_display() {
1043        assert_eq!(format!("{}", EnsembleStrategy::Voting), "Simple Voting");
1044        assert_eq!(
1045            format!("{}", EnsembleStrategy::WeightedVoting),
1046            "Weighted Voting"
1047        );
1048        assert_eq!(
1049            format!(
1050                "{}",
1051                EnsembleStrategy::Stacking {
1052                    meta_learner: "Linear".to_string()
1053                }
1054            ),
1055            "Stacking (Linear)"
1056        );
1057        assert_eq!(
1058            format!("{}", EnsembleStrategy::Blending { blend_ratio: 0.2 }),
1059            "Blending (ratio: 0.20)"
1060        );
1061    }
1062
1063    #[test]
1064    fn test_insufficient_models() {
1065        let models = vec![(
1066            MockEstimator {
1067                performance_level: 0.9,
1068            },
1069            "Only Model".to_string(),
1070        )];
1071
1072        let x: Vec<f64> = (0..20).map(|i| i as f64 * 0.1).collect();
1073        let y: Vec<f64> = x.iter().map(|&xi| xi * 0.5).collect();
1074        let cv = KFold::new(3);
1075        let scoring = MockScoring;
1076
1077        let selector = EnsembleSelector::new();
1078        let result = selector.select_ensemble(&models, &x, &y, &cv, &scoring);
1079
1080        assert!(result.is_err());
1081    }
1082}