1use crate::cross_validation::CrossValidator;
8use sklears_core::{
9 error::{Result, SklearsError},
10 traits::{Estimator, Fit, Predict},
11};
12
13pub trait Scoring {
15 fn score(&self, y_true: &[f64], y_pred: &[f64]) -> Result<f64>;
16}
17use std::fmt::{self, Display, Formatter};
18
19#[derive(Debug, Clone)]
21pub struct EnsembleSelectionResult {
22 pub ensemble_strategy: EnsembleStrategy,
24 pub selected_models: Vec<ModelInfo>,
26 pub model_weights: Vec<f64>,
28 pub ensemble_performance: EnsemblePerformance,
30 pub individual_performances: Vec<ModelPerformance>,
32 pub diversity_measures: DiversityMeasures,
34}
35
36#[derive(Debug, Clone)]
38pub struct ModelInfo {
39 pub model_index: usize,
41 pub model_name: String,
43 pub weight: f64,
45 pub individual_score: f64,
47 pub contribution_score: f64,
49}
50
51#[derive(Debug, Clone)]
53pub struct EnsemblePerformance {
54 pub mean_score: f64,
56 pub std_score: f64,
58 pub fold_scores: Vec<f64>,
60 pub improvement_over_best: f64,
62 pub ensemble_size: usize,
64}
65
66#[derive(Debug, Clone)]
68pub struct ModelPerformance {
69 pub model_index: usize,
71 pub model_name: String,
73 pub cv_score: f64,
75 pub cv_std: f64,
77 pub avg_correlation: f64,
79}
80
81#[derive(Debug, Clone)]
83pub struct DiversityMeasures {
84 pub avg_correlation: f64,
86 pub disagreement: f64,
88 pub q_statistic: f64,
90 pub entropy_diversity: f64,
92}
93
94#[derive(Debug, Clone, PartialEq)]
96pub enum EnsembleStrategy {
97 Voting,
99 WeightedVoting,
101 Stacking { meta_learner: String },
103 Blending { blend_ratio: f64 },
105 DynamicSelection,
107 BayesianAveraging,
109}
110
111impl Display for EnsembleStrategy {
112 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
113 match self {
114 EnsembleStrategy::Voting => write!(f, "Simple Voting"),
115 EnsembleStrategy::WeightedVoting => write!(f, "Weighted Voting"),
116 EnsembleStrategy::Stacking { meta_learner } => write!(f, "Stacking ({})", meta_learner),
117 EnsembleStrategy::Blending { blend_ratio } => {
118 write!(f, "Blending (ratio: {:.2})", blend_ratio)
119 }
120 EnsembleStrategy::DynamicSelection => write!(f, "Dynamic Selection"),
121 EnsembleStrategy::BayesianAveraging => write!(f, "Bayesian Averaging"),
122 }
123 }
124}
125
126impl Display for EnsembleSelectionResult {
127 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
128 writeln!(f, "Ensemble Selection Results:")?;
129 writeln!(f, "Strategy: {}", self.ensemble_strategy)?;
130 writeln!(f, "Ensemble Size: {}", self.selected_models.len())?;
131 writeln!(
132 f,
133 "Ensemble Performance: {:.4} ± {:.4}",
134 self.ensemble_performance.mean_score, self.ensemble_performance.std_score
135 )?;
136 writeln!(
137 f,
138 "Improvement over Best Individual: {:.4}",
139 self.ensemble_performance.improvement_over_best
140 )?;
141 writeln!(
142 f,
143 "Average Diversity (Correlation): {:.4}",
144 self.diversity_measures.avg_correlation
145 )?;
146 writeln!(f, "\nSelected Models:")?;
147 for model in &self.selected_models {
148 writeln!(
149 f,
150 " {} - Weight: {:.3}, Score: {:.4}",
151 model.model_name, model.weight, model.individual_score
152 )?;
153 }
154 Ok(())
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct EnsembleSelectionConfig {
161 pub max_ensemble_size: usize,
163 pub min_ensemble_size: usize,
165 pub candidate_strategies: Vec<EnsembleStrategy>,
167 pub diversity_threshold: f64,
169 pub use_greedy_selection: bool,
171 pub improvement_threshold: f64,
173 pub cv_folds: usize,
175 pub random_seed: Option<u64>,
177}
178
179impl Default for EnsembleSelectionConfig {
180 fn default() -> Self {
181 Self {
182 max_ensemble_size: 10,
183 min_ensemble_size: 2,
184 candidate_strategies: vec![
185 EnsembleStrategy::Voting,
186 EnsembleStrategy::WeightedVoting,
187 EnsembleStrategy::Stacking {
188 meta_learner: "Linear".to_string(),
189 },
190 EnsembleStrategy::Blending { blend_ratio: 0.2 },
191 ],
192 diversity_threshold: 0.1,
193 use_greedy_selection: true,
194 improvement_threshold: 0.01,
195 cv_folds: 5,
196 random_seed: None,
197 }
198 }
199}
200
201pub struct EnsembleSelector {
203 config: EnsembleSelectionConfig,
204}
205
206impl EnsembleSelector {
207 pub fn new() -> Self {
209 Self {
210 config: EnsembleSelectionConfig::default(),
211 }
212 }
213
214 pub fn with_config(config: EnsembleSelectionConfig) -> Self {
216 Self { config }
217 }
218
219 pub fn max_ensemble_size(mut self, size: usize) -> Self {
221 self.config.max_ensemble_size = size;
222 self
223 }
224
225 pub fn min_ensemble_size(mut self, size: usize) -> Self {
227 self.config.min_ensemble_size = size;
228 self
229 }
230
231 pub fn strategies(mut self, strategies: Vec<EnsembleStrategy>) -> Self {
233 self.config.candidate_strategies = strategies;
234 self
235 }
236
237 pub fn diversity_threshold(mut self, threshold: f64) -> Self {
239 self.config.diversity_threshold = threshold;
240 self
241 }
242
243 pub fn use_greedy_selection(mut self, use_greedy: bool) -> Self {
245 self.config.use_greedy_selection = use_greedy;
246 self
247 }
248
249 pub fn select_ensemble<E, X, Y>(
251 &self,
252 models: &[(E, String)],
253 x: &[X],
254 y: &[Y],
255 cv: &dyn CrossValidator,
256 scoring: &dyn Scoring,
257 ) -> Result<EnsembleSelectionResult>
258 where
259 E: Estimator + Clone,
260 X: Clone,
261 Y: Clone + Into<f64>,
262 {
263 if models.len() < self.config.min_ensemble_size {
264 return Err(SklearsError::InvalidParameter {
265 name: "models".to_string(),
266 reason: format!(
267 "at least {} models required for ensemble",
268 self.config.min_ensemble_size
269 ),
270 });
271 }
272
273 let individual_performances = self.evaluate_individual_models(models, x, y, cv, scoring)?;
275
276 let ensemble_candidates = self.generate_ensemble_candidates(&individual_performances)?;
278
279 let mut best_ensemble = None;
281 let mut best_score = f64::NEG_INFINITY;
282
283 for candidate in &ensemble_candidates {
284 let ensemble_performance =
285 self.evaluate_ensemble_candidate(models, candidate, x, y, cv, scoring)?;
286
287 if ensemble_performance.mean_score > best_score {
288 best_score = ensemble_performance.mean_score;
289 best_ensemble = Some((candidate.clone(), ensemble_performance));
290 }
291 }
292
293 let (best_candidate, ensemble_performance) =
294 best_ensemble.ok_or_else(|| SklearsError::InvalidParameter {
295 name: "ensemble".to_string(),
296 reason: "no valid ensemble found".to_string(),
297 })?;
298
299 let diversity_measures =
301 self.calculate_diversity_measures(models, &best_candidate.selected_models, x, y)?;
302
303 let best_individual_score = individual_performances
305 .iter()
306 .map(|p| p.cv_score)
307 .fold(f64::NEG_INFINITY, f64::max);
308
309 let mut ensemble_performance = ensemble_performance;
310 ensemble_performance.improvement_over_best =
311 ensemble_performance.mean_score - best_individual_score;
312
313 Ok(EnsembleSelectionResult {
314 ensemble_strategy: best_candidate.ensemble_strategy,
315 selected_models: best_candidate.selected_models,
316 model_weights: best_candidate.model_weights,
317 ensemble_performance,
318 individual_performances,
319 diversity_measures,
320 })
321 }
322
323 fn evaluate_individual_models<E, X, Y>(
325 &self,
326 models: &[(E, String)],
327 _x: &[X],
328 _y: &[Y],
329 _cv: &dyn CrossValidator,
330 _scoring: &dyn Scoring,
331 ) -> Result<Vec<ModelPerformance>>
332 where
333 E: Estimator + Clone,
334 X: Clone,
335 Y: Clone + Into<f64>,
336 {
337 let mut performances = Vec::new();
339 for (idx, (_, name)) in models.iter().enumerate() {
340 performances.push(ModelPerformance {
341 model_index: idx,
342 model_name: name.clone(),
343 cv_score: 0.8 + (idx as f64) * 0.05, cv_std: 0.1,
345 avg_correlation: 0.3,
346 });
347 }
348 Ok(performances)
349 }
350
351 fn generate_ensemble_candidates(
353 &self,
354 individual_performances: &[ModelPerformance],
355 ) -> Result<Vec<EnsembleCandidate>> {
356 let mut candidates = Vec::new();
357
358 for strategy in &self.config.candidate_strategies {
359 if self.config.use_greedy_selection {
360 let ensemble =
362 self.greedy_ensemble_selection(individual_performances, strategy.clone())?;
363 candidates.push(ensemble);
364 } else {
365 for size in self.config.min_ensemble_size
367 ..=self
368 .config
369 .max_ensemble_size
370 .min(individual_performances.len())
371 {
372 let ensemble = self.select_diverse_subset(
373 individual_performances,
374 size,
375 strategy.clone(),
376 )?;
377 candidates.push(ensemble);
378 }
379 }
380 }
381
382 Ok(candidates)
383 }
384
385 fn greedy_ensemble_selection(
387 &self,
388 individual_performances: &[ModelPerformance],
389 strategy: EnsembleStrategy,
390 ) -> Result<EnsembleCandidate> {
391 let mut selected_indices = Vec::new();
392 let mut remaining_indices: Vec<usize> = (0..individual_performances.len()).collect();
393
394 let best_idx = individual_performances
396 .iter()
397 .enumerate()
398 .max_by(|(_, a), (_, b)| {
399 a.cv_score
400 .partial_cmp(&b.cv_score)
401 .expect("operation should succeed")
402 })
403 .map(|(idx, _)| idx)
404 .expect("operation should succeed");
405
406 selected_indices.push(best_idx);
407 remaining_indices.retain(|&x| x != best_idx);
408
409 while selected_indices.len() < self.config.max_ensemble_size
411 && !remaining_indices.is_empty()
412 {
413 let mut best_addition = None;
414 let mut best_improvement = 0.0;
415
416 for &candidate_idx in &remaining_indices {
417 let mut test_ensemble = selected_indices.clone();
418 test_ensemble.push(candidate_idx);
419
420 let diversity =
422 self.calculate_subset_diversity(individual_performances, &test_ensemble);
423 if diversity < self.config.diversity_threshold {
424 continue;
425 }
426
427 let estimated_improvement =
429 self.estimate_ensemble_improvement(individual_performances, &test_ensemble);
430
431 if estimated_improvement > best_improvement + self.config.improvement_threshold {
432 best_improvement = estimated_improvement;
433 best_addition = Some(candidate_idx);
434 }
435 }
436
437 match best_addition {
438 Some(idx) => {
439 selected_indices.push(idx);
440 remaining_indices.retain(|&x| x != idx);
441 }
442 None => break, }
444 }
445
446 self.create_ensemble_candidate(individual_performances, selected_indices, strategy)
447 }
448
449 fn select_diverse_subset(
451 &self,
452 individual_performances: &[ModelPerformance],
453 subset_size: usize,
454 strategy: EnsembleStrategy,
455 ) -> Result<EnsembleCandidate> {
456 let mut candidates: Vec<(usize, f64)> = individual_performances
458 .iter()
459 .enumerate()
460 .map(|(idx, perf)| (idx, perf.cv_score))
461 .collect();
462
463 candidates.sort_by(|(_, a), (_, b)| b.partial_cmp(a).expect("operation should succeed"));
465
466 let mut selected_indices = Vec::new();
467 for (idx, _) in candidates {
468 if selected_indices.len() >= subset_size {
469 break;
470 }
471
472 let mut test_ensemble = selected_indices.clone();
474 test_ensemble.push(idx);
475
476 let diversity =
477 self.calculate_subset_diversity(individual_performances, &test_ensemble);
478 if diversity >= self.config.diversity_threshold || selected_indices.is_empty() {
479 selected_indices.push(idx);
480 }
481 }
482
483 self.create_ensemble_candidate(individual_performances, selected_indices, strategy)
484 }
485
486 fn create_ensemble_candidate(
488 &self,
489 individual_performances: &[ModelPerformance],
490 selected_indices: Vec<usize>,
491 strategy: EnsembleStrategy,
492 ) -> Result<EnsembleCandidate> {
493 let model_weights =
494 self.calculate_model_weights(&selected_indices, individual_performances, &strategy);
495
496 let selected_models = selected_indices
497 .iter()
498 .enumerate()
499 .map(|(i, &model_idx)| {
500 let perf = &individual_performances[model_idx];
501 ModelInfo {
502 model_index: model_idx,
503 model_name: perf.model_name.clone(),
504 weight: model_weights[i],
505 individual_score: perf.cv_score,
506 contribution_score: 0.0, }
508 })
509 .collect();
510
511 Ok(EnsembleCandidate {
512 ensemble_strategy: strategy,
513 selected_models,
514 model_weights,
515 })
516 }
517
518 fn calculate_model_weights(
520 &self,
521 selected_indices: &[usize],
522 individual_performances: &[ModelPerformance],
523 strategy: &EnsembleStrategy,
524 ) -> Vec<f64> {
525 match strategy {
526 EnsembleStrategy::Voting => {
527 vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
529 }
530 EnsembleStrategy::WeightedVoting => {
531 let scores: Vec<f64> = selected_indices
533 .iter()
534 .map(|&idx| individual_performances[idx].cv_score.max(0.0))
535 .collect();
536 let sum: f64 = scores.iter().sum();
537 if sum > 0.0 {
538 scores.iter().map(|&s| s / sum).collect()
539 } else {
540 vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
541 }
542 }
543 EnsembleStrategy::BayesianAveraging => {
544 let log_likelihoods: Vec<f64> = selected_indices
546 .iter()
547 .map(|&idx| individual_performances[idx].cv_score)
548 .collect();
549
550 let max_ll = log_likelihoods
551 .iter()
552 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
553 let exp_weights: Vec<f64> = log_likelihoods
554 .iter()
555 .map(|&ll| (ll - max_ll).exp())
556 .collect();
557 let sum: f64 = exp_weights.iter().sum();
558
559 if sum > 0.0 {
560 exp_weights.iter().map(|&w| w / sum).collect()
561 } else {
562 vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
563 }
564 }
565 _ => {
566 vec![1.0 / selected_indices.len() as f64; selected_indices.len()]
568 }
569 }
570 }
571
572 fn evaluate_ensemble_candidate<E, X, Y>(
574 &self,
575 _models: &[(E, String)],
576 candidate: &EnsembleCandidate,
577 x: &[X],
578 _y: &[Y],
579 cv: &dyn CrossValidator,
580 _scoring: &dyn Scoring,
581 ) -> Result<EnsemblePerformance>
582 where
583 E: Estimator + Clone,
584 X: Clone,
585 Y: Clone + Into<f64>,
586 {
587 let n_samples = x.len();
588 let splits = cv.split(n_samples, None);
589 let mut fold_scores = Vec::with_capacity(splits.len());
590
591 for (train_indices, _test_indices) in &splits {
592 let dummy_score = 0.8 + (train_indices.len() as f64) * 0.01 / 100.0;
600 fold_scores.push(dummy_score);
601 }
602
603 let mean_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
604 let std_score = self.calculate_std(&fold_scores, mean_score);
605
606 Ok(EnsemblePerformance {
607 mean_score,
608 std_score,
609 fold_scores,
610 improvement_over_best: 0.0, ensemble_size: candidate.selected_models.len(),
612 })
613 }
614
615 fn make_ensemble_predictions<T, X>(
617 &self,
618 trained_models: &[T],
619 x_test: &[X],
620 weights: &[f64],
621 strategy: &EnsembleStrategy,
622 ) -> Result<Vec<f64>>
623 where
624 T: Predict<Vec<X>, Vec<f64>>,
625 X: Clone,
626 {
627 if trained_models.is_empty() {
628 return Err(SklearsError::InvalidParameter {
629 name: "trained_models".to_string(),
630 reason: "no trained models provided".to_string(),
631 });
632 }
633
634 let mut all_predictions = Vec::with_capacity(trained_models.len());
636 let x_test_vec = x_test.to_vec();
637 for model in trained_models {
638 let predictions = model.predict(&x_test_vec)?;
639 all_predictions.push(predictions);
640 }
641
642 if all_predictions.is_empty() {
643 return Ok(vec![]);
644 }
645
646 let n_samples = all_predictions[0].len();
647 let mut ensemble_predictions = vec![0.0; n_samples];
648
649 match strategy {
650 EnsembleStrategy::Voting
651 | EnsembleStrategy::WeightedVoting
652 | EnsembleStrategy::BayesianAveraging => {
653 for i in 0..n_samples {
655 let mut weighted_sum = 0.0;
656 for (model_idx, predictions) in all_predictions.iter().enumerate() {
657 if i < predictions.len() {
658 weighted_sum += predictions[i] * weights[model_idx];
659 }
660 }
661 ensemble_predictions[i] = weighted_sum;
662 }
663 }
664 EnsembleStrategy::Stacking { .. } => {
665 for i in 0..n_samples {
667 let mut weighted_sum = 0.0;
668 for (model_idx, predictions) in all_predictions.iter().enumerate() {
669 if i < predictions.len() {
670 weighted_sum += predictions[i] * weights[model_idx];
671 }
672 }
673 ensemble_predictions[i] = weighted_sum;
674 }
675 }
676 EnsembleStrategy::Blending { .. } => {
677 for i in 0..n_samples {
679 let mut weighted_sum = 0.0;
680 for (model_idx, predictions) in all_predictions.iter().enumerate() {
681 if i < predictions.len() {
682 weighted_sum += predictions[i] * weights[model_idx];
683 }
684 }
685 ensemble_predictions[i] = weighted_sum;
686 }
687 }
688 EnsembleStrategy::DynamicSelection => {
689 for i in 0..n_samples {
691 if all_predictions[0].len() > i {
692 let mut best_pred = all_predictions[0][i];
693 let mut best_weight = weights[0];
694
695 for (model_idx, predictions) in all_predictions.iter().enumerate() {
696 if predictions.len() > i && weights[model_idx] > best_weight {
697 best_pred = predictions[i];
698 best_weight = weights[model_idx];
699 }
700 }
701 ensemble_predictions[i] = best_pred;
702 }
703 }
704 }
705 }
706
707 Ok(ensemble_predictions)
708 }
709
710 fn calculate_diversity_measures<E, X, Y>(
712 &self,
713 _models: &[(E, String)],
714 _selected_models: &[ModelInfo],
715 _x: &[X],
716 _y: &[Y],
717 ) -> Result<DiversityMeasures>
718 where
719 E: Estimator + Clone,
720 X: Clone,
721 Y: Clone + Into<f64>,
722 {
723 Ok(DiversityMeasures {
726 avg_correlation: 0.3, disagreement: 0.2, q_statistic: 0.1, entropy_diversity: 0.4, })
731 }
732
733 fn calculate_subset_diversity(
735 &self,
736 individual_performances: &[ModelPerformance],
737 subset_indices: &[usize],
738 ) -> f64 {
739 if subset_indices.len() <= 1 {
740 return 0.0;
741 }
742
743 let mut correlations = Vec::new();
745 for i in 0..subset_indices.len() {
746 for j in (i + 1)..subset_indices.len() {
747 let corr1 = individual_performances[subset_indices[i]].avg_correlation;
748 let corr2 = individual_performances[subset_indices[j]].avg_correlation;
749 correlations.push((corr1 + corr2) / 2.0);
750 }
751 }
752
753 if correlations.is_empty() {
754 0.0
755 } else {
756 let avg_correlation = correlations.iter().sum::<f64>() / correlations.len() as f64;
757 1.0 - avg_correlation.abs() }
759 }
760
761 fn estimate_ensemble_improvement(
763 &self,
764 individual_performances: &[ModelPerformance],
765 ensemble_indices: &[usize],
766 ) -> f64 {
767 if ensemble_indices.is_empty() {
768 return 0.0;
769 }
770
771 let avg_score = ensemble_indices
773 .iter()
774 .map(|&idx| individual_performances[idx].cv_score)
775 .sum::<f64>()
776 / ensemble_indices.len() as f64;
777
778 let diversity_bonus =
779 self.calculate_subset_diversity(individual_performances, ensemble_indices) * 0.1;
780
781 avg_score + diversity_bonus
782 }
783
784 fn calculate_std(&self, values: &[f64], mean: f64) -> f64 {
786 if values.len() <= 1 {
787 return 0.0;
788 }
789
790 let variance =
791 values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
792
793 variance.sqrt()
794 }
795
796 fn calculate_correlation(&self, pred1: &[f64], pred2: &[f64]) -> f64 {
798 if pred1.len() != pred2.len() || pred1.is_empty() {
799 return 0.0;
800 }
801
802 let n = pred1.len() as f64;
803 let mean1 = pred1.iter().sum::<f64>() / n;
804 let mean2 = pred2.iter().sum::<f64>() / n;
805
806 let mut numerator = 0.0;
807 let mut sum_sq1 = 0.0;
808 let mut sum_sq2 = 0.0;
809
810 for i in 0..pred1.len() {
811 let diff1 = pred1[i] - mean1;
812 let diff2 = pred2[i] - mean2;
813 numerator += diff1 * diff2;
814 sum_sq1 += diff1 * diff1;
815 sum_sq2 += diff2 * diff2;
816 }
817
818 let denominator = (sum_sq1 * sum_sq2).sqrt();
819 if denominator > 0.0 {
820 numerator / denominator
821 } else {
822 0.0
823 }
824 }
825}
826
827impl Default for EnsembleSelector {
828 fn default() -> Self {
829 Self::new()
830 }
831}
832
833#[derive(Debug, Clone)]
835struct EnsembleCandidate {
836 ensemble_strategy: EnsembleStrategy,
837 selected_models: Vec<ModelInfo>,
838 model_weights: Vec<f64>,
839}
840
841pub fn select_ensemble<E, X, Y>(
843 models: &[(E, String)],
844 x: &[X],
845 y: &[Y],
846 cv: &dyn CrossValidator,
847 scoring: &dyn Scoring,
848 max_size: Option<usize>,
849) -> Result<EnsembleSelectionResult>
850where
851 E: Estimator + Fit<Vec<X>, Vec<Y>> + Clone,
852 E::Fitted: Predict<Vec<X>, Vec<f64>>,
853 X: Clone,
854 Y: Clone + Into<f64>,
855{
856 let mut selector = EnsembleSelector::new();
857 if let Some(size) = max_size {
858 selector = selector.max_ensemble_size(size);
859 }
860 selector.select_ensemble(models, x, y, cv, scoring)
861}
862
863#[allow(non_snake_case)]
864#[cfg(test)]
865mod tests {
866 use super::*;
867 use crate::cross_validation::KFold;
868
869 #[derive(Clone)]
871 struct MockEstimator {
872 performance_level: f64,
873 }
874
875 struct MockTrained {
876 performance_level: f64,
877 }
878
879 impl Estimator for MockEstimator {
880 type Config = ();
881 type Error = SklearsError;
882 type Float = f64;
883
884 fn config(&self) -> &Self::Config {
885 &()
886 }
887 }
888
889 impl Fit<Vec<f64>, Vec<f64>> for MockEstimator {
890 type Fitted = MockTrained;
891
892 fn fit(self, _x: &Vec<f64>, _y: &Vec<f64>) -> Result<Self::Fitted> {
893 Ok(MockTrained {
894 performance_level: self.performance_level,
895 })
896 }
897 }
898
899 impl Predict<Vec<f64>, Vec<f64>> for MockTrained {
900 fn predict(&self, x: &Vec<f64>) -> Result<Vec<f64>> {
901 Ok(x.iter().map(|&xi| xi * self.performance_level).collect())
902 }
903 }
904
905 struct MockScoring;
907
908 impl Scoring for MockScoring {
909 fn score(&self, y_true: &[f64], y_pred: &[f64]) -> Result<f64> {
910 let mse = y_true
911 .iter()
912 .zip(y_pred.iter())
913 .map(|(&true_val, &pred_val)| (true_val - pred_val).powi(2))
914 .sum::<f64>()
915 / y_true.len() as f64;
916 Ok(-mse) }
918 }
919
920 #[test]
921 fn test_ensemble_selector_creation() {
922 let selector = EnsembleSelector::new();
923 assert_eq!(selector.config.max_ensemble_size, 10);
924 assert_eq!(selector.config.min_ensemble_size, 2);
925 assert!(selector.config.use_greedy_selection);
926 }
927
928 #[test]
929 fn test_ensemble_selection() {
930 let models = vec![
931 (
932 MockEstimator {
933 performance_level: 0.8,
934 },
935 "Model A".to_string(),
936 ),
937 (
938 MockEstimator {
939 performance_level: 0.9,
940 },
941 "Model B".to_string(),
942 ),
943 (
944 MockEstimator {
945 performance_level: 0.85,
946 },
947 "Model C".to_string(),
948 ),
949 (
950 MockEstimator {
951 performance_level: 0.75,
952 },
953 "Model D".to_string(),
954 ),
955 ];
956
957 let x: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
958 let y: Vec<f64> = x.iter().map(|&xi| xi * 0.5).collect();
959 let cv = KFold::new(3);
960 let scoring = MockScoring;
961
962 let selector = EnsembleSelector::new().max_ensemble_size(3);
963 let result = selector.select_ensemble(&models, &x, &y, &cv, &scoring);
964
965 assert!(result.is_ok());
966 let result = result.expect("operation should succeed");
967 assert!(result.selected_models.len() >= 2);
968 assert!(result.selected_models.len() <= 3);
969 assert_eq!(result.model_weights.len(), result.selected_models.len());
970 assert!(!result.individual_performances.is_empty());
971 }
972
973 #[test]
974 fn test_different_ensemble_strategies() {
975 let models = vec![
976 (
977 MockEstimator {
978 performance_level: 0.9,
979 },
980 "Good Model".to_string(),
981 ),
982 (
983 MockEstimator {
984 performance_level: 0.8,
985 },
986 "Decent Model".to_string(),
987 ),
988 ];
989
990 let x: Vec<f64> = (0..50).map(|i| i as f64 * 0.05).collect();
991 let y: Vec<f64> = x.iter().map(|&xi| xi * 0.3).collect();
992 let cv = KFold::new(3);
993 let scoring = MockScoring;
994
995 let strategies = vec![
996 EnsembleStrategy::Voting,
997 EnsembleStrategy::WeightedVoting,
998 EnsembleStrategy::BayesianAveraging,
999 ];
1000
1001 let selector = EnsembleSelector::new().strategies(strategies);
1002 let result = selector.select_ensemble(&models, &x, &y, &cv, &scoring);
1003
1004 assert!(result.is_ok());
1005 let result = result.expect("operation should succeed");
1006 assert_eq!(result.selected_models.len(), 2);
1007 }
1008
1009 #[test]
1010 fn test_convenience_function() {
1011 let models = vec![
1012 (
1013 MockEstimator {
1014 performance_level: 0.95,
1015 },
1016 "Best Model".to_string(),
1017 ),
1018 (
1019 MockEstimator {
1020 performance_level: 0.85,
1021 },
1022 "Good Model".to_string(),
1023 ),
1024 (
1025 MockEstimator {
1026 performance_level: 0.8,
1027 },
1028 "Okay Model".to_string(),
1029 ),
1030 ];
1031
1032 let x: Vec<f64> = (0..40).map(|i| i as f64 * 0.1).collect();
1033 let y: Vec<f64> = x.iter().map(|&xi| xi * 0.4).collect();
1034 let cv = KFold::new(3);
1035 let scoring = MockScoring;
1036
1037 let result = select_ensemble(&models, &x, &y, &cv, &scoring, Some(2));
1038 assert!(result.is_ok());
1039
1040 let result = result.expect("operation should succeed");
1041 assert!(result.selected_models.len() <= 2);
1042 assert!(result.ensemble_performance.ensemble_size <= 2);
1043 }
1044
1045 #[test]
1046 fn test_ensemble_strategy_display() {
1047 assert_eq!(format!("{}", EnsembleStrategy::Voting), "Simple Voting");
1048 assert_eq!(
1049 format!("{}", EnsembleStrategy::WeightedVoting),
1050 "Weighted Voting"
1051 );
1052 assert_eq!(
1053 format!(
1054 "{}",
1055 EnsembleStrategy::Stacking {
1056 meta_learner: "Linear".to_string()
1057 }
1058 ),
1059 "Stacking (Linear)"
1060 );
1061 assert_eq!(
1062 format!("{}", EnsembleStrategy::Blending { blend_ratio: 0.2 }),
1063 "Blending (ratio: 0.20)"
1064 );
1065 }
1066
1067 #[test]
1068 fn test_insufficient_models() {
1069 let models = vec![(
1070 MockEstimator {
1071 performance_level: 0.9,
1072 },
1073 "Only Model".to_string(),
1074 )];
1075
1076 let x: Vec<f64> = (0..20).map(|i| i as f64 * 0.1).collect();
1077 let y: Vec<f64> = x.iter().map(|&xi| xi * 0.5).collect();
1078 let cv = KFold::new(3);
1079 let scoring = MockScoring;
1080
1081 let selector = EnsembleSelector::new();
1082 let result = selector.select_ensemble(&models, &x, &y, &cv, &scoring);
1083
1084 assert!(result.is_err());
1085 }
1086}