sklears_feature_selection/evaluation/
cross_validation.rs

1//! Cross-validation strategies for feature selection evaluation
2//!
3//! This module implements comprehensive cross-validation methods specifically designed
4//! for evaluating feature selection algorithms. All implementations follow the SciRS2
5//! policy using scirs2-core for numerical computations.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use sklears_core::error::{Result as SklResult, SklearsError};
9type Result<T> = SklResult<T>;
10use scirs2_core::random::{thread_rng, Rng};
11
12impl From<CrossValidationError> for SklearsError {
13    fn from(err: CrossValidationError) -> Self {
14        SklearsError::FitError(format!("Cross-validation error: {}", err))
15    }
16}
17use std::collections::HashMap;
18use thiserror::Error;
19
20#[derive(Debug, Error)]
21pub enum CrossValidationError {
22    #[error("Insufficient data for cross-validation")]
23    InsufficientData,
24    #[error("Invalid fold configuration")]
25    InvalidFoldConfiguration,
26    #[error("Feature and target length mismatch")]
27    LengthMismatch,
28    #[error("Invalid feature indices")]
29    InvalidFeatureIndices,
30    #[error("Empty feature selection")]
31    EmptyFeatureSelection,
32}
33
34/// Nested cross-validation for feature selection with inner and outer loops
35#[derive(Debug, Clone)]
36pub struct NestedCrossValidation {
37    outer_folds: usize,
38    inner_folds: usize,
39    stratified: bool,
40    random_state: Option<u64>,
41}
42
43impl NestedCrossValidation {
44    /// Create a new nested cross-validation configuration
45    pub fn new(
46        outer_folds: usize,
47        inner_folds: usize,
48        stratified: bool,
49        random_state: Option<u64>,
50    ) -> Self {
51        Self {
52            outer_folds,
53            inner_folds,
54            stratified,
55            random_state,
56        }
57    }
58
59    /// Perform nested cross-validation for feature selection evaluation
60    #[allow(non_snake_case)]
61    pub fn evaluate<F, G>(
62        &self,
63        X: ArrayView2<f64>,
64        y: ArrayView1<f64>,
65        feature_selector: F,
66        performance_evaluator: G,
67    ) -> Result<NestedCVResults>
68    where
69        F: Fn(ArrayView2<f64>, ArrayView1<f64>) -> Result<Vec<usize>> + Copy,
70        G: Fn(
71                ArrayView2<f64>,
72                ArrayView1<f64>,
73                ArrayView2<f64>,
74                ArrayView1<f64>,
75                &[usize],
76            ) -> Result<f64>
77            + Copy,
78    {
79        if X.nrows() != y.len() {
80            return Err(CrossValidationError::LengthMismatch.into());
81        }
82
83        if X.nrows() < self.outer_folds * 2 {
84            return Err(CrossValidationError::InsufficientData.into());
85        }
86
87        let n_samples = X.nrows();
88        let indices: Vec<usize> = (0..n_samples).collect();
89
90        // Create outer fold splits
91        let outer_splits = if self.stratified {
92            self.stratified_k_fold_split(&indices, y, self.outer_folds)?
93        } else {
94            self.k_fold_split(&indices, self.outer_folds)?
95        };
96
97        let mut outer_scores = Vec::with_capacity(self.outer_folds);
98        let mut feature_selection_stability = Vec::new();
99        let mut inner_cv_scores = Vec::new();
100
101        for (outer_fold, (train_idx, test_idx)) in outer_splits.into_iter().enumerate() {
102            // Extract outer training and test sets
103            let X_outer_train = self.extract_samples(X, &train_idx);
104            let y_outer_train = self.extract_targets(y, &train_idx);
105            let X_outer_test = self.extract_samples(X, &test_idx);
106            let y_outer_test = self.extract_targets(y, &test_idx);
107
108            // Inner cross-validation for feature selection
109            let inner_splits = if self.stratified {
110                self.stratified_k_fold_split(
111                    &(0..train_idx.len()).collect::<Vec<_>>(),
112                    y_outer_train.view(),
113                    self.inner_folds,
114                )?
115            } else {
116                self.k_fold_split(&(0..train_idx.len()).collect::<Vec<_>>(), self.inner_folds)?
117            };
118
119            let mut inner_fold_scores = Vec::new();
120            let mut inner_fold_features = Vec::new();
121
122            for (inner_train_idx, inner_val_idx) in inner_splits {
123                // Extract inner training and validation sets
124                let X_inner_train = self.extract_samples(X_outer_train.view(), &inner_train_idx);
125                let y_inner_train = self.extract_targets(y_outer_train.view(), &inner_train_idx);
126                let X_inner_val = self.extract_samples(X_outer_train.view(), &inner_val_idx);
127                let y_inner_val = self.extract_targets(y_outer_train.view(), &inner_val_idx);
128
129                // Feature selection on inner training set
130                let selected_features =
131                    feature_selector(X_inner_train.view(), y_inner_train.view())?;
132
133                if selected_features.is_empty() {
134                    return Err(CrossValidationError::EmptyFeatureSelection.into());
135                }
136
137                // Evaluate on inner validation set
138                let inner_score = performance_evaluator(
139                    X_inner_train.view(),
140                    y_inner_train.view(),
141                    X_inner_val.view(),
142                    y_inner_val.view(),
143                    &selected_features,
144                )?;
145
146                inner_fold_scores.push(inner_score);
147                inner_fold_features.push(selected_features);
148            }
149
150            // Store inner CV results
151            let inner_cv_mean =
152                inner_fold_scores.iter().sum::<f64>() / inner_fold_scores.len() as f64;
153            inner_cv_scores.push(InnerCVResult {
154                outer_fold,
155                inner_scores: inner_fold_scores,
156                mean_score: inner_cv_mean,
157                selected_features: inner_fold_features,
158            });
159
160            // Select features using full outer training set
161            let final_selected_features =
162                feature_selector(X_outer_train.view(), y_outer_train.view())?;
163            feature_selection_stability.push(final_selected_features.clone());
164
165            // Evaluate on outer test set
166            let outer_score = performance_evaluator(
167                X_outer_train.view(),
168                y_outer_train.view(),
169                X_outer_test.view(),
170                y_outer_test.view(),
171                &final_selected_features,
172            )?;
173
174            outer_scores.push(outer_score);
175        }
176
177        // Compute stability metrics
178        let stability_metrics = self.compute_stability_metrics(&feature_selection_stability)?;
179
180        // Compute overall statistics
181        let outer_mean = outer_scores.iter().sum::<f64>() / outer_scores.len() as f64;
182        let outer_std = {
183            let variance = outer_scores
184                .iter()
185                .map(|score| (score - outer_mean).powi(2))
186                .sum::<f64>()
187                / outer_scores.len() as f64;
188            variance.sqrt()
189        };
190
191        let inner_mean = inner_cv_scores
192            .iter()
193            .map(|result| result.mean_score)
194            .sum::<f64>()
195            / inner_cv_scores.len() as f64;
196
197        Ok(NestedCVResults {
198            outer_scores,
199            outer_mean_score: outer_mean,
200            outer_std_score: outer_std,
201            inner_cv_results: inner_cv_scores,
202            inner_mean_score: inner_mean,
203            feature_stability: stability_metrics,
204            n_outer_folds: self.outer_folds,
205            n_inner_folds: self.inner_folds,
206        })
207    }
208
209    /// Create K-fold splits
210    fn k_fold_split(
211        &self,
212        indices: &[usize],
213        n_folds: usize,
214    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
215        if indices.len() < n_folds {
216            return Err(CrossValidationError::InvalidFoldConfiguration.into());
217        }
218
219        let mut shuffled_indices = indices.to_vec();
220
221        // Shuffle if random state is provided
222        if self.random_state.is_some() {
223            self.shuffle_indices(&mut shuffled_indices);
224        }
225
226        let fold_size = indices.len() / n_folds;
227        let remainder = indices.len() % n_folds;
228
229        let mut splits = Vec::new();
230
231        for fold in 0..n_folds {
232            let start = fold * fold_size + fold.min(remainder);
233            let end = start + fold_size + if fold < remainder { 1 } else { 0 };
234
235            let test_indices = shuffled_indices[start..end].to_vec();
236            let train_indices: Vec<usize> = shuffled_indices[..start]
237                .iter()
238                .chain(shuffled_indices[end..].iter())
239                .cloned()
240                .collect();
241
242            splits.push((train_indices, test_indices));
243        }
244
245        Ok(splits)
246    }
247
248    /// Create stratified K-fold splits
249    fn stratified_k_fold_split(
250        &self,
251        indices: &[usize],
252        y: ArrayView1<f64>,
253        n_folds: usize,
254    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
255        if indices.len() < n_folds {
256            return Err(CrossValidationError::InvalidFoldConfiguration.into());
257        }
258
259        // Group indices by class
260        let mut class_groups: HashMap<i32, Vec<usize>> = HashMap::new();
261        for &idx in indices {
262            let class = y[idx] as i32;
263            class_groups.entry(class).or_default().push(idx);
264        }
265
266        // Shuffle each class group
267        if self.random_state.is_some() {
268            for group in class_groups.values_mut() {
269                self.shuffle_indices(group);
270            }
271        }
272
273        // Create folds maintaining class distribution
274        let mut folds: Vec<Vec<usize>> = vec![Vec::new(); n_folds];
275
276        for group in class_groups.values() {
277            let group_fold_size = group.len() / n_folds;
278            let group_remainder = group.len() % n_folds;
279
280            for fold in 0..n_folds {
281                let start = fold * group_fold_size + fold.min(group_remainder);
282                let end = start + group_fold_size + if fold < group_remainder { 1 } else { 0 };
283                folds[fold].extend_from_slice(&group[start..end]);
284            }
285        }
286
287        // Create train/test splits
288        let mut splits = Vec::new();
289        for fold in 0..n_folds {
290            let test_indices = folds[fold].clone();
291            let train_indices: Vec<usize> = folds
292                .iter()
293                .enumerate()
294                .filter(|(i, _)| *i != fold)
295                .flat_map(|(_, fold_indices)| fold_indices.iter())
296                .cloned()
297                .collect();
298
299            splits.push((train_indices, test_indices));
300        }
301
302        Ok(splits)
303    }
304
305    /// Simple shuffle implementation
306    fn shuffle_indices(&self, indices: &mut [usize]) {
307        for i in (1..indices.len()).rev() {
308            let j = (thread_rng().gen::<f64>() * (i + 1) as f64) as usize;
309            indices.swap(i, j);
310        }
311    }
312
313    /// Extract samples by indices
314    fn extract_samples(&self, X: ArrayView2<f64>, indices: &[usize]) -> Array2<f64> {
315        let mut samples = Array2::zeros((indices.len(), X.ncols()));
316        for (i, &idx) in indices.iter().enumerate() {
317            samples.row_mut(i).assign(&X.row(idx));
318        }
319        samples
320    }
321
322    /// Extract targets by indices
323    fn extract_targets(&self, y: ArrayView1<f64>, indices: &[usize]) -> Array1<f64> {
324        let mut targets = Array1::zeros(indices.len());
325        for (i, &idx) in indices.iter().enumerate() {
326            targets[i] = y[idx];
327        }
328        targets
329    }
330
331    /// Compute feature selection stability metrics
332    fn compute_stability_metrics(
333        &self,
334        feature_selections: &[Vec<usize>],
335    ) -> Result<FeatureStabilityMetrics> {
336        if feature_selections.is_empty() {
337            return Ok(FeatureStabilityMetrics {
338                jaccard_similarity: 0.0,
339                intersection_stability: 0.0,
340                average_selection_size: 0.0,
341                unique_features_selected: 0,
342                feature_frequencies: Vec::new(),
343            });
344        }
345
346        // Compute pairwise Jaccard similarities
347        let mut jaccard_similarities = Vec::new();
348        for i in 0..feature_selections.len() {
349            for j in (i + 1)..feature_selections.len() {
350                let set1: std::collections::HashSet<_> = feature_selections[i].iter().collect();
351                let set2: std::collections::HashSet<_> = feature_selections[j].iter().collect();
352
353                let intersection = set1.intersection(&set2).count() as f64;
354                let union = set1.union(&set2).count() as f64;
355
356                let jaccard = if union > 0.0 {
357                    intersection / union
358                } else {
359                    1.0
360                };
361
362                jaccard_similarities.push(jaccard);
363            }
364        }
365
366        let mean_jaccard = if jaccard_similarities.is_empty() {
367            1.0
368        } else {
369            jaccard_similarities.iter().sum::<f64>() / jaccard_similarities.len() as f64
370        };
371
372        // Compute feature frequencies
373        let mut feature_counts: HashMap<usize, usize> = HashMap::new();
374        let mut total_features = 0;
375
376        for selection in feature_selections {
377            total_features += selection.len();
378            for &feature in selection {
379                *feature_counts.entry(feature).or_insert(0) += 1;
380            }
381        }
382
383        let average_selection_size = total_features as f64 / feature_selections.len() as f64;
384
385        let mut feature_frequencies: Vec<(usize, f64)> = feature_counts
386            .into_iter()
387            .map(|(feature, count)| {
388                let frequency = count as f64 / feature_selections.len() as f64;
389                (feature, frequency)
390            })
391            .collect();
392
393        feature_frequencies.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
394
395        // Intersection stability (features selected in all folds)
396        let all_features: std::collections::HashSet<_> = feature_selections[0].iter().collect();
397        let intersection_features =
398            feature_selections
399                .iter()
400                .skip(1)
401                .fold(all_features, |acc, selection| {
402                    let set: std::collections::HashSet<_> = selection.iter().collect();
403                    acc.intersection(&set).cloned().collect()
404                });
405
406        let intersection_stability = intersection_features.len() as f64 / average_selection_size;
407
408        Ok(FeatureStabilityMetrics {
409            jaccard_similarity: mean_jaccard,
410            intersection_stability,
411            average_selection_size,
412            unique_features_selected: feature_frequencies.len(),
413            feature_frequencies,
414        })
415    }
416}
417
418/// Stratified K-Fold cross-validation
419#[derive(Debug, Clone)]
420pub struct StratifiedKFold {
421    n_splits: usize,
422    shuffle: bool,
423    random_state: Option<u64>,
424}
425
426impl StratifiedKFold {
427    /// Create a new stratified K-fold validator
428    pub fn new(n_splits: usize, shuffle: bool, random_state: Option<u64>) -> Self {
429        Self {
430            n_splits,
431            shuffle,
432            random_state,
433        }
434    }
435
436    /// Generate stratified splits
437    pub fn split(
438        &self,
439        X: ArrayView2<f64>,
440        y: ArrayView1<f64>,
441    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
442        if X.nrows() != y.len() {
443            return Err(CrossValidationError::LengthMismatch.into());
444        }
445
446        let indices: Vec<usize> = (0..X.nrows()).collect();
447        self.stratified_split(&indices, y)
448    }
449
450    fn stratified_split(
451        &self,
452        indices: &[usize],
453        y: ArrayView1<f64>,
454    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
455        // Group indices by class
456        let mut class_groups: HashMap<i32, Vec<usize>> = HashMap::new();
457        for &idx in indices {
458            let class = y[idx] as i32;
459            class_groups.entry(class).or_default().push(idx);
460        }
461
462        // Check minimum samples per class
463        for (class, group) in &class_groups {
464            if group.len() < self.n_splits {
465                return Err(SklearsError::InvalidInput(format!(
466                    "Class {} has only {} samples, need at least {}",
467                    class,
468                    group.len(),
469                    self.n_splits
470                )));
471            }
472        }
473
474        // Shuffle within each class if requested
475        if self.shuffle {
476            for group in class_groups.values_mut() {
477                self.shuffle_indices(group);
478            }
479        }
480
481        // Create stratified folds
482        let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.n_splits];
483
484        for group in class_groups.values() {
485            let fold_size = group.len() / self.n_splits;
486            let remainder = group.len() % self.n_splits;
487
488            for fold in 0..self.n_splits {
489                let start = fold * fold_size + fold.min(remainder);
490                let end = start + fold_size + if fold < remainder { 1 } else { 0 };
491                folds[fold].extend_from_slice(&group[start..end]);
492            }
493        }
494
495        // Create train/test splits
496        let mut splits = Vec::new();
497        for fold in 0..self.n_splits {
498            let test_indices = folds[fold].clone();
499            let train_indices: Vec<usize> = folds
500                .iter()
501                .enumerate()
502                .filter(|(i, _)| *i != fold)
503                .flat_map(|(_, fold_indices)| fold_indices.iter())
504                .cloned()
505                .collect();
506
507            splits.push((train_indices, test_indices));
508        }
509
510        Ok(splits)
511    }
512
513    fn shuffle_indices(&self, indices: &mut [usize]) {
514        for i in (1..indices.len()).rev() {
515            let j = (thread_rng().gen::<f64>() * (i + 1) as f64) as usize;
516            indices.swap(i, j);
517        }
518    }
519}
520
521/// Time series cross-validation split
522#[derive(Debug, Clone)]
523pub struct TimeSeriesSplit {
524    n_splits: usize,
525    max_train_size: Option<usize>,
526    test_size: Option<usize>,
527}
528
529impl TimeSeriesSplit {
530    /// Create a new time series splitter
531    pub fn new(n_splits: usize, max_train_size: Option<usize>, test_size: Option<usize>) -> Self {
532        Self {
533            n_splits,
534            max_train_size,
535            test_size,
536        }
537    }
538
539    /// Generate time series splits
540    pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
541        if n_samples < self.n_splits + 1 {
542            return Err(CrossValidationError::InsufficientData.into());
543        }
544
545        let test_size = self.test_size.unwrap_or(n_samples / (self.n_splits + 1));
546        let mut splits = Vec::new();
547
548        for split in 0..self.n_splits {
549            let test_start = (split + 1) * test_size;
550            let test_end = test_start + test_size;
551
552            if test_end > n_samples {
553                break;
554            }
555
556            let train_end = test_start;
557            let train_start = if let Some(max_size) = self.max_train_size {
558                train_end.saturating_sub(max_size)
559            } else {
560                0
561            };
562
563            let train_indices: Vec<usize> = (train_start..train_end).collect();
564            let test_indices: Vec<usize> = (test_start..test_end).collect();
565
566            if !train_indices.is_empty() && !test_indices.is_empty() {
567                splits.push((train_indices, test_indices));
568            }
569        }
570
571        Ok(splits)
572    }
573}
574
575/// Group K-Fold cross-validation
576#[derive(Debug, Clone)]
577pub struct GroupKFold {
578    n_splits: usize,
579}
580
581impl GroupKFold {
582    /// Create a new group K-fold validator
583    pub fn new(n_splits: usize) -> Self {
584        Self { n_splits }
585    }
586
587    /// Generate group-based splits
588    pub fn split(&self, groups: &[usize]) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
589        // Get unique groups
590        let mut unique_groups: Vec<usize> = groups.to_vec();
591        unique_groups.sort_unstable();
592        unique_groups.dedup();
593
594        if unique_groups.len() < self.n_splits {
595            return Err(CrossValidationError::InvalidFoldConfiguration.into());
596        }
597
598        // Create group index mapping
599        let mut group_indices: HashMap<usize, Vec<usize>> = HashMap::new();
600        for (idx, &group) in groups.iter().enumerate() {
601            group_indices.entry(group).or_default().push(idx);
602        }
603
604        // Distribute groups among folds
605        let groups_per_fold = unique_groups.len() / self.n_splits;
606        let remainder = unique_groups.len() % self.n_splits;
607
608        let mut splits = Vec::new();
609
610        for fold in 0..self.n_splits {
611            let start = fold * groups_per_fold + fold.min(remainder);
612            let end = start + groups_per_fold + if fold < remainder { 1 } else { 0 };
613
614            let test_groups = &unique_groups[start..end];
615            let train_groups: Vec<usize> = unique_groups[..start]
616                .iter()
617                .chain(unique_groups[end..].iter())
618                .cloned()
619                .collect();
620
621            let test_indices: Vec<usize> = test_groups
622                .iter()
623                .flat_map(|&group| group_indices[&group].iter())
624                .cloned()
625                .collect();
626
627            let train_indices: Vec<usize> = train_groups
628                .iter()
629                .flat_map(|&group| group_indices[&group].iter())
630                .cloned()
631                .collect();
632
633            splits.push((train_indices, test_indices));
634        }
635
636        Ok(splits)
637    }
638}
639
640/// Repeated K-Fold cross-validation
641#[derive(Debug, Clone)]
642pub struct RepeatedKFold {
643    n_splits: usize,
644    n_repeats: usize,
645    random_state: Option<u64>,
646}
647
648impl RepeatedKFold {
649    /// Create a new repeated K-fold validator
650    pub fn new(n_splits: usize, n_repeats: usize, random_state: Option<u64>) -> Self {
651        Self {
652            n_splits,
653            n_repeats,
654            random_state,
655        }
656    }
657
658    /// Generate repeated K-fold splits
659    pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
660        let mut all_splits = Vec::new();
661
662        for repeat in 0..self.n_repeats {
663            let current_random_state = self.random_state.map(|s| s + repeat as u64);
664
665            // Create K-fold with current random state
666            let indices: Vec<usize> = (0..n_samples).collect();
667            let kfold_splits = self.k_fold_split(&indices, current_random_state)?;
668
669            all_splits.extend(kfold_splits);
670        }
671
672        Ok(all_splits)
673    }
674
675    fn k_fold_split(
676        &self,
677        indices: &[usize],
678        random_state: Option<u64>,
679    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
680        let mut shuffled_indices = indices.to_vec();
681
682        if random_state.is_some() {
683            self.shuffle_indices(&mut shuffled_indices);
684        }
685
686        let fold_size = indices.len() / self.n_splits;
687        let remainder = indices.len() % self.n_splits;
688
689        let mut splits = Vec::new();
690
691        for fold in 0..self.n_splits {
692            let start = fold * fold_size + fold.min(remainder);
693            let end = start + fold_size + if fold < remainder { 1 } else { 0 };
694
695            let test_indices = shuffled_indices[start..end].to_vec();
696            let train_indices: Vec<usize> = shuffled_indices[..start]
697                .iter()
698                .chain(shuffled_indices[end..].iter())
699                .cloned()
700                .collect();
701
702            splits.push((train_indices, test_indices));
703        }
704
705        Ok(splits)
706    }
707
708    fn shuffle_indices(&self, indices: &mut [usize]) {
709        for i in (1..indices.len()).rev() {
710            let j = (thread_rng().gen::<f64>() * (i + 1) as f64) as usize;
711            indices.swap(i, j);
712        }
713    }
714}
715
716/// Results from nested cross-validation
717#[derive(Debug, Clone)]
718pub struct NestedCVResults {
719    pub outer_scores: Vec<f64>,
720    pub outer_mean_score: f64,
721    pub outer_std_score: f64,
722    pub inner_cv_results: Vec<InnerCVResult>,
723    pub inner_mean_score: f64,
724    pub feature_stability: FeatureStabilityMetrics,
725    pub n_outer_folds: usize,
726    pub n_inner_folds: usize,
727}
728
729impl NestedCVResults {
730    /// Generate detailed report
731    pub fn report(&self) -> String {
732        let mut report = String::new();
733
734        report.push_str("=== Nested Cross-Validation Results ===\n\n");
735
736        report.push_str(&format!(
737            "Configuration: {} outer folds, {} inner folds\n\n",
738            self.n_outer_folds, self.n_inner_folds
739        ));
740
741        report.push_str("Outer CV Performance:\n");
742        report.push_str(&format!(
743            "  Mean Score: {:.4} ± {:.4}\n",
744            self.outer_mean_score, self.outer_std_score
745        ));
746        report.push_str(&format!(
747            "  Individual Scores: {:?}\n\n",
748            self.outer_scores
749                .iter()
750                .map(|s| format!("{:.4}", s))
751                .collect::<Vec<_>>()
752        ));
753
754        report.push_str("Inner CV Performance:\n");
755        report.push_str(&format!("  Mean Score: {:.4}\n", self.inner_mean_score));
756
757        for (i, inner_result) in self.inner_cv_results.iter().enumerate() {
758            report.push_str(&format!(
759                "  Outer Fold {}: {:.4} ± {:.4}\n",
760                i,
761                inner_result.mean_score,
762                inner_result.std_score()
763            ));
764        }
765
766        report.push_str("\nFeature Selection Stability:\n");
767        report.push_str(&format!(
768            "  Jaccard Similarity: {:.4}\n",
769            self.feature_stability.jaccard_similarity
770        ));
771        report.push_str(&format!(
772            "  Intersection Stability: {:.4}\n",
773            self.feature_stability.intersection_stability
774        ));
775        report.push_str(&format!(
776            "  Average Selection Size: {:.1}\n",
777            self.feature_stability.average_selection_size
778        ));
779        report.push_str(&format!(
780            "  Unique Features Selected: {}\n",
781            self.feature_stability.unique_features_selected
782        ));
783
784        if !self.feature_stability.feature_frequencies.is_empty() {
785            report.push_str("\nTop 10 Most Frequent Features:\n");
786            for (feature, frequency) in self.feature_stability.feature_frequencies.iter().take(10) {
787                report.push_str(&format!(
788                    "  Feature {}: {:.1}%\n",
789                    feature,
790                    frequency * 100.0
791                ));
792            }
793        }
794
795        report
796    }
797}
798
799/// Inner cross-validation result for one outer fold
800#[derive(Debug, Clone)]
801pub struct InnerCVResult {
802    pub outer_fold: usize,
803    pub inner_scores: Vec<f64>,
804    pub mean_score: f64,
805    pub selected_features: Vec<Vec<usize>>,
806}
807
808impl InnerCVResult {
809    pub fn std_score(&self) -> f64 {
810        if self.inner_scores.len() <= 1 {
811            return 0.0;
812        }
813
814        let variance = self
815            .inner_scores
816            .iter()
817            .map(|score| (score - self.mean_score).powi(2))
818            .sum::<f64>()
819            / self.inner_scores.len() as f64;
820        variance.sqrt()
821    }
822}
823
824/// Feature stability metrics from cross-validation
825#[derive(Debug, Clone)]
826pub struct FeatureStabilityMetrics {
827    pub jaccard_similarity: f64,
828    pub intersection_stability: f64,
829    pub average_selection_size: f64,
830    pub unique_features_selected: usize,
831    pub feature_frequencies: Vec<(usize, f64)>,
832}
833
834#[allow(non_snake_case)]
835#[cfg(test)]
836mod tests {
837    use super::*;
838    use scirs2_core::ndarray::array;
839
840    // Mock feature selector for testing
841    fn mock_feature_selector(X: ArrayView2<f64>, _y: ArrayView1<f64>) -> Result<Vec<usize>> {
842        // Select first half of features
843        let n_features = X.ncols();
844        Ok((0..(n_features / 2)).collect())
845    }
846
847    // Mock performance evaluator for testing
848    fn mock_performance_evaluator(
849        _X_train: ArrayView2<f64>,
850        _y_train: ArrayView1<f64>,
851        _X_test: ArrayView2<f64>,
852        _y_test: ArrayView1<f64>,
853        _features: &[usize],
854    ) -> Result<f64> {
855        // Return random score between 0.7 and 0.9
856        Ok(0.7 + thread_rng().gen::<f64>() * 0.2)
857    }
858
859    #[test]
860    #[allow(non_snake_case)]
861    fn test_nested_cross_validation() {
862        let X = array![
863            [1.0, 2.0, 3.0, 4.0],
864            [2.0, 3.0, 4.0, 5.0],
865            [3.0, 4.0, 5.0, 6.0],
866            [4.0, 5.0, 6.0, 7.0],
867            [5.0, 6.0, 7.0, 8.0],
868            [6.0, 7.0, 8.0, 9.0],
869            [7.0, 8.0, 9.0, 10.0],
870            [8.0, 9.0, 10.0, 11.0],
871        ];
872        let y = array![0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0];
873
874        let nested_cv = NestedCrossValidation::new(3, 2, false, Some(42));
875        let results = nested_cv
876            .evaluate(
877                X.view(),
878                y.view(),
879                mock_feature_selector,
880                mock_performance_evaluator,
881            )
882            .unwrap();
883
884        assert_eq!(results.outer_scores.len(), 3);
885        assert_eq!(results.inner_cv_results.len(), 3);
886        assert!(results.outer_mean_score >= 0.0 && results.outer_mean_score <= 1.0);
887        assert!(results.feature_stability.jaccard_similarity >= 0.0);
888
889        let report = results.report();
890        assert!(report.contains("Nested Cross-Validation"));
891        assert!(report.contains("Feature Selection Stability"));
892    }
893
894    #[test]
895    #[allow(non_snake_case)]
896    fn test_stratified_k_fold() {
897        let X = array![
898            [1.0, 2.0],
899            [2.0, 3.0],
900            [3.0, 4.0],
901            [4.0, 5.0],
902            [5.0, 6.0],
903            [6.0, 7.0],
904        ];
905        let y = array![0.0, 0.0, 1.0, 1.0, 0.0, 1.0];
906
907        let skf = StratifiedKFold::new(3, true, Some(42));
908        let splits = skf.split(X.view(), y.view()).unwrap();
909
910        assert_eq!(splits.len(), 3);
911
912        for (train_idx, test_idx) in splits {
913            assert!(!train_idx.is_empty());
914            assert!(!test_idx.is_empty());
915            assert_eq!(train_idx.len() + test_idx.len(), X.nrows());
916        }
917    }
918
919    #[test]
920    fn test_time_series_split() {
921        let ts_split = TimeSeriesSplit::new(3, None, Some(2));
922        let splits = ts_split.split(10).unwrap();
923
924        assert_eq!(splits.len(), 3);
925
926        for (train_idx, test_idx) in splits {
927            assert!(!train_idx.is_empty());
928            assert_eq!(test_idx.len(), 2);
929
930            // Verify temporal order
931            if !train_idx.is_empty() && !test_idx.is_empty() {
932                let max_train = train_idx.iter().max().unwrap();
933                let min_test = test_idx.iter().min().unwrap();
934                assert!(max_train < min_test);
935            }
936        }
937    }
938
939    #[test]
940    fn test_group_k_fold() {
941        let groups = vec![0, 0, 1, 1, 2, 2];
942        let gkf = GroupKFold::new(3);
943        let splits = gkf.split(&groups).unwrap();
944
945        assert_eq!(splits.len(), 3);
946
947        for (train_idx, test_idx) in splits {
948            assert!(!train_idx.is_empty());
949            assert!(!test_idx.is_empty());
950            assert_eq!(train_idx.len() + test_idx.len(), groups.len());
951        }
952    }
953
954    #[test]
955    fn test_repeated_k_fold() {
956        let rkf = RepeatedKFold::new(3, 2, Some(42));
957        let splits = rkf.split(9).unwrap();
958
959        assert_eq!(splits.len(), 6); // 3 folds * 2 repeats
960
961        for (train_idx, test_idx) in splits {
962            assert!(!train_idx.is_empty());
963            assert!(!test_idx.is_empty());
964            assert_eq!(train_idx.len() + test_idx.len(), 9);
965        }
966    }
967}