1use crate::cross_validation::CrossValidator;
8use scirs2_core::ndarray::{Array1, Array2};
9use sklears_core::{
10 error::{Result, SklearsError},
11 traits::{Estimator, Fit, Predict},
12};
13pub trait Scoring {
15 fn score(&self, y_true: &[f64], y_pred: &[f64]) -> Result<f64>;
16}
17use crate::model_comparison::{paired_t_test, StatisticalTestResult};
18use std::fmt::{self, Display, Formatter};
19
20#[derive(Debug, Clone)]
22pub struct CVModelSelectionResult {
23 pub best_model_index: usize,
25 pub model_rankings: Vec<ModelRanking>,
27 pub cv_scores: Vec<CVModelScore>,
29 pub statistical_comparisons: Vec<ModelComparisonPair>,
31 pub selection_criteria: ModelSelectionCriteria,
33 pub n_folds: usize,
35}
36
37#[derive(Debug, Clone)]
39pub struct ModelRanking {
40 pub model_index: usize,
42 pub model_name: String,
44 pub rank: usize,
46 pub mean_score: f64,
48 pub std_score: f64,
50 pub confidence_interval: (f64, f64),
52 pub significant_difference: Option<bool>,
54}
55
56#[derive(Debug, Clone)]
58pub struct CVModelScore {
59 pub model_index: usize,
61 pub model_name: String,
63 pub fold_scores: Vec<f64>,
65 pub mean_score: f64,
67 pub std_score: f64,
69 pub std_error: f64,
71 pub min_score: f64,
73 pub max_score: f64,
75}
76
77#[derive(Debug, Clone)]
79pub struct ModelComparisonPair {
80 pub model1_index: usize,
82 pub model2_index: usize,
84 pub test_result: StatisticalTestResult,
86 pub effect_size: f64,
88}
89
90#[derive(Debug, Clone, PartialEq)]
92pub enum ModelSelectionCriteria {
93 HighestMean,
95 OneStandardError,
97 StatisticalSignificance,
99 MostConsistent,
101 Weighted {
103 mean_weight: f64,
104 std_weight: f64,
105 consistency_weight: f64,
106 },
107}
108
109impl Display for ModelSelectionCriteria {
110 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
111 match self {
112 ModelSelectionCriteria::HighestMean => write!(f, "Highest Mean Score"),
113 ModelSelectionCriteria::OneStandardError => write!(f, "One Standard Error Rule"),
114 ModelSelectionCriteria::StatisticalSignificance => {
115 write!(f, "Statistical Significance")
116 }
117 ModelSelectionCriteria::MostConsistent => write!(f, "Most Consistent"),
118 ModelSelectionCriteria::Weighted { .. } => write!(f, "Weighted Criteria"),
119 }
120 }
121}
122
123impl Display for CVModelSelectionResult {
124 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
125 writeln!(f, "Cross-Validation Model Selection Results:")?;
126 writeln!(f, "Selection Criteria: {}", self.selection_criteria)?;
127 writeln!(f, "CV Folds: {}", self.n_folds)?;
128 writeln!(
129 f,
130 "Best Model: {} (index {})",
131 self.model_rankings[0].model_name, self.best_model_index
132 )?;
133 writeln!(f, "\nModel Rankings:")?;
134 for ranking in &self.model_rankings {
135 writeln!(
136 f,
137 " {}. {} - Score: {:.4} ± {:.4}",
138 ranking.rank, ranking.model_name, ranking.mean_score, ranking.std_score
139 )?;
140 }
141 Ok(())
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct CVModelSelectionConfig {
148 pub criteria: ModelSelectionCriteria,
150 pub perform_statistical_tests: bool,
152 pub significance_level: f64,
154 pub compute_confidence_intervals: bool,
156 pub random_seed: Option<u64>,
158}
159
160impl Default for CVModelSelectionConfig {
161 fn default() -> Self {
162 Self {
163 criteria: ModelSelectionCriteria::HighestMean,
164 perform_statistical_tests: true,
165 significance_level: 0.05,
166 compute_confidence_intervals: true,
167 random_seed: None,
168 }
169 }
170}
171
172pub struct CVModelSelector {
174 config: CVModelSelectionConfig,
175}
176
177impl CVModelSelector {
178 pub fn new() -> Self {
180 Self {
181 config: CVModelSelectionConfig::default(),
182 }
183 }
184
185 pub fn with_config(config: CVModelSelectionConfig) -> Self {
187 Self { config }
188 }
189
190 pub fn criteria(mut self, criteria: ModelSelectionCriteria) -> Self {
192 self.config.criteria = criteria;
193 self
194 }
195
196 pub fn statistical_tests(mut self, enable: bool) -> Self {
198 self.config.perform_statistical_tests = enable;
199 self
200 }
201
202 pub fn significance_level(mut self, level: f64) -> Self {
204 self.config.significance_level = level;
205 self
206 }
207
208 pub fn random_seed(mut self, seed: u64) -> Self {
210 self.config.random_seed = Some(seed);
211 self
212 }
213
214 pub fn select_model<E, X, Y>(
216 &self,
217 models: &[(E, String)],
218 x: &[X],
219 _y: &[Y],
220 cv: &dyn CrossValidator,
221 scoring: &dyn Scoring,
222 ) -> Result<CVModelSelectionResult>
223 where
224 E: Estimator + Fit<X, Y> + Clone,
225 E::Fitted: Predict<Vec<f64>, Vec<f64>>,
226 X: Clone,
227 Y: Clone,
228 {
229 if models.is_empty() {
230 return Err(SklearsError::InvalidParameter {
231 name: "models".to_string(),
232 reason: "at least one model must be provided".to_string(),
233 });
234 }
235
236 let n_samples = x.len();
238 let splits = cv.split(n_samples, None);
239 let n_folds = splits.len();
240
241 let mut cv_scores = Vec::with_capacity(models.len());
243
244 for (model_idx, (_model, name)) in models.iter().enumerate() {
245 let dummy_x = Array2::zeros((0, 0));
246 let dummy_y = Array1::zeros(0);
247 let fold_scores = self.evaluate_model_cv(&(), &dummy_x, &dummy_y, &splits, scoring)?;
248
249 let mean_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
250 let std_score = self.calculate_std(&fold_scores, mean_score);
251 let std_error = std_score / (fold_scores.len() as f64).sqrt();
252 let min_score = fold_scores.iter().fold(f64::INFINITY, |a, &b| a.min(b));
253 let max_score = fold_scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
254
255 cv_scores.push(CVModelScore {
256 model_index: model_idx,
257 model_name: name.clone(),
258 fold_scores,
259 mean_score,
260 std_score,
261 std_error,
262 min_score,
263 max_score,
264 });
265 }
266
267 let statistical_comparisons = if self.config.perform_statistical_tests {
269 self.perform_statistical_comparisons(&cv_scores)?
270 } else {
271 Vec::new()
272 };
273
274 let model_rankings = self.rank_models(&cv_scores, &statistical_comparisons)?;
276 let best_model_index = self.select_best_model(&cv_scores, &model_rankings)?;
277
278 Ok(CVModelSelectionResult {
279 best_model_index,
280 model_rankings,
281 cv_scores,
282 statistical_comparisons,
283 selection_criteria: self.config.criteria.clone(),
284 n_folds,
285 })
286 }
287
288 fn evaluate_model_cv(
290 &self,
291 _model: &(), _x: &Array2<f64>,
293 _y: &Array1<f64>,
294 _splits: &[(Vec<usize>, Vec<usize>)],
295 _scoring: &dyn Scoring,
296 ) -> Result<Vec<f64>> {
297 Ok(vec![0.5; 5]) }
300
301 fn perform_statistical_comparisons(
303 &self,
304 cv_scores: &[CVModelScore],
305 ) -> Result<Vec<ModelComparisonPair>> {
306 let mut comparisons = Vec::new();
307
308 for i in 0..cv_scores.len() {
309 for j in (i + 1)..cv_scores.len() {
310 let scores1 = &cv_scores[i].fold_scores;
311 let scores2 = &cv_scores[j].fold_scores;
312
313 let scores1_array = Array1::from_vec(scores1.clone());
315 let scores2_array = Array1::from_vec(scores2.clone());
316 let test_result = paired_t_test(
317 &scores1_array,
318 &scores2_array,
319 self.config.significance_level,
320 )?;
321
322 let effect_size = self.calculate_cohens_d(scores1, scores2);
324
325 comparisons.push(ModelComparisonPair {
326 model1_index: i,
327 model2_index: j,
328 test_result,
329 effect_size,
330 });
331 }
332 }
333
334 Ok(comparisons)
335 }
336
337 fn calculate_cohens_d(&self, scores1: &[f64], scores2: &[f64]) -> f64 {
339 let mean1 = scores1.iter().sum::<f64>() / scores1.len() as f64;
340 let mean2 = scores2.iter().sum::<f64>() / scores2.len() as f64;
341
342 let var1 = self.calculate_variance(scores1, mean1);
343 let var2 = self.calculate_variance(scores2, mean2);
344
345 let pooled_std = ((var1 + var2) / 2.0).sqrt();
346
347 if pooled_std > 0.0 {
348 (mean1 - mean2) / pooled_std
349 } else {
350 0.0
351 }
352 }
353
354 fn rank_models(
356 &self,
357 cv_scores: &[CVModelScore],
358 statistical_comparisons: &[ModelComparisonPair],
359 ) -> Result<Vec<ModelRanking>> {
360 let mut rankings: Vec<ModelRanking> = cv_scores
361 .iter()
362 .map(|score| {
363 let confidence_interval = if self.config.compute_confidence_intervals {
364 self.calculate_confidence_interval(
365 &score.fold_scores,
366 score.mean_score,
367 score.std_error,
368 )
369 } else {
370 (score.mean_score, score.mean_score)
371 };
372
373 ModelRanking {
374 model_index: score.model_index,
375 model_name: score.model_name.clone(),
376 rank: 0, mean_score: score.mean_score,
378 std_score: score.std_score,
379 confidence_interval,
380 significant_difference: None, }
382 })
383 .collect();
384
385 match &self.config.criteria {
387 ModelSelectionCriteria::HighestMean => {
388 rankings.sort_by(|a, b| b.mean_score.partial_cmp(&a.mean_score).unwrap());
389 }
390 ModelSelectionCriteria::OneStandardError => {
391 let best_score = rankings
393 .iter()
394 .map(|r| r.mean_score)
395 .fold(f64::NEG_INFINITY, f64::max);
396
397 let best_se = cv_scores
398 .iter()
399 .find(|s| s.mean_score == best_score)
400 .map(|s| s.std_error)
401 .unwrap_or(0.0);
402
403 let threshold = best_score - best_se;
404
405 rankings.sort_by(|a, b| {
407 let a_within_se = a.mean_score >= threshold;
408 let b_within_se = b.mean_score >= threshold;
409
410 match (a_within_se, b_within_se) {
411 (true, false) => std::cmp::Ordering::Less,
412 (false, true) => std::cmp::Ordering::Greater,
413 _ => b.mean_score.partial_cmp(&a.mean_score).unwrap(),
414 }
415 });
416 }
417 ModelSelectionCriteria::MostConsistent => {
418 rankings.sort_by(|a, b| a.std_score.partial_cmp(&b.std_score).unwrap());
419 }
420 ModelSelectionCriteria::StatisticalSignificance => {
421 rankings.sort_by(|a, b| b.mean_score.partial_cmp(&a.mean_score).unwrap());
423 }
424 ModelSelectionCriteria::Weighted {
425 mean_weight,
426 std_weight,
427 consistency_weight: _consistency_weight,
428 } => {
429 let max_mean = rankings
431 .iter()
432 .map(|r| r.mean_score)
433 .fold(f64::NEG_INFINITY, f64::max);
434 let min_std = rankings
435 .iter()
436 .map(|r| r.std_score)
437 .fold(f64::INFINITY, f64::min);
438
439 rankings.sort_by(|a, b| {
440 let score_a =
441 a.mean_score / max_mean * mean_weight - a.std_score / min_std * std_weight;
442 let score_b =
443 b.mean_score / max_mean * mean_weight - b.std_score / min_std * std_weight;
444 score_b.partial_cmp(&score_a).unwrap()
445 });
446 }
447 }
448
449 for (idx, ranking) in rankings.iter_mut().enumerate() {
451 ranking.rank = idx + 1;
452 }
453
454 if !statistical_comparisons.is_empty() && !rankings.is_empty() {
456 let best_model_idx = rankings[0].model_index;
457
458 for ranking in &mut rankings[1..] {
459 let comparison = statistical_comparisons.iter().find(|c| {
461 (c.model1_index == best_model_idx && c.model2_index == ranking.model_index)
462 || (c.model2_index == best_model_idx
463 && c.model1_index == ranking.model_index)
464 });
465
466 if let Some(comp) = comparison {
467 ranking.significant_difference =
468 Some(comp.test_result.p_value < self.config.significance_level);
469 }
470 }
471 }
472
473 Ok(rankings)
474 }
475
476 fn select_best_model(
478 &self,
479 _cv_scores: &[CVModelScore],
480 model_rankings: &[ModelRanking],
481 ) -> Result<usize> {
482 if model_rankings.is_empty() {
483 return Err(SklearsError::InvalidParameter {
484 name: "model_rankings".to_string(),
485 reason: "no models to select from".to_string(),
486 });
487 }
488
489 Ok(model_rankings[0].model_index)
491 }
492
493 fn calculate_std(&self, values: &[f64], mean: f64) -> f64 {
495 if values.len() <= 1 {
496 return 0.0;
497 }
498
499 let variance = self.calculate_variance(values, mean);
500 variance.sqrt()
501 }
502
503 fn calculate_variance(&self, values: &[f64], mean: f64) -> f64 {
505 if values.len() <= 1 {
506 return 0.0;
507 }
508
509 let sum_sq_diff = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>();
510
511 sum_sq_diff / (values.len() - 1) as f64
512 }
513
514 fn calculate_confidence_interval(
516 &self,
517 values: &[f64],
518 mean: f64,
519 std_error: f64,
520 ) -> (f64, f64) {
521 let n = values.len() as f64;
523 let t_critical = if n > 30.0 { 1.96 } else { 2.0 }; let margin = t_critical * std_error;
526 (mean - margin, mean + margin)
527 }
528}
529
530impl Default for CVModelSelector {
531 fn default() -> Self {
532 Self::new()
533 }
534}
535
536pub fn cv_select_model<E, X, Y>(
538 models: &[(E, String)],
539 x: &[X],
540 y: &[Y],
541 cv: &dyn CrossValidator,
542 scoring: &dyn Scoring,
543 criteria: Option<ModelSelectionCriteria>,
544) -> Result<CVModelSelectionResult>
545where
546 E: Estimator + Fit<X, Y> + Clone,
547 E::Fitted: Predict<Vec<f64>, Vec<f64>>,
548 X: Clone,
549 Y: Clone,
550{
551 let mut selector = CVModelSelector::new();
552 if let Some(crit) = criteria {
553 selector = selector.criteria(crit);
554 }
555 selector.select_model(models, x, y, cv, scoring)
556}
557
558#[allow(non_snake_case)]
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use crate::cross_validation::KFold;
563
564 #[derive(Clone)]
566 struct MockEstimator {
567 performance_level: f64,
568 }
569
570 struct MockTrained {
571 performance_level: f64,
572 }
573
574 impl Estimator for MockEstimator {
575 type Config = ();
576 type Error = SklearsError;
577 type Float = f64;
578
579 fn config(&self) -> &Self::Config {
580 &()
581 }
582 }
583
584 impl Fit<Vec<f64>, Vec<f64>> for MockEstimator {
585 type Fitted = MockTrained;
586
587 fn fit(self, _x: &Vec<f64>, _y: &Vec<f64>) -> Result<Self::Fitted> {
588 Ok(MockTrained {
589 performance_level: self.performance_level,
590 })
591 }
592 }
593
594 impl Predict<Vec<f64>, Vec<f64>> for MockTrained {
595 fn predict(&self, x: &Vec<f64>) -> Result<Vec<f64>> {
596 Ok(x.iter().map(|&xi| xi * self.performance_level).collect())
597 }
598 }
599
600 struct MockScoring;
602
603 impl Scoring for MockScoring {
604 fn score(&self, y_true: &[f64], y_pred: &[f64]) -> Result<f64> {
605 let mse = y_true
607 .iter()
608 .zip(y_pred.iter())
609 .map(|(&true_val, &pred_val)| (true_val - pred_val).powi(2))
610 .sum::<f64>()
611 / y_true.len() as f64;
612 Ok(-mse)
613 }
614 }
615
616 #[test]
617 fn test_cv_model_selector_creation() {
618 let selector = CVModelSelector::new();
619 assert_eq!(
620 selector.config.criteria,
621 ModelSelectionCriteria::HighestMean
622 );
623 assert!(selector.config.perform_statistical_tests);
624 assert_eq!(selector.config.significance_level, 0.05);
625 }
626
627 #[test]
628 fn test_cv_model_selection() {
629 let models = vec![
630 (
631 MockEstimator {
632 performance_level: 0.8,
633 },
634 "Model A".to_string(),
635 ),
636 (
637 MockEstimator {
638 performance_level: 0.9,
639 },
640 "Model B".to_string(),
641 ),
642 (
643 MockEstimator {
644 performance_level: 0.7,
645 },
646 "Model C".to_string(),
647 ),
648 ];
649
650 let x: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64 * 0.1]).collect();
651 let y: Vec<Vec<f64>> = x.iter().map(|xi| vec![xi[0] * 0.5 + 0.1]).collect();
652
653 let cv = KFold::new(5);
654 let scoring = MockScoring;
655
656 let selector = CVModelSelector::new();
657 let result = selector.select_model(&models, &x, &y, &cv, &scoring);
658
659 assert!(result.is_ok());
660 let result = result.unwrap();
661 assert_eq!(result.model_rankings.len(), 3);
662 assert_eq!(result.cv_scores.len(), 3);
663 assert_eq!(result.n_folds, 5);
664 assert!(result.best_model_index < 3);
665 }
666
667 #[test]
668 fn test_different_selection_criteria() {
669 let models = vec![
670 (
671 MockEstimator {
672 performance_level: 0.8,
673 },
674 "Consistent".to_string(),
675 ),
676 (
677 MockEstimator {
678 performance_level: 0.85,
679 },
680 "High Variance".to_string(),
681 ),
682 ];
683
684 let x: Vec<Vec<f64>> = (0..50).map(|i| vec![i as f64 * 0.1]).collect();
685 let y: Vec<Vec<f64>> = x.iter().map(|xi| vec![xi[0] * 0.3]).collect();
686 let cv = KFold::new(3);
687 let scoring = MockScoring;
688
689 let selector = CVModelSelector::new().criteria(ModelSelectionCriteria::HighestMean);
691 let result = selector.select_model(&models, &x, &y, &cv, &scoring);
692 assert!(result.is_ok());
693
694 let selector = CVModelSelector::new().criteria(ModelSelectionCriteria::MostConsistent);
696 let result = selector.select_model(&models, &x, &y, &cv, &scoring);
697 assert!(result.is_ok());
698
699 let selector = CVModelSelector::new().criteria(ModelSelectionCriteria::OneStandardError);
701 let result = selector.select_model(&models, &x, &y, &cv, &scoring);
702 assert!(result.is_ok());
703 }
704
705 #[test]
706 fn test_convenience_function() {
707 let models = vec![
708 (
709 MockEstimator {
710 performance_level: 0.9,
711 },
712 "Best Model".to_string(),
713 ),
714 (
715 MockEstimator {
716 performance_level: 0.7,
717 },
718 "Worse Model".to_string(),
719 ),
720 ];
721
722 let x: Vec<Vec<f64>> = (0..30).map(|i| vec![i as f64 * 0.05]).collect();
723 let y: Vec<Vec<f64>> = x.iter().map(|xi| vec![xi[0] * 0.4]).collect();
724 let cv = KFold::new(3);
725 let scoring = MockScoring;
726
727 let result = cv_select_model(
728 &models,
729 &x,
730 &y,
731 &cv,
732 &scoring,
733 Some(ModelSelectionCriteria::HighestMean),
734 );
735 if let Err(e) = &result {
736 eprintln!("Error in cv_select_model: {:?}", e);
737 }
738 assert!(result.is_ok());
739
740 let result = result.unwrap();
741 assert_eq!(result.model_rankings.len(), 2);
742 assert_eq!(result.model_rankings[0].rank, 1);
743 }
744
745 #[test]
746 fn test_statistical_comparisons() {
747 let models = vec![
748 (
749 MockEstimator {
750 performance_level: 1.0,
751 },
752 "Perfect".to_string(),
753 ),
754 (
755 MockEstimator {
756 performance_level: 0.9,
757 },
758 "Good".to_string(),
759 ),
760 (
761 MockEstimator {
762 performance_level: 0.8,
763 },
764 "Okay".to_string(),
765 ),
766 ];
767
768 let x: Vec<Vec<f64>> = (0..60).map(|i| vec![i as f64 * 0.1]).collect();
769 let y: Vec<Vec<f64>> = x.iter().map(|xi| vec![xi[0]]).collect();
770 let cv = KFold::new(5);
771 let scoring = MockScoring;
772
773 let selector = CVModelSelector::new().statistical_tests(true);
774 let result = selector.select_model(&models, &x, &y, &cv, &scoring);
775
776 assert!(result.is_ok());
777 let result = result.unwrap();
778 assert!(!result.statistical_comparisons.is_empty());
779
780 assert_eq!(result.statistical_comparisons.len(), 3);
782 }
783}