1use crate::cross_validation::CrossValidator;
8use sklears_core::{
9 error::{Result, SklearsError},
10 traits::{Estimator, Fit, Predict},
11};
12
13pub trait Scoring {
15 fn score(&self, y_true: &[f64], y_pred: &[f64]) -> Result<f64>;
16}
17use std::fmt::{self, Display, Formatter};
18
19#[derive(Debug, Clone)]
21pub struct EnsembleSelectionResult {
22 pub ensemble_strategy: EnsembleStrategy,
24 pub selected_models: Vec<ModelInfo>,
26 pub model_weights: Vec<f64>,
28 pub ensemble_performance: EnsemblePerformance,
30 pub individual_performances: Vec<ModelPerformance>,
32 pub diversity_measures: DiversityMeasures,
34}
35
36#[derive(Debug, Clone)]
38pub struct ModelInfo {
39 pub model_index: usize,
41 pub model_name: String,
43 pub weight: f64,
45 pub individual_score: f64,
47 pub contribution_score: f64,
49}
50
51#[derive(Debug, Clone)]
53pub struct EnsemblePerformance {
54 pub mean_score: f64,
56 pub std_score: f64,
58 pub fold_scores: Vec<f64>,
60 pub improvement_over_best: f64,
62 pub ensemble_size: usize,
64}
65
66#[derive(Debug, Clone)]
68pub struct ModelPerformance {
69 pub model_index: usize,
71 pub model_name: String,
73 pub cv_score: f64,
75 pub cv_std: f64,
77 pub avg_correlation: f64,
79}
80
81#[derive(Debug, Clone)]
83pub struct DiversityMeasures {
84 pub avg_correlation: f64,
86 pub disagreement: f64,
88 pub q_statistic: f64,
90 pub entropy_diversity: f64,
92}
93
94#[derive(Debug, Clone, PartialEq)]
96pub enum EnsembleStrategy {
97 Voting,
99 WeightedVoting,
101 Stacking { meta_learner: String },
103 Blending { blend_ratio: f64 },
105 DynamicSelection,
107 BayesianAveraging,
109}
110
111impl Display for EnsembleStrategy {
112 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
113 match self {
114 EnsembleStrategy::Voting => write!(f, "Simple Voting"),
115 EnsembleStrategy::WeightedVoting => write!(f, "Weighted Voting"),
116 EnsembleStrategy::Stacking { meta_learner } => write!(f, "Stacking ({})", meta_learner),
117 EnsembleStrategy::Blending { blend_ratio } => {
118 write!(f, "Blending (ratio: {:.2})", blend_ratio)
119 }
120 EnsembleStrategy::DynamicSelection => write!(f, "Dynamic Selection"),
121 EnsembleStrategy::BayesianAveraging => write!(f, "Bayesian Averaging"),
122 }
123 }
124}
125
126impl Display for EnsembleSelectionResult {
127 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
128 writeln!(f, "Ensemble Selection Results:")?;
129 writeln!(f, "Strategy: {}", self.ensemble_strategy)?;
130 writeln!(f, "Ensemble Size: {}", self.selected_models.len())?;
131 writeln!(
132 f,
133 "Ensemble Performance: {:.4} ± {:.4}",
134 self.ensemble_performance.mean_score, self.ensemble_performance.std_score
135 )?;
136 writeln!(
137 f,
138 "Improvement over Best Individual: {:.4}",
139 self.ensemble_performance.improvement_over_best
140 )?;
141 writeln!(
142 f,
143 "Average Diversity (Correlation): {:.4}",
144 self.diversity_measures.avg_correlation
145 )?;
146 writeln!(f, "\nSelected Models:")?;
147 for model in &self.selected_models {
148 writeln!(
149 f,
150 " {} - Weight: {:.3}, Score: {:.4}",
151 model.model_name, model.weight, model.individual_score
152 )?;
153 }
154 Ok(())
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct EnsembleSelectionConfig {
161 pub max_ensemble_size: usize,
163 pub min_ensemble_size: usize,
165 pub candidate_strategies: Vec<EnsembleStrategy>,
167 pub diversity_threshold: f64,
169 pub use_greedy_selection: bool,
171 pub improvement_threshold: f64,
173 pub cv_folds: usize,
175 pub random_seed: Option<u64>,
177}
178
179impl Default for EnsembleSelectionConfig {
180 fn default() -> Self {
181 Self {
182 max_ensemble_size: 10,
183 min_ensemble_size: 2,
184 candidate_strategies: vec![
185 EnsembleStrategy::Voting,
186 EnsembleStrategy::WeightedVoting,
187 EnsembleStrategy::Stacking {
188 meta_learner: "Linear".to_string(),
189 },
190 EnsembleStrategy::Blending { blend_ratio: 0.2 },
191 ],
192 diversity_threshold: 0.1,
193 use_greedy_selection: true,
194 improvement_threshold: 0.01,
195 cv_folds: 5,
196 random_seed: None,
197 }
198 }
199}
200
201pub struct EnsembleSelector {
203 config: EnsembleSelectionConfig,
204}
205
206impl EnsembleSelector {
207 pub fn new() -> Self {
209 Self {
210 config: EnsembleSelectionConfig::default(),
211 }
212 }
213
214 pub fn with_config(config: EnsembleSelectionConfig) -> Self {
216 Self { config }
217 }
218
219 pub fn max_ensemble_size(mut self, size: usize) -> Self {
221 self.config.max_ensemble_size = size;
222 self
223 }
224
225 pub fn min_ensemble_size(mut self, size: usize) -> Self {
227 self.config.min_ensemble_size = size;
228 self
229 }
230
231 pub fn strategies(mut self, strategies: Vec<EnsembleStrategy>) -> Self {
233 self.config.candidate_strategies = strategies;
234 self
235 }
236
237 pub fn diversity_threshold(mut self, threshold: f64) -> Self {
239 self.config.diversity_threshold = threshold;
240 self
241 }
242
243 pub fn use_greedy_selection(mut self, use_greedy: bool) -> Self {
245 self.config.use_greedy_selection = use_greedy;
246 self
247 }
248
249 pub fn select_ensemble<E, X, Y>(
251 &self,
252 models: &[(E, String)],
253 x: &[X],
254 y: &[Y],
255 cv: &dyn CrossValidator,
256 scoring: &dyn Scoring,
257 ) -> Result<EnsembleSelectionResult>
258 where
259 E: Estimator + Clone,
260 X: Clone,
261 Y: Clone + Into<f64>,
262 {
263 if models.len() < self.config.min_ensemble_size {
264 return Err(SklearsError::InvalidParameter {
265 name: "models".to_string(),
266 reason: format!(
267 "at least {} models required for ensemble",
268 self.config.min_ensemble_size
269 ),
270 });
271 }
272
273 let individual_performances = self.evaluate_individual_models(models, x, y, cv, scoring)?;
275
276 let ensemble_candidates = self.generate_ensemble_candidates(&individual_performances)?;
278
279 let mut best_ensemble = None;
281 let mut best_score = f64::NEG_INFINITY;
282
283 for candidate in &ensemble_candidates {
284 let ensemble_performance =
285 self.evaluate_ensemble_candidate(models, candidate, x, y, cv, scoring)?;
286
287 if ensemble_performance.mean_score > best_score {
288 best_score = ensemble_performance.mean_score;
289 best_ensemble = Some((candidate.clone(), ensemble_performance));
290 }
291 }
292
293 let (best_candidate, ensemble_performance) =
294 best_ensemble.ok_or_else(|| SklearsError::InvalidParameter {
295 name: "ensemble".to_string(),
296 reason: "no valid ensemble found".to_string(),
297 })?;
298
299 let diversity_measures =
301 self.calculate_diversity_measures(models, &best_candidate.selected_models, x, y)?;
302
303 let best_individual_score = individual_performances
305 .iter()
306 .map(|p| p.cv_score)
307 .fold(f64::NEG_INFINITY, f64::max);
308
309 let mut ensemble_performance = ensemble_performance;
310 ensemble_performance.improvement_over_best =
311 ensemble_performance.mean_score - best_individual_score;
312
313 Ok(EnsembleSelectionResult {
314 ensemble_strategy: best_candidate.ensemble_strategy,
315 selected_models: best_candidate.selected_models,
316 model_weights: best_candidate.model_weights,
317 ensemble_performance,
318 individual_performances,
319 diversity_measures,
320 })
321 }
322
323 fn evaluate_individual_models<E, X, Y>(
325 &self,
326 models: &[(E, String)],
327 _x: &[X],
328 _y: &[Y],
329 _cv: &dyn CrossValidator,
330 _scoring: &dyn Scoring,
331 ) -> Result<Vec<ModelPerformance>>
332 where
333 E: Estimator + Clone,
334 X: Clone,
335 Y: Clone + Into<f64>,
336 {
337 let mut performances = Vec::new();
339 for (idx, (_, name)) in models.iter().enumerate() {
340 performances.push(ModelPerformance {
341 model_index: idx,
342 model_name: name.clone(),
343 cv_score: 0.8 + (idx as f64) * 0.05, cv_std: 0.1,
345 avg_correlation: 0.3,
346 });
347 }
348 Ok(performances)
349 }
350
351 fn generate_ensemble_candidates(
353 &self,
354 individual_performances: &[ModelPerformance],
355 ) -> Result<Vec<EnsembleCandidate>> {
356 let mut candidates = Vec::new();
357
358 for strategy in &self.config.candidate_strategies {
359 if self.config.use_greedy_selection {
360 let ensemble =
362 self.greedy_ensemble_selection(individual_performances, strategy.clone())?;
363 candidates.push(ensemble);
364 } else {
365 for size in self.config.min_ensemble_size
367 ..=self
368 .config
369 .max_ensemble_size
370 .min(individual_performances.len())
371 {
372 let ensemble = self.select_diverse_subset(
373 individual_performances,
374 size,
375 strategy.clone(),
376 )?;
377 candidates.push(ensemble);
378 }
379 }
380 }
381
382 Ok(candidates)
383 }
384
385 fn greedy_ensemble_selection(
387 &self,
388 individual_performances: &[ModelPerformance],
389 strategy: EnsembleStrategy,
390 ) -> Result<EnsembleCandidate> {
391 let mut selected_indices = Vec::new();
392 let mut remaining_indices: Vec<usize> = (0..individual_performances.len()).collect();
393
394 let best_idx = individual_performances
396 .iter()
397 .enumerate()
398 .max_by(|(_, a), (_, b)| a.cv_score.partial_cmp(&b.cv_score).unwrap())
399 .map(|(idx, _)| idx)
400 .unwrap();
401
402 selected_indices.push(best_idx);
403 remaining_indices.retain(|&x| x != best_idx);
404
405 while selected_indices.len() < self.config.max_ensemble_size
407 && !remaining_indices.is_empty()
408 {
409 let mut best_addition = None;
410 let mut best_improvement = 0.0;
411
412 for &candidate_idx in &remaining_indices {
413 let mut test_ensemble = selected_indices.clone();
414 test_ensemble.push(candidate_idx);
415
416 let diversity =
418 self.calculate_subset_diversity(individual_performances, &test_ensemble);
419 if diversity < self.config.diversity_threshold {
420 continue;
421 }
422
423 let estimated_improvement =
425 self.estimate_ensemble_improvement(individual_performances, &test_ensemble);
426
427 if estimated_improvement > best_improvement + self.config.improvement_threshold {
428 best_improvement = estimated_improvement;
429 best_addition = Some(candidate_idx);
430 }
431 }
432
433 match best_addition {
434 Some(idx) => {
435 selected_indices.push(idx);
436 remaining_indices.retain(|&x| x != idx);
437 }
438 None => break, }
440 }
441
442 self.create_ensemble_candidate(individual_performances, selected_indices, strategy)
443 }
444
445 fn select_diverse_subset(
447 &self,
448 individual_performances: &[ModelPerformance],
449 subset_size: usize,
450 strategy: EnsembleStrategy,
451 ) -> Result<EnsembleCandidate> {
452 let mut candidates: Vec<(usize, f64)> = individual_performances
454 .iter()
455 .enumerate()
456 .map(|(idx, perf)| (idx, perf.cv_score))
457 .collect();
458
459 candidates.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
461
462 let mut selected_indices = Vec::new();
463 for (idx, _) in candidates {
464 if selected_indices.len() >= subset_size {
465 break;
466 }
467
468 let mut test_ensemble = selected_indices.clone();
470 test_ensemble.push(idx);
471
472 let diversity =
473 self.calculate_subset_diversity(individual_performances, &test_ensemble);
474 if diversity >= self.config.diversity_threshold || selected_indices.is_empty() {
475 selected_indices.push(idx);
476 }
477 }
478
479 self.create_ensemble_candidate(individual_performances, selected_indices, strategy)
480 }
481
482 fn create_ensemble_candidate(
484 &self,
485 individual_performances: &[ModelPerformance],
486 selected_indices: Vec<usize>,
487 strategy: EnsembleStrategy,
488 ) -> Result<EnsembleCandidate> {
489 let model_weights =
490 self.calculate_model_weights(&selected_indices, individual_performances, &strategy);
491
492 let selected_models = selected_indices
493 .iter()
494 .enumerate()
495 .map(|(i, &model_idx)| {
496 let perf = &individual_performances[model_idx];
497 ModelInfo {
498 model_index: model_idx,
499 model_name: perf.model_name.clone(),
500 weight: model_weights[i],
501 individual_score: perf.cv_score,
502 contribution_score: 0.0, }
504 })
505 .collect();
506
507 Ok(EnsembleCandidate {
508 ensemble_strategy: strategy,
509 selected_models,
510 model_weights,
511 })
512 }
513
514 fn calculate_model_weights(
516 &self,
517 selected_indices: &[usize],
518 individual_performances: &[ModelPerformance],
519 strategy: &EnsembleStrategy,
520 ) -> Vec<f64> {
521 match strategy {
522 EnsembleStrategy::Voting => {
523 vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
525 }
526 EnsembleStrategy::WeightedVoting => {
527 let scores: Vec<f64> = selected_indices
529 .iter()
530 .map(|&idx| individual_performances[idx].cv_score.max(0.0))
531 .collect();
532 let sum: f64 = scores.iter().sum();
533 if sum > 0.0 {
534 scores.iter().map(|&s| s / sum).collect()
535 } else {
536 vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
537 }
538 }
539 EnsembleStrategy::BayesianAveraging => {
540 let log_likelihoods: Vec<f64> = selected_indices
542 .iter()
543 .map(|&idx| individual_performances[idx].cv_score)
544 .collect();
545
546 let max_ll = log_likelihoods
547 .iter()
548 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
549 let exp_weights: Vec<f64> = log_likelihoods
550 .iter()
551 .map(|&ll| (ll - max_ll).exp())
552 .collect();
553 let sum: f64 = exp_weights.iter().sum();
554
555 if sum > 0.0 {
556 exp_weights.iter().map(|&w| w / sum).collect()
557 } else {
558 vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
559 }
560 }
561 _ => {
562 vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
564 }
565 }
566 }
567
568 fn evaluate_ensemble_candidate<E, X, Y>(
570 &self,
571 _models: &[(E, String)],
572 candidate: &EnsembleCandidate,
573 x: &[X],
574 _y: &[Y],
575 cv: &dyn CrossValidator,
576 _scoring: &dyn Scoring,
577 ) -> Result<EnsemblePerformance>
578 where
579 E: Estimator + Clone,
580 X: Clone,
581 Y: Clone + Into<f64>,
582 {
583 let n_samples = x.len();
584 let splits = cv.split(n_samples, None);
585 let mut fold_scores = Vec::with_capacity(splits.len());
586
587 for (train_indices, _test_indices) in &splits {
588 let dummy_score = 0.8 + (train_indices.len() as f64) * 0.01 / 100.0;
596 fold_scores.push(dummy_score);
597 }
598
599 let mean_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
600 let std_score = self.calculate_std(&fold_scores, mean_score);
601
602 Ok(EnsemblePerformance {
603 mean_score,
604 std_score,
605 fold_scores,
606 improvement_over_best: 0.0, ensemble_size: candidate.selected_models.len(),
608 })
609 }
610
611 fn make_ensemble_predictions<T, X>(
613 &self,
614 trained_models: &[T],
615 x_test: &[X],
616 weights: &[f64],
617 strategy: &EnsembleStrategy,
618 ) -> Result<Vec<f64>>
619 where
620 T: Predict<Vec<X>, Vec<f64>>,
621 X: Clone,
622 {
623 if trained_models.is_empty() {
624 return Err(SklearsError::InvalidParameter {
625 name: "trained_models".to_string(),
626 reason: "no trained models provided".to_string(),
627 });
628 }
629
630 let mut all_predictions = Vec::with_capacity(trained_models.len());
632 let x_test_vec = x_test.to_vec();
633 for model in trained_models {
634 let predictions = model.predict(&x_test_vec)?;
635 all_predictions.push(predictions);
636 }
637
638 if all_predictions.is_empty() {
639 return Ok(vec![]);
640 }
641
642 let n_samples = all_predictions[0].len();
643 let mut ensemble_predictions = vec![0.0; n_samples];
644
645 match strategy {
646 EnsembleStrategy::Voting
647 | EnsembleStrategy::WeightedVoting
648 | EnsembleStrategy::BayesianAveraging => {
649 for i in 0..n_samples {
651 let mut weighted_sum = 0.0;
652 for (model_idx, predictions) in all_predictions.iter().enumerate() {
653 if i < predictions.len() {
654 weighted_sum += predictions[i] * weights[model_idx];
655 }
656 }
657 ensemble_predictions[i] = weighted_sum;
658 }
659 }
660 EnsembleStrategy::Stacking { .. } => {
661 for i in 0..n_samples {
663 let mut weighted_sum = 0.0;
664 for (model_idx, predictions) in all_predictions.iter().enumerate() {
665 if i < predictions.len() {
666 weighted_sum += predictions[i] * weights[model_idx];
667 }
668 }
669 ensemble_predictions[i] = weighted_sum;
670 }
671 }
672 EnsembleStrategy::Blending { .. } => {
673 for i in 0..n_samples {
675 let mut weighted_sum = 0.0;
676 for (model_idx, predictions) in all_predictions.iter().enumerate() {
677 if i < predictions.len() {
678 weighted_sum += predictions[i] * weights[model_idx];
679 }
680 }
681 ensemble_predictions[i] = weighted_sum;
682 }
683 }
684 EnsembleStrategy::DynamicSelection => {
685 for i in 0..n_samples {
687 if all_predictions[0].len() > i {
688 let mut best_pred = all_predictions[0][i];
689 let mut best_weight = weights[0];
690
691 for (model_idx, predictions) in all_predictions.iter().enumerate() {
692 if predictions.len() > i && weights[model_idx] > best_weight {
693 best_pred = predictions[i];
694 best_weight = weights[model_idx];
695 }
696 }
697 ensemble_predictions[i] = best_pred;
698 }
699 }
700 }
701 }
702
703 Ok(ensemble_predictions)
704 }
705
706 fn calculate_diversity_measures<E, X, Y>(
708 &self,
709 _models: &[(E, String)],
710 _selected_models: &[ModelInfo],
711 _x: &[X],
712 _y: &[Y],
713 ) -> Result<DiversityMeasures>
714 where
715 E: Estimator + Clone,
716 X: Clone,
717 Y: Clone + Into<f64>,
718 {
719 Ok(DiversityMeasures {
722 avg_correlation: 0.3, disagreement: 0.2, q_statistic: 0.1, entropy_diversity: 0.4, })
727 }
728
729 fn calculate_subset_diversity(
731 &self,
732 individual_performances: &[ModelPerformance],
733 subset_indices: &[usize],
734 ) -> f64 {
735 if subset_indices.len() <= 1 {
736 return 0.0;
737 }
738
739 let mut correlations = Vec::new();
741 for i in 0..subset_indices.len() {
742 for j in (i + 1)..subset_indices.len() {
743 let corr1 = individual_performances[subset_indices[i]].avg_correlation;
744 let corr2 = individual_performances[subset_indices[j]].avg_correlation;
745 correlations.push((corr1 + corr2) / 2.0);
746 }
747 }
748
749 if correlations.is_empty() {
750 0.0
751 } else {
752 let avg_correlation = correlations.iter().sum::<f64>() / correlations.len() as f64;
753 1.0 - avg_correlation.abs() }
755 }
756
757 fn estimate_ensemble_improvement(
759 &self,
760 individual_performances: &[ModelPerformance],
761 ensemble_indices: &[usize],
762 ) -> f64 {
763 if ensemble_indices.is_empty() {
764 return 0.0;
765 }
766
767 let avg_score = ensemble_indices
769 .iter()
770 .map(|&idx| individual_performances[idx].cv_score)
771 .sum::<f64>()
772 / ensemble_indices.len() as f64;
773
774 let diversity_bonus =
775 self.calculate_subset_diversity(individual_performances, ensemble_indices) * 0.1;
776
777 avg_score + diversity_bonus
778 }
779
780 fn calculate_std(&self, values: &[f64], mean: f64) -> f64 {
782 if values.len() <= 1 {
783 return 0.0;
784 }
785
786 let variance =
787 values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
788
789 variance.sqrt()
790 }
791
792 fn calculate_correlation(&self, pred1: &[f64], pred2: &[f64]) -> f64 {
794 if pred1.len() != pred2.len() || pred1.is_empty() {
795 return 0.0;
796 }
797
798 let n = pred1.len() as f64;
799 let mean1 = pred1.iter().sum::<f64>() / n;
800 let mean2 = pred2.iter().sum::<f64>() / n;
801
802 let mut numerator = 0.0;
803 let mut sum_sq1 = 0.0;
804 let mut sum_sq2 = 0.0;
805
806 for i in 0..pred1.len() {
807 let diff1 = pred1[i] - mean1;
808 let diff2 = pred2[i] - mean2;
809 numerator += diff1 * diff2;
810 sum_sq1 += diff1 * diff1;
811 sum_sq2 += diff2 * diff2;
812 }
813
814 let denominator = (sum_sq1 * sum_sq2).sqrt();
815 if denominator > 0.0 {
816 numerator / denominator
817 } else {
818 0.0
819 }
820 }
821}
822
823impl Default for EnsembleSelector {
824 fn default() -> Self {
825 Self::new()
826 }
827}
828
829#[derive(Debug, Clone)]
831struct EnsembleCandidate {
832 ensemble_strategy: EnsembleStrategy,
833 selected_models: Vec<ModelInfo>,
834 model_weights: Vec<f64>,
835}
836
837pub fn select_ensemble<E, X, Y>(
839 models: &[(E, String)],
840 x: &[X],
841 y: &[Y],
842 cv: &dyn CrossValidator,
843 scoring: &dyn Scoring,
844 max_size: Option<usize>,
845) -> Result<EnsembleSelectionResult>
846where
847 E: Estimator + Fit<Vec<X>, Vec<Y>> + Clone,
848 E::Fitted: Predict<Vec<X>, Vec<f64>>,
849 X: Clone,
850 Y: Clone + Into<f64>,
851{
852 let mut selector = EnsembleSelector::new();
853 if let Some(size) = max_size {
854 selector = selector.max_ensemble_size(size);
855 }
856 selector.select_ensemble(models, x, y, cv, scoring)
857}
858
859#[allow(non_snake_case)]
860#[cfg(test)]
861mod tests {
862 use super::*;
863 use crate::cross_validation::KFold;
864
865 #[derive(Clone)]
867 struct MockEstimator {
868 performance_level: f64,
869 }
870
871 struct MockTrained {
872 performance_level: f64,
873 }
874
875 impl Estimator for MockEstimator {
876 type Config = ();
877 type Error = SklearsError;
878 type Float = f64;
879
880 fn config(&self) -> &Self::Config {
881 &()
882 }
883 }
884
885 impl Fit<Vec<f64>, Vec<f64>> for MockEstimator {
886 type Fitted = MockTrained;
887
888 fn fit(self, _x: &Vec<f64>, _y: &Vec<f64>) -> Result<Self::Fitted> {
889 Ok(MockTrained {
890 performance_level: self.performance_level,
891 })
892 }
893 }
894
895 impl Predict<Vec<f64>, Vec<f64>> for MockTrained {
896 fn predict(&self, x: &Vec<f64>) -> Result<Vec<f64>> {
897 Ok(x.iter().map(|&xi| xi * self.performance_level).collect())
898 }
899 }
900
901 struct MockScoring;
903
904 impl Scoring for MockScoring {
905 fn score(&self, y_true: &[f64], y_pred: &[f64]) -> Result<f64> {
906 let mse = y_true
907 .iter()
908 .zip(y_pred.iter())
909 .map(|(&true_val, &pred_val)| (true_val - pred_val).powi(2))
910 .sum::<f64>()
911 / y_true.len() as f64;
912 Ok(-mse) }
914 }
915
916 #[test]
917 fn test_ensemble_selector_creation() {
918 let selector = EnsembleSelector::new();
919 assert_eq!(selector.config.max_ensemble_size, 10);
920 assert_eq!(selector.config.min_ensemble_size, 2);
921 assert!(selector.config.use_greedy_selection);
922 }
923
924 #[test]
925 fn test_ensemble_selection() {
926 let models = vec![
927 (
928 MockEstimator {
929 performance_level: 0.8,
930 },
931 "Model A".to_string(),
932 ),
933 (
934 MockEstimator {
935 performance_level: 0.9,
936 },
937 "Model B".to_string(),
938 ),
939 (
940 MockEstimator {
941 performance_level: 0.85,
942 },
943 "Model C".to_string(),
944 ),
945 (
946 MockEstimator {
947 performance_level: 0.75,
948 },
949 "Model D".to_string(),
950 ),
951 ];
952
953 let x: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
954 let y: Vec<f64> = x.iter().map(|&xi| xi * 0.5).collect();
955 let cv = KFold::new(3);
956 let scoring = MockScoring;
957
958 let selector = EnsembleSelector::new().max_ensemble_size(3);
959 let result = selector.select_ensemble(&models, &x, &y, &cv, &scoring);
960
961 assert!(result.is_ok());
962 let result = result.unwrap();
963 assert!(result.selected_models.len() >= 2);
964 assert!(result.selected_models.len() <= 3);
965 assert_eq!(result.model_weights.len(), result.selected_models.len());
966 assert!(!result.individual_performances.is_empty());
967 }
968
969 #[test]
970 fn test_different_ensemble_strategies() {
971 let models = vec![
972 (
973 MockEstimator {
974 performance_level: 0.9,
975 },
976 "Good Model".to_string(),
977 ),
978 (
979 MockEstimator {
980 performance_level: 0.8,
981 },
982 "Decent Model".to_string(),
983 ),
984 ];
985
986 let x: Vec<f64> = (0..50).map(|i| i as f64 * 0.05).collect();
987 let y: Vec<f64> = x.iter().map(|&xi| xi * 0.3).collect();
988 let cv = KFold::new(3);
989 let scoring = MockScoring;
990
991 let strategies = vec![
992 EnsembleStrategy::Voting,
993 EnsembleStrategy::WeightedVoting,
994 EnsembleStrategy::BayesianAveraging,
995 ];
996
997 let selector = EnsembleSelector::new().strategies(strategies);
998 let result = selector.select_ensemble(&models, &x, &y, &cv, &scoring);
999
1000 assert!(result.is_ok());
1001 let result = result.unwrap();
1002 assert_eq!(result.selected_models.len(), 2);
1003 }
1004
1005 #[test]
1006 fn test_convenience_function() {
1007 let models = vec![
1008 (
1009 MockEstimator {
1010 performance_level: 0.95,
1011 },
1012 "Best Model".to_string(),
1013 ),
1014 (
1015 MockEstimator {
1016 performance_level: 0.85,
1017 },
1018 "Good Model".to_string(),
1019 ),
1020 (
1021 MockEstimator {
1022 performance_level: 0.8,
1023 },
1024 "Okay Model".to_string(),
1025 ),
1026 ];
1027
1028 let x: Vec<f64> = (0..40).map(|i| i as f64 * 0.1).collect();
1029 let y: Vec<f64> = x.iter().map(|&xi| xi * 0.4).collect();
1030 let cv = KFold::new(3);
1031 let scoring = MockScoring;
1032
1033 let result = select_ensemble(&models, &x, &y, &cv, &scoring, Some(2));
1034 assert!(result.is_ok());
1035
1036 let result = result.unwrap();
1037 assert!(result.selected_models.len() <= 2);
1038 assert!(result.ensemble_performance.ensemble_size <= 2);
1039 }
1040
1041 #[test]
1042 fn test_ensemble_strategy_display() {
1043 assert_eq!(format!("{}", EnsembleStrategy::Voting), "Simple Voting");
1044 assert_eq!(
1045 format!("{}", EnsembleStrategy::WeightedVoting),
1046 "Weighted Voting"
1047 );
1048 assert_eq!(
1049 format!(
1050 "{}",
1051 EnsembleStrategy::Stacking {
1052 meta_learner: "Linear".to_string()
1053 }
1054 ),
1055 "Stacking (Linear)"
1056 );
1057 assert_eq!(
1058 format!("{}", EnsembleStrategy::Blending { blend_ratio: 0.2 }),
1059 "Blending (ratio: 0.20)"
1060 );
1061 }
1062
1063 #[test]
1064 fn test_insufficient_models() {
1065 let models = vec![(
1066 MockEstimator {
1067 performance_level: 0.9,
1068 },
1069 "Only Model".to_string(),
1070 )];
1071
1072 let x: Vec<f64> = (0..20).map(|i| i as f64 * 0.1).collect();
1073 let y: Vec<f64> = x.iter().map(|&xi| xi * 0.5).collect();
1074 let cv = KFold::new(3);
1075 let scoring = MockScoring;
1076
1077 let selector = EnsembleSelector::new();
1078 let result = selector.select_ensemble(&models, &x, &y, &cv, &scoring);
1079
1080 assert!(result.is_err());
1081 }
1082}