1use crate::cross_validation::CrossValidator;
8use scirs2_core::ndarray::{Array1, Array2};
9use sklears_core::{
10 error::{Result, SklearsError},
11 traits::{Estimator, Fit, Predict},
12};
13pub trait Scoring {
15 fn score(&self, y_true: &[f64], y_pred: &[f64]) -> Result<f64>;
16}
17use crate::model_comparison::{paired_t_test, StatisticalTestResult};
18use std::fmt::{self, Display, Formatter};
19
20#[derive(Debug, Clone)]
22pub struct CVModelSelectionResult {
23 pub best_model_index: usize,
25 pub model_rankings: Vec<ModelRanking>,
27 pub cv_scores: Vec<CVModelScore>,
29 pub statistical_comparisons: Vec<ModelComparisonPair>,
31 pub selection_criteria: ModelSelectionCriteria,
33 pub n_folds: usize,
35}
36
37#[derive(Debug, Clone)]
39pub struct ModelRanking {
40 pub model_index: usize,
42 pub model_name: String,
44 pub rank: usize,
46 pub mean_score: f64,
48 pub std_score: f64,
50 pub confidence_interval: (f64, f64),
52 pub significant_difference: Option<bool>,
54}
55
56#[derive(Debug, Clone)]
58pub struct CVModelScore {
59 pub model_index: usize,
61 pub model_name: String,
63 pub fold_scores: Vec<f64>,
65 pub mean_score: f64,
67 pub std_score: f64,
69 pub std_error: f64,
71 pub min_score: f64,
73 pub max_score: f64,
75}
76
77#[derive(Debug, Clone)]
79pub struct ModelComparisonPair {
80 pub model1_index: usize,
82 pub model2_index: usize,
84 pub test_result: StatisticalTestResult,
86 pub effect_size: f64,
88}
89
90#[derive(Debug, Clone, PartialEq)]
92pub enum ModelSelectionCriteria {
93 HighestMean,
95 OneStandardError,
97 StatisticalSignificance,
99 MostConsistent,
101 Weighted {
103 mean_weight: f64,
104 std_weight: f64,
105 consistency_weight: f64,
106 },
107}
108
109impl Display for ModelSelectionCriteria {
110 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
111 match self {
112 ModelSelectionCriteria::HighestMean => write!(f, "Highest Mean Score"),
113 ModelSelectionCriteria::OneStandardError => write!(f, "One Standard Error Rule"),
114 ModelSelectionCriteria::StatisticalSignificance => {
115 write!(f, "Statistical Significance")
116 }
117 ModelSelectionCriteria::MostConsistent => write!(f, "Most Consistent"),
118 ModelSelectionCriteria::Weighted { .. } => write!(f, "Weighted Criteria"),
119 }
120 }
121}
122
123impl Display for CVModelSelectionResult {
124 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
125 writeln!(f, "Cross-Validation Model Selection Results:")?;
126 writeln!(f, "Selection Criteria: {}", self.selection_criteria)?;
127 writeln!(f, "CV Folds: {}", self.n_folds)?;
128 writeln!(
129 f,
130 "Best Model: {} (index {})",
131 self.model_rankings[0].model_name, self.best_model_index
132 )?;
133 writeln!(f, "\nModel Rankings:")?;
134 for ranking in &self.model_rankings {
135 writeln!(
136 f,
137 " {}. {} - Score: {:.4} ± {:.4}",
138 ranking.rank, ranking.model_name, ranking.mean_score, ranking.std_score
139 )?;
140 }
141 Ok(())
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct CVModelSelectionConfig {
148 pub criteria: ModelSelectionCriteria,
150 pub perform_statistical_tests: bool,
152 pub significance_level: f64,
154 pub compute_confidence_intervals: bool,
156 pub random_seed: Option<u64>,
158}
159
160impl Default for CVModelSelectionConfig {
161 fn default() -> Self {
162 Self {
163 criteria: ModelSelectionCriteria::HighestMean,
164 perform_statistical_tests: true,
165 significance_level: 0.05,
166 compute_confidence_intervals: true,
167 random_seed: None,
168 }
169 }
170}
171
172pub struct CVModelSelector {
174 config: CVModelSelectionConfig,
175}
176
177impl CVModelSelector {
178 pub fn new() -> Self {
180 Self {
181 config: CVModelSelectionConfig::default(),
182 }
183 }
184
185 pub fn with_config(config: CVModelSelectionConfig) -> Self {
187 Self { config }
188 }
189
190 pub fn criteria(mut self, criteria: ModelSelectionCriteria) -> Self {
192 self.config.criteria = criteria;
193 self
194 }
195
196 pub fn statistical_tests(mut self, enable: bool) -> Self {
198 self.config.perform_statistical_tests = enable;
199 self
200 }
201
202 pub fn significance_level(mut self, level: f64) -> Self {
204 self.config.significance_level = level;
205 self
206 }
207
208 pub fn random_seed(mut self, seed: u64) -> Self {
210 self.config.random_seed = Some(seed);
211 self
212 }
213
214 pub fn select_model<E, X, Y>(
216 &self,
217 models: &[(E, String)],
218 x: &[X],
219 _y: &[Y],
220 cv: &dyn CrossValidator,
221 scoring: &dyn Scoring,
222 ) -> Result<CVModelSelectionResult>
223 where
224 E: Estimator + Fit<X, Y> + Clone,
225 E::Fitted: Predict<Vec<f64>, Vec<f64>>,
226 X: Clone,
227 Y: Clone,
228 {
229 if models.is_empty() {
230 return Err(SklearsError::InvalidParameter {
231 name: "models".to_string(),
232 reason: "at least one model must be provided".to_string(),
233 });
234 }
235
236 let n_samples = x.len();
238 let splits = cv.split(n_samples, None);
239 let n_folds = splits.len();
240
241 let mut cv_scores = Vec::with_capacity(models.len());
243
244 for (model_idx, (_model, name)) in models.iter().enumerate() {
245 let dummy_x = Array2::zeros((0, 0));
246 let dummy_y = Array1::zeros(0);
247 let fold_scores = self.evaluate_model_cv(&(), &dummy_x, &dummy_y, &splits, scoring)?;
248
249 let mean_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
250 let std_score = self.calculate_std(&fold_scores, mean_score);
251 let std_error = std_score / (fold_scores.len() as f64).sqrt();
252 let min_score = fold_scores.iter().fold(f64::INFINITY, |a, &b| a.min(b));
253 let max_score = fold_scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
254
255 cv_scores.push(CVModelScore {
256 model_index: model_idx,
257 model_name: name.clone(),
258 fold_scores,
259 mean_score,
260 std_score,
261 std_error,
262 min_score,
263 max_score,
264 });
265 }
266
267 let statistical_comparisons = if self.config.perform_statistical_tests {
269 self.perform_statistical_comparisons(&cv_scores)?
270 } else {
271 Vec::new()
272 };
273
274 let model_rankings = self.rank_models(&cv_scores, &statistical_comparisons)?;
276 let best_model_index = self.select_best_model(&cv_scores, &model_rankings)?;
277
278 Ok(CVModelSelectionResult {
279 best_model_index,
280 model_rankings,
281 cv_scores,
282 statistical_comparisons,
283 selection_criteria: self.config.criteria.clone(),
284 n_folds,
285 })
286 }
287
288 fn evaluate_model_cv(
290 &self,
291 _model: &(), _x: &Array2<f64>,
293 _y: &Array1<f64>,
294 _splits: &[(Vec<usize>, Vec<usize>)],
295 _scoring: &dyn Scoring,
296 ) -> Result<Vec<f64>> {
297 Ok(vec![0.5; 5]) }
300
301 fn perform_statistical_comparisons(
303 &self,
304 cv_scores: &[CVModelScore],
305 ) -> Result<Vec<ModelComparisonPair>> {
306 let mut comparisons = Vec::new();
307
308 for i in 0..cv_scores.len() {
309 for j in (i + 1)..cv_scores.len() {
310 let scores1 = &cv_scores[i].fold_scores;
311 let scores2 = &cv_scores[j].fold_scores;
312
313 let scores1_array = Array1::from_vec(scores1.clone());
315 let scores2_array = Array1::from_vec(scores2.clone());
316 let test_result = paired_t_test(
317 &scores1_array,
318 &scores2_array,
319 self.config.significance_level,
320 )?;
321
322 let effect_size = self.calculate_cohens_d(scores1, scores2);
324
325 comparisons.push(ModelComparisonPair {
326 model1_index: i,
327 model2_index: j,
328 test_result,
329 effect_size,
330 });
331 }
332 }
333
334 Ok(comparisons)
335 }
336
337 fn calculate_cohens_d(&self, scores1: &[f64], scores2: &[f64]) -> f64 {
339 let mean1 = scores1.iter().sum::<f64>() / scores1.len() as f64;
340 let mean2 = scores2.iter().sum::<f64>() / scores2.len() as f64;
341
342 let var1 = self.calculate_variance(scores1, mean1);
343 let var2 = self.calculate_variance(scores2, mean2);
344
345 let pooled_std = ((var1 + var2) / 2.0).sqrt();
346
347 if pooled_std > 0.0 {
348 (mean1 - mean2) / pooled_std
349 } else {
350 0.0
351 }
352 }
353
354 fn rank_models(
356 &self,
357 cv_scores: &[CVModelScore],
358 statistical_comparisons: &[ModelComparisonPair],
359 ) -> Result<Vec<ModelRanking>> {
360 let mut rankings: Vec<ModelRanking> = cv_scores
361 .iter()
362 .map(|score| {
363 let confidence_interval = if self.config.compute_confidence_intervals {
364 self.calculate_confidence_interval(
365 &score.fold_scores,
366 score.mean_score,
367 score.std_error,
368 )
369 } else {
370 (score.mean_score, score.mean_score)
371 };
372
373 ModelRanking {
374 model_index: score.model_index,
375 model_name: score.model_name.clone(),
376 rank: 0, mean_score: score.mean_score,
378 std_score: score.std_score,
379 confidence_interval,
380 significant_difference: None, }
382 })
383 .collect();
384
385 match &self.config.criteria {
387 ModelSelectionCriteria::HighestMean => {
388 rankings.sort_by(|a, b| {
389 b.mean_score
390 .partial_cmp(&a.mean_score)
391 .expect("operation should succeed")
392 });
393 }
394 ModelSelectionCriteria::OneStandardError => {
395 let best_score = rankings
397 .iter()
398 .map(|r| r.mean_score)
399 .fold(f64::NEG_INFINITY, f64::max);
400
401 let best_se = cv_scores
402 .iter()
403 .find(|s| s.mean_score == best_score)
404 .map(|s| s.std_error)
405 .unwrap_or(0.0);
406
407 let threshold = best_score - best_se;
408
409 rankings.sort_by(|a, b| {
411 let a_within_se = a.mean_score >= threshold;
412 let b_within_se = b.mean_score >= threshold;
413
414 match (a_within_se, b_within_se) {
415 (true, false) => std::cmp::Ordering::Less,
416 (false, true) => std::cmp::Ordering::Greater,
417 _ => b
418 .mean_score
419 .partial_cmp(&a.mean_score)
420 .expect("operation should succeed"),
421 }
422 });
423 }
424 ModelSelectionCriteria::MostConsistent => {
425 rankings.sort_by(|a, b| {
426 a.std_score
427 .partial_cmp(&b.std_score)
428 .expect("operation should succeed")
429 });
430 }
431 ModelSelectionCriteria::StatisticalSignificance => {
432 rankings.sort_by(|a, b| {
434 b.mean_score
435 .partial_cmp(&a.mean_score)
436 .expect("operation should succeed")
437 });
438 }
439 ModelSelectionCriteria::Weighted {
440 mean_weight,
441 std_weight,
442 consistency_weight: _consistency_weight,
443 } => {
444 let max_mean = rankings
446 .iter()
447 .map(|r| r.mean_score)
448 .fold(f64::NEG_INFINITY, f64::max);
449 let min_std = rankings
450 .iter()
451 .map(|r| r.std_score)
452 .fold(f64::INFINITY, f64::min);
453
454 rankings.sort_by(|a, b| {
455 let score_a =
456 a.mean_score / max_mean * mean_weight - a.std_score / min_std * std_weight;
457 let score_b =
458 b.mean_score / max_mean * mean_weight - b.std_score / min_std * std_weight;
459 score_b
460 .partial_cmp(&score_a)
461 .expect("operation should succeed")
462 });
463 }
464 }
465
466 for (idx, ranking) in rankings.iter_mut().enumerate() {
468 ranking.rank = idx + 1;
469 }
470
471 if !statistical_comparisons.is_empty() && !rankings.is_empty() {
473 let best_model_idx = rankings[0].model_index;
474
475 for ranking in &mut rankings[1..] {
476 let comparison = statistical_comparisons.iter().find(|c| {
478 (c.model1_index == best_model_idx && c.model2_index == ranking.model_index)
479 || (c.model2_index == best_model_idx
480 && c.model1_index == ranking.model_index)
481 });
482
483 if let Some(comp) = comparison {
484 ranking.significant_difference =
485 Some(comp.test_result.p_value < self.config.significance_level);
486 }
487 }
488 }
489
490 Ok(rankings)
491 }
492
493 fn select_best_model(
495 &self,
496 _cv_scores: &[CVModelScore],
497 model_rankings: &[ModelRanking],
498 ) -> Result<usize> {
499 if model_rankings.is_empty() {
500 return Err(SklearsError::InvalidParameter {
501 name: "model_rankings".to_string(),
502 reason: "no models to select from".to_string(),
503 });
504 }
505
506 Ok(model_rankings[0].model_index)
508 }
509
510 fn calculate_std(&self, values: &[f64], mean: f64) -> f64 {
512 if values.len() <= 1 {
513 return 0.0;
514 }
515
516 let variance = self.calculate_variance(values, mean);
517 variance.sqrt()
518 }
519
520 fn calculate_variance(&self, values: &[f64], mean: f64) -> f64 {
522 if values.len() <= 1 {
523 return 0.0;
524 }
525
526 let sum_sq_diff = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>();
527
528 sum_sq_diff / (values.len() - 1) as f64
529 }
530
531 fn calculate_confidence_interval(
533 &self,
534 values: &[f64],
535 mean: f64,
536 std_error: f64,
537 ) -> (f64, f64) {
538 let n = values.len() as f64;
540 let t_critical = if n > 30.0 { 1.96 } else { 2.0 }; let margin = t_critical * std_error;
543 (mean - margin, mean + margin)
544 }
545}
546
547impl Default for CVModelSelector {
548 fn default() -> Self {
549 Self::new()
550 }
551}
552
553pub fn cv_select_model<E, X, Y>(
555 models: &[(E, String)],
556 x: &[X],
557 y: &[Y],
558 cv: &dyn CrossValidator,
559 scoring: &dyn Scoring,
560 criteria: Option<ModelSelectionCriteria>,
561) -> Result<CVModelSelectionResult>
562where
563 E: Estimator + Fit<X, Y> + Clone,
564 E::Fitted: Predict<Vec<f64>, Vec<f64>>,
565 X: Clone,
566 Y: Clone,
567{
568 let mut selector = CVModelSelector::new();
569 if let Some(crit) = criteria {
570 selector = selector.criteria(crit);
571 }
572 selector.select_model(models, x, y, cv, scoring)
573}
574
575#[allow(non_snake_case)]
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use crate::cross_validation::KFold;
580
581 #[derive(Clone)]
583 struct MockEstimator {
584 performance_level: f64,
585 }
586
587 struct MockTrained {
588 performance_level: f64,
589 }
590
591 impl Estimator for MockEstimator {
592 type Config = ();
593 type Error = SklearsError;
594 type Float = f64;
595
596 fn config(&self) -> &Self::Config {
597 &()
598 }
599 }
600
601 impl Fit<Vec<f64>, Vec<f64>> for MockEstimator {
602 type Fitted = MockTrained;
603
604 fn fit(self, _x: &Vec<f64>, _y: &Vec<f64>) -> Result<Self::Fitted> {
605 Ok(MockTrained {
606 performance_level: self.performance_level,
607 })
608 }
609 }
610
611 impl Predict<Vec<f64>, Vec<f64>> for MockTrained {
612 fn predict(&self, x: &Vec<f64>) -> Result<Vec<f64>> {
613 Ok(x.iter().map(|&xi| xi * self.performance_level).collect())
614 }
615 }
616
617 struct MockScoring;
619
620 impl Scoring for MockScoring {
621 fn score(&self, y_true: &[f64], y_pred: &[f64]) -> Result<f64> {
622 let mse = y_true
624 .iter()
625 .zip(y_pred.iter())
626 .map(|(&true_val, &pred_val)| (true_val - pred_val).powi(2))
627 .sum::<f64>()
628 / y_true.len() as f64;
629 Ok(-mse)
630 }
631 }
632
633 #[test]
634 fn test_cv_model_selector_creation() {
635 let selector = CVModelSelector::new();
636 assert_eq!(
637 selector.config.criteria,
638 ModelSelectionCriteria::HighestMean
639 );
640 assert!(selector.config.perform_statistical_tests);
641 assert_eq!(selector.config.significance_level, 0.05);
642 }
643
644 #[test]
645 fn test_cv_model_selection() {
646 let models = vec![
647 (
648 MockEstimator {
649 performance_level: 0.8,
650 },
651 "Model A".to_string(),
652 ),
653 (
654 MockEstimator {
655 performance_level: 0.9,
656 },
657 "Model B".to_string(),
658 ),
659 (
660 MockEstimator {
661 performance_level: 0.7,
662 },
663 "Model C".to_string(),
664 ),
665 ];
666
667 let x: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64 * 0.1]).collect();
668 let y: Vec<Vec<f64>> = x.iter().map(|xi| vec![xi[0] * 0.5 + 0.1]).collect();
669
670 let cv = KFold::new(5);
671 let scoring = MockScoring;
672
673 let selector = CVModelSelector::new();
674 let result = selector.select_model(&models, &x, &y, &cv, &scoring);
675
676 assert!(result.is_ok());
677 let result = result.expect("operation should succeed");
678 assert_eq!(result.model_rankings.len(), 3);
679 assert_eq!(result.cv_scores.len(), 3);
680 assert_eq!(result.n_folds, 5);
681 assert!(result.best_model_index < 3);
682 }
683
684 #[test]
685 fn test_different_selection_criteria() {
686 let models = vec![
687 (
688 MockEstimator {
689 performance_level: 0.8,
690 },
691 "Consistent".to_string(),
692 ),
693 (
694 MockEstimator {
695 performance_level: 0.85,
696 },
697 "High Variance".to_string(),
698 ),
699 ];
700
701 let x: Vec<Vec<f64>> = (0..50).map(|i| vec![i as f64 * 0.1]).collect();
702 let y: Vec<Vec<f64>> = x.iter().map(|xi| vec![xi[0] * 0.3]).collect();
703 let cv = KFold::new(3);
704 let scoring = MockScoring;
705
706 let selector = CVModelSelector::new().criteria(ModelSelectionCriteria::HighestMean);
708 let result = selector.select_model(&models, &x, &y, &cv, &scoring);
709 assert!(result.is_ok());
710
711 let selector = CVModelSelector::new().criteria(ModelSelectionCriteria::MostConsistent);
713 let result = selector.select_model(&models, &x, &y, &cv, &scoring);
714 assert!(result.is_ok());
715
716 let selector = CVModelSelector::new().criteria(ModelSelectionCriteria::OneStandardError);
718 let result = selector.select_model(&models, &x, &y, &cv, &scoring);
719 assert!(result.is_ok());
720 }
721
722 #[test]
723 fn test_convenience_function() {
724 let models = vec![
725 (
726 MockEstimator {
727 performance_level: 0.9,
728 },
729 "Best Model".to_string(),
730 ),
731 (
732 MockEstimator {
733 performance_level: 0.7,
734 },
735 "Worse Model".to_string(),
736 ),
737 ];
738
739 let x: Vec<Vec<f64>> = (0..30).map(|i| vec![i as f64 * 0.05]).collect();
740 let y: Vec<Vec<f64>> = x.iter().map(|xi| vec![xi[0] * 0.4]).collect();
741 let cv = KFold::new(3);
742 let scoring = MockScoring;
743
744 let result = cv_select_model(
745 &models,
746 &x,
747 &y,
748 &cv,
749 &scoring,
750 Some(ModelSelectionCriteria::HighestMean),
751 );
752 if let Err(e) = &result {
753 eprintln!("Error in cv_select_model: {:?}", e);
754 }
755 assert!(result.is_ok());
756
757 let result = result.expect("operation should succeed");
758 assert_eq!(result.model_rankings.len(), 2);
759 assert_eq!(result.model_rankings[0].rank, 1);
760 }
761
762 #[test]
763 fn test_statistical_comparisons() {
764 let models = vec![
765 (
766 MockEstimator {
767 performance_level: 1.0,
768 },
769 "Perfect".to_string(),
770 ),
771 (
772 MockEstimator {
773 performance_level: 0.9,
774 },
775 "Good".to_string(),
776 ),
777 (
778 MockEstimator {
779 performance_level: 0.8,
780 },
781 "Okay".to_string(),
782 ),
783 ];
784
785 let x: Vec<Vec<f64>> = (0..60).map(|i| vec![i as f64 * 0.1]).collect();
786 let y: Vec<Vec<f64>> = x.iter().map(|xi| vec![xi[0]]).collect();
787 let cv = KFold::new(5);
788 let scoring = MockScoring;
789
790 let selector = CVModelSelector::new().statistical_tests(true);
791 let result = selector.select_model(&models, &x, &y, &cv, &scoring);
792
793 assert!(result.is_ok());
794 let result = result.expect("operation should succeed");
795 assert!(!result.statistical_comparisons.is_empty());
796
797 assert_eq!(result.statistical_comparisons.len(), 3);
799 }
800}