1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use sklears_core::error::{Result as SklResult, SklearsError};
9type Result<T> = SklResult<T>;
10use scirs2_core::random::{thread_rng, Rng};
11
12impl From<CrossValidationError> for SklearsError {
13 fn from(err: CrossValidationError) -> Self {
14 SklearsError::FitError(format!("Cross-validation error: {}", err))
15 }
16}
17use std::collections::HashMap;
18use thiserror::Error;
19
20#[derive(Debug, Error)]
21pub enum CrossValidationError {
22 #[error("Insufficient data for cross-validation")]
23 InsufficientData,
24 #[error("Invalid fold configuration")]
25 InvalidFoldConfiguration,
26 #[error("Feature and target length mismatch")]
27 LengthMismatch,
28 #[error("Invalid feature indices")]
29 InvalidFeatureIndices,
30 #[error("Empty feature selection")]
31 EmptyFeatureSelection,
32}
33
34#[derive(Debug, Clone)]
36pub struct NestedCrossValidation {
37 outer_folds: usize,
38 inner_folds: usize,
39 stratified: bool,
40 random_state: Option<u64>,
41}
42
43impl NestedCrossValidation {
44 pub fn new(
46 outer_folds: usize,
47 inner_folds: usize,
48 stratified: bool,
49 random_state: Option<u64>,
50 ) -> Self {
51 Self {
52 outer_folds,
53 inner_folds,
54 stratified,
55 random_state,
56 }
57 }
58
59 #[allow(non_snake_case)]
61 pub fn evaluate<F, G>(
62 &self,
63 X: ArrayView2<f64>,
64 y: ArrayView1<f64>,
65 feature_selector: F,
66 performance_evaluator: G,
67 ) -> Result<NestedCVResults>
68 where
69 F: Fn(ArrayView2<f64>, ArrayView1<f64>) -> Result<Vec<usize>> + Copy,
70 G: Fn(
71 ArrayView2<f64>,
72 ArrayView1<f64>,
73 ArrayView2<f64>,
74 ArrayView1<f64>,
75 &[usize],
76 ) -> Result<f64>
77 + Copy,
78 {
79 if X.nrows() != y.len() {
80 return Err(CrossValidationError::LengthMismatch.into());
81 }
82
83 if X.nrows() < self.outer_folds * 2 {
84 return Err(CrossValidationError::InsufficientData.into());
85 }
86
87 let n_samples = X.nrows();
88 let indices: Vec<usize> = (0..n_samples).collect();
89
90 let outer_splits = if self.stratified {
92 self.stratified_k_fold_split(&indices, y, self.outer_folds)?
93 } else {
94 self.k_fold_split(&indices, self.outer_folds)?
95 };
96
97 let mut outer_scores = Vec::with_capacity(self.outer_folds);
98 let mut feature_selection_stability = Vec::new();
99 let mut inner_cv_scores = Vec::new();
100
101 for (outer_fold, (train_idx, test_idx)) in outer_splits.into_iter().enumerate() {
102 let X_outer_train = self.extract_samples(X, &train_idx);
104 let y_outer_train = self.extract_targets(y, &train_idx);
105 let X_outer_test = self.extract_samples(X, &test_idx);
106 let y_outer_test = self.extract_targets(y, &test_idx);
107
108 let inner_splits = if self.stratified {
110 self.stratified_k_fold_split(
111 &(0..train_idx.len()).collect::<Vec<_>>(),
112 y_outer_train.view(),
113 self.inner_folds,
114 )?
115 } else {
116 self.k_fold_split(&(0..train_idx.len()).collect::<Vec<_>>(), self.inner_folds)?
117 };
118
119 let mut inner_fold_scores = Vec::new();
120 let mut inner_fold_features = Vec::new();
121
122 for (inner_train_idx, inner_val_idx) in inner_splits {
123 let X_inner_train = self.extract_samples(X_outer_train.view(), &inner_train_idx);
125 let y_inner_train = self.extract_targets(y_outer_train.view(), &inner_train_idx);
126 let X_inner_val = self.extract_samples(X_outer_train.view(), &inner_val_idx);
127 let y_inner_val = self.extract_targets(y_outer_train.view(), &inner_val_idx);
128
129 let selected_features =
131 feature_selector(X_inner_train.view(), y_inner_train.view())?;
132
133 if selected_features.is_empty() {
134 return Err(CrossValidationError::EmptyFeatureSelection.into());
135 }
136
137 let inner_score = performance_evaluator(
139 X_inner_train.view(),
140 y_inner_train.view(),
141 X_inner_val.view(),
142 y_inner_val.view(),
143 &selected_features,
144 )?;
145
146 inner_fold_scores.push(inner_score);
147 inner_fold_features.push(selected_features);
148 }
149
150 let inner_cv_mean =
152 inner_fold_scores.iter().sum::<f64>() / inner_fold_scores.len() as f64;
153 inner_cv_scores.push(InnerCVResult {
154 outer_fold,
155 inner_scores: inner_fold_scores,
156 mean_score: inner_cv_mean,
157 selected_features: inner_fold_features,
158 });
159
160 let final_selected_features =
162 feature_selector(X_outer_train.view(), y_outer_train.view())?;
163 feature_selection_stability.push(final_selected_features.clone());
164
165 let outer_score = performance_evaluator(
167 X_outer_train.view(),
168 y_outer_train.view(),
169 X_outer_test.view(),
170 y_outer_test.view(),
171 &final_selected_features,
172 )?;
173
174 outer_scores.push(outer_score);
175 }
176
177 let stability_metrics = self.compute_stability_metrics(&feature_selection_stability)?;
179
180 let outer_mean = outer_scores.iter().sum::<f64>() / outer_scores.len() as f64;
182 let outer_std = {
183 let variance = outer_scores
184 .iter()
185 .map(|score| (score - outer_mean).powi(2))
186 .sum::<f64>()
187 / outer_scores.len() as f64;
188 variance.sqrt()
189 };
190
191 let inner_mean = inner_cv_scores
192 .iter()
193 .map(|result| result.mean_score)
194 .sum::<f64>()
195 / inner_cv_scores.len() as f64;
196
197 Ok(NestedCVResults {
198 outer_scores,
199 outer_mean_score: outer_mean,
200 outer_std_score: outer_std,
201 inner_cv_results: inner_cv_scores,
202 inner_mean_score: inner_mean,
203 feature_stability: stability_metrics,
204 n_outer_folds: self.outer_folds,
205 n_inner_folds: self.inner_folds,
206 })
207 }
208
209 fn k_fold_split(
211 &self,
212 indices: &[usize],
213 n_folds: usize,
214 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
215 if indices.len() < n_folds {
216 return Err(CrossValidationError::InvalidFoldConfiguration.into());
217 }
218
219 let mut shuffled_indices = indices.to_vec();
220
221 if self.random_state.is_some() {
223 self.shuffle_indices(&mut shuffled_indices);
224 }
225
226 let fold_size = indices.len() / n_folds;
227 let remainder = indices.len() % n_folds;
228
229 let mut splits = Vec::new();
230
231 for fold in 0..n_folds {
232 let start = fold * fold_size + fold.min(remainder);
233 let end = start + fold_size + if fold < remainder { 1 } else { 0 };
234
235 let test_indices = shuffled_indices[start..end].to_vec();
236 let train_indices: Vec<usize> = shuffled_indices[..start]
237 .iter()
238 .chain(shuffled_indices[end..].iter())
239 .cloned()
240 .collect();
241
242 splits.push((train_indices, test_indices));
243 }
244
245 Ok(splits)
246 }
247
248 fn stratified_k_fold_split(
250 &self,
251 indices: &[usize],
252 y: ArrayView1<f64>,
253 n_folds: usize,
254 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
255 if indices.len() < n_folds {
256 return Err(CrossValidationError::InvalidFoldConfiguration.into());
257 }
258
259 let mut class_groups: HashMap<i32, Vec<usize>> = HashMap::new();
261 for &idx in indices {
262 let class = y[idx] as i32;
263 class_groups.entry(class).or_default().push(idx);
264 }
265
266 if self.random_state.is_some() {
268 for group in class_groups.values_mut() {
269 self.shuffle_indices(group);
270 }
271 }
272
273 let mut folds: Vec<Vec<usize>> = vec![Vec::new(); n_folds];
275
276 for group in class_groups.values() {
277 let group_fold_size = group.len() / n_folds;
278 let group_remainder = group.len() % n_folds;
279
280 for fold in 0..n_folds {
281 let start = fold * group_fold_size + fold.min(group_remainder);
282 let end = start + group_fold_size + if fold < group_remainder { 1 } else { 0 };
283 folds[fold].extend_from_slice(&group[start..end]);
284 }
285 }
286
287 let mut splits = Vec::new();
289 for fold in 0..n_folds {
290 let test_indices = folds[fold].clone();
291 let train_indices: Vec<usize> = folds
292 .iter()
293 .enumerate()
294 .filter(|(i, _)| *i != fold)
295 .flat_map(|(_, fold_indices)| fold_indices.iter())
296 .cloned()
297 .collect();
298
299 splits.push((train_indices, test_indices));
300 }
301
302 Ok(splits)
303 }
304
305 fn shuffle_indices(&self, indices: &mut [usize]) {
307 for i in (1..indices.len()).rev() {
308 let j = (thread_rng().gen::<f64>() * (i + 1) as f64) as usize;
309 indices.swap(i, j);
310 }
311 }
312
313 fn extract_samples(&self, X: ArrayView2<f64>, indices: &[usize]) -> Array2<f64> {
315 let mut samples = Array2::zeros((indices.len(), X.ncols()));
316 for (i, &idx) in indices.iter().enumerate() {
317 samples.row_mut(i).assign(&X.row(idx));
318 }
319 samples
320 }
321
322 fn extract_targets(&self, y: ArrayView1<f64>, indices: &[usize]) -> Array1<f64> {
324 let mut targets = Array1::zeros(indices.len());
325 for (i, &idx) in indices.iter().enumerate() {
326 targets[i] = y[idx];
327 }
328 targets
329 }
330
331 fn compute_stability_metrics(
333 &self,
334 feature_selections: &[Vec<usize>],
335 ) -> Result<FeatureStabilityMetrics> {
336 if feature_selections.is_empty() {
337 return Ok(FeatureStabilityMetrics {
338 jaccard_similarity: 0.0,
339 intersection_stability: 0.0,
340 average_selection_size: 0.0,
341 unique_features_selected: 0,
342 feature_frequencies: Vec::new(),
343 });
344 }
345
346 let mut jaccard_similarities = Vec::new();
348 for i in 0..feature_selections.len() {
349 for j in (i + 1)..feature_selections.len() {
350 let set1: std::collections::HashSet<_> = feature_selections[i].iter().collect();
351 let set2: std::collections::HashSet<_> = feature_selections[j].iter().collect();
352
353 let intersection = set1.intersection(&set2).count() as f64;
354 let union = set1.union(&set2).count() as f64;
355
356 let jaccard = if union > 0.0 {
357 intersection / union
358 } else {
359 1.0
360 };
361
362 jaccard_similarities.push(jaccard);
363 }
364 }
365
366 let mean_jaccard = if jaccard_similarities.is_empty() {
367 1.0
368 } else {
369 jaccard_similarities.iter().sum::<f64>() / jaccard_similarities.len() as f64
370 };
371
372 let mut feature_counts: HashMap<usize, usize> = HashMap::new();
374 let mut total_features = 0;
375
376 for selection in feature_selections {
377 total_features += selection.len();
378 for &feature in selection {
379 *feature_counts.entry(feature).or_insert(0) += 1;
380 }
381 }
382
383 let average_selection_size = total_features as f64 / feature_selections.len() as f64;
384
385 let mut feature_frequencies: Vec<(usize, f64)> = feature_counts
386 .into_iter()
387 .map(|(feature, count)| {
388 let frequency = count as f64 / feature_selections.len() as f64;
389 (feature, frequency)
390 })
391 .collect();
392
393 feature_frequencies.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
394
395 let all_features: std::collections::HashSet<_> = feature_selections[0].iter().collect();
397 let intersection_features =
398 feature_selections
399 .iter()
400 .skip(1)
401 .fold(all_features, |acc, selection| {
402 let set: std::collections::HashSet<_> = selection.iter().collect();
403 acc.intersection(&set).cloned().collect()
404 });
405
406 let intersection_stability = intersection_features.len() as f64 / average_selection_size;
407
408 Ok(FeatureStabilityMetrics {
409 jaccard_similarity: mean_jaccard,
410 intersection_stability,
411 average_selection_size,
412 unique_features_selected: feature_frequencies.len(),
413 feature_frequencies,
414 })
415 }
416}
417
418#[derive(Debug, Clone)]
420pub struct StratifiedKFold {
421 n_splits: usize,
422 shuffle: bool,
423 random_state: Option<u64>,
424}
425
426impl StratifiedKFold {
427 pub fn new(n_splits: usize, shuffle: bool, random_state: Option<u64>) -> Self {
429 Self {
430 n_splits,
431 shuffle,
432 random_state,
433 }
434 }
435
436 pub fn split(
438 &self,
439 X: ArrayView2<f64>,
440 y: ArrayView1<f64>,
441 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
442 if X.nrows() != y.len() {
443 return Err(CrossValidationError::LengthMismatch.into());
444 }
445
446 let indices: Vec<usize> = (0..X.nrows()).collect();
447 self.stratified_split(&indices, y)
448 }
449
450 fn stratified_split(
451 &self,
452 indices: &[usize],
453 y: ArrayView1<f64>,
454 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
455 let mut class_groups: HashMap<i32, Vec<usize>> = HashMap::new();
457 for &idx in indices {
458 let class = y[idx] as i32;
459 class_groups.entry(class).or_default().push(idx);
460 }
461
462 for (class, group) in &class_groups {
464 if group.len() < self.n_splits {
465 return Err(SklearsError::InvalidInput(format!(
466 "Class {} has only {} samples, need at least {}",
467 class,
468 group.len(),
469 self.n_splits
470 )));
471 }
472 }
473
474 if self.shuffle {
476 for group in class_groups.values_mut() {
477 self.shuffle_indices(group);
478 }
479 }
480
481 let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.n_splits];
483
484 for group in class_groups.values() {
485 let fold_size = group.len() / self.n_splits;
486 let remainder = group.len() % self.n_splits;
487
488 for fold in 0..self.n_splits {
489 let start = fold * fold_size + fold.min(remainder);
490 let end = start + fold_size + if fold < remainder { 1 } else { 0 };
491 folds[fold].extend_from_slice(&group[start..end]);
492 }
493 }
494
495 let mut splits = Vec::new();
497 for fold in 0..self.n_splits {
498 let test_indices = folds[fold].clone();
499 let train_indices: Vec<usize> = folds
500 .iter()
501 .enumerate()
502 .filter(|(i, _)| *i != fold)
503 .flat_map(|(_, fold_indices)| fold_indices.iter())
504 .cloned()
505 .collect();
506
507 splits.push((train_indices, test_indices));
508 }
509
510 Ok(splits)
511 }
512
513 fn shuffle_indices(&self, indices: &mut [usize]) {
514 for i in (1..indices.len()).rev() {
515 let j = (thread_rng().gen::<f64>() * (i + 1) as f64) as usize;
516 indices.swap(i, j);
517 }
518 }
519}
520
521#[derive(Debug, Clone)]
523pub struct TimeSeriesSplit {
524 n_splits: usize,
525 max_train_size: Option<usize>,
526 test_size: Option<usize>,
527}
528
529impl TimeSeriesSplit {
530 pub fn new(n_splits: usize, max_train_size: Option<usize>, test_size: Option<usize>) -> Self {
532 Self {
533 n_splits,
534 max_train_size,
535 test_size,
536 }
537 }
538
539 pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
541 if n_samples < self.n_splits + 1 {
542 return Err(CrossValidationError::InsufficientData.into());
543 }
544
545 let test_size = self.test_size.unwrap_or(n_samples / (self.n_splits + 1));
546 let mut splits = Vec::new();
547
548 for split in 0..self.n_splits {
549 let test_start = (split + 1) * test_size;
550 let test_end = test_start + test_size;
551
552 if test_end > n_samples {
553 break;
554 }
555
556 let train_end = test_start;
557 let train_start = if let Some(max_size) = self.max_train_size {
558 train_end.saturating_sub(max_size)
559 } else {
560 0
561 };
562
563 let train_indices: Vec<usize> = (train_start..train_end).collect();
564 let test_indices: Vec<usize> = (test_start..test_end).collect();
565
566 if !train_indices.is_empty() && !test_indices.is_empty() {
567 splits.push((train_indices, test_indices));
568 }
569 }
570
571 Ok(splits)
572 }
573}
574
575#[derive(Debug, Clone)]
577pub struct GroupKFold {
578 n_splits: usize,
579}
580
581impl GroupKFold {
582 pub fn new(n_splits: usize) -> Self {
584 Self { n_splits }
585 }
586
587 pub fn split(&self, groups: &[usize]) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
589 let mut unique_groups: Vec<usize> = groups.to_vec();
591 unique_groups.sort_unstable();
592 unique_groups.dedup();
593
594 if unique_groups.len() < self.n_splits {
595 return Err(CrossValidationError::InvalidFoldConfiguration.into());
596 }
597
598 let mut group_indices: HashMap<usize, Vec<usize>> = HashMap::new();
600 for (idx, &group) in groups.iter().enumerate() {
601 group_indices.entry(group).or_default().push(idx);
602 }
603
604 let groups_per_fold = unique_groups.len() / self.n_splits;
606 let remainder = unique_groups.len() % self.n_splits;
607
608 let mut splits = Vec::new();
609
610 for fold in 0..self.n_splits {
611 let start = fold * groups_per_fold + fold.min(remainder);
612 let end = start + groups_per_fold + if fold < remainder { 1 } else { 0 };
613
614 let test_groups = &unique_groups[start..end];
615 let train_groups: Vec<usize> = unique_groups[..start]
616 .iter()
617 .chain(unique_groups[end..].iter())
618 .cloned()
619 .collect();
620
621 let test_indices: Vec<usize> = test_groups
622 .iter()
623 .flat_map(|&group| group_indices[&group].iter())
624 .cloned()
625 .collect();
626
627 let train_indices: Vec<usize> = train_groups
628 .iter()
629 .flat_map(|&group| group_indices[&group].iter())
630 .cloned()
631 .collect();
632
633 splits.push((train_indices, test_indices));
634 }
635
636 Ok(splits)
637 }
638}
639
640#[derive(Debug, Clone)]
642pub struct RepeatedKFold {
643 n_splits: usize,
644 n_repeats: usize,
645 random_state: Option<u64>,
646}
647
648impl RepeatedKFold {
649 pub fn new(n_splits: usize, n_repeats: usize, random_state: Option<u64>) -> Self {
651 Self {
652 n_splits,
653 n_repeats,
654 random_state,
655 }
656 }
657
658 pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
660 let mut all_splits = Vec::new();
661
662 for repeat in 0..self.n_repeats {
663 let current_random_state = self.random_state.map(|s| s + repeat as u64);
664
665 let indices: Vec<usize> = (0..n_samples).collect();
667 let kfold_splits = self.k_fold_split(&indices, current_random_state)?;
668
669 all_splits.extend(kfold_splits);
670 }
671
672 Ok(all_splits)
673 }
674
675 fn k_fold_split(
676 &self,
677 indices: &[usize],
678 random_state: Option<u64>,
679 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
680 let mut shuffled_indices = indices.to_vec();
681
682 if random_state.is_some() {
683 self.shuffle_indices(&mut shuffled_indices);
684 }
685
686 let fold_size = indices.len() / self.n_splits;
687 let remainder = indices.len() % self.n_splits;
688
689 let mut splits = Vec::new();
690
691 for fold in 0..self.n_splits {
692 let start = fold * fold_size + fold.min(remainder);
693 let end = start + fold_size + if fold < remainder { 1 } else { 0 };
694
695 let test_indices = shuffled_indices[start..end].to_vec();
696 let train_indices: Vec<usize> = shuffled_indices[..start]
697 .iter()
698 .chain(shuffled_indices[end..].iter())
699 .cloned()
700 .collect();
701
702 splits.push((train_indices, test_indices));
703 }
704
705 Ok(splits)
706 }
707
708 fn shuffle_indices(&self, indices: &mut [usize]) {
709 for i in (1..indices.len()).rev() {
710 let j = (thread_rng().gen::<f64>() * (i + 1) as f64) as usize;
711 indices.swap(i, j);
712 }
713 }
714}
715
716#[derive(Debug, Clone)]
718pub struct NestedCVResults {
719 pub outer_scores: Vec<f64>,
720 pub outer_mean_score: f64,
721 pub outer_std_score: f64,
722 pub inner_cv_results: Vec<InnerCVResult>,
723 pub inner_mean_score: f64,
724 pub feature_stability: FeatureStabilityMetrics,
725 pub n_outer_folds: usize,
726 pub n_inner_folds: usize,
727}
728
729impl NestedCVResults {
730 pub fn report(&self) -> String {
732 let mut report = String::new();
733
734 report.push_str("=== Nested Cross-Validation Results ===\n\n");
735
736 report.push_str(&format!(
737 "Configuration: {} outer folds, {} inner folds\n\n",
738 self.n_outer_folds, self.n_inner_folds
739 ));
740
741 report.push_str("Outer CV Performance:\n");
742 report.push_str(&format!(
743 " Mean Score: {:.4} ± {:.4}\n",
744 self.outer_mean_score, self.outer_std_score
745 ));
746 report.push_str(&format!(
747 " Individual Scores: {:?}\n\n",
748 self.outer_scores
749 .iter()
750 .map(|s| format!("{:.4}", s))
751 .collect::<Vec<_>>()
752 ));
753
754 report.push_str("Inner CV Performance:\n");
755 report.push_str(&format!(" Mean Score: {:.4}\n", self.inner_mean_score));
756
757 for (i, inner_result) in self.inner_cv_results.iter().enumerate() {
758 report.push_str(&format!(
759 " Outer Fold {}: {:.4} ± {:.4}\n",
760 i,
761 inner_result.mean_score,
762 inner_result.std_score()
763 ));
764 }
765
766 report.push_str("\nFeature Selection Stability:\n");
767 report.push_str(&format!(
768 " Jaccard Similarity: {:.4}\n",
769 self.feature_stability.jaccard_similarity
770 ));
771 report.push_str(&format!(
772 " Intersection Stability: {:.4}\n",
773 self.feature_stability.intersection_stability
774 ));
775 report.push_str(&format!(
776 " Average Selection Size: {:.1}\n",
777 self.feature_stability.average_selection_size
778 ));
779 report.push_str(&format!(
780 " Unique Features Selected: {}\n",
781 self.feature_stability.unique_features_selected
782 ));
783
784 if !self.feature_stability.feature_frequencies.is_empty() {
785 report.push_str("\nTop 10 Most Frequent Features:\n");
786 for (feature, frequency) in self.feature_stability.feature_frequencies.iter().take(10) {
787 report.push_str(&format!(
788 " Feature {}: {:.1}%\n",
789 feature,
790 frequency * 100.0
791 ));
792 }
793 }
794
795 report
796 }
797}
798
799#[derive(Debug, Clone)]
801pub struct InnerCVResult {
802 pub outer_fold: usize,
803 pub inner_scores: Vec<f64>,
804 pub mean_score: f64,
805 pub selected_features: Vec<Vec<usize>>,
806}
807
808impl InnerCVResult {
809 pub fn std_score(&self) -> f64 {
810 if self.inner_scores.len() <= 1 {
811 return 0.0;
812 }
813
814 let variance = self
815 .inner_scores
816 .iter()
817 .map(|score| (score - self.mean_score).powi(2))
818 .sum::<f64>()
819 / self.inner_scores.len() as f64;
820 variance.sqrt()
821 }
822}
823
824#[derive(Debug, Clone)]
826pub struct FeatureStabilityMetrics {
827 pub jaccard_similarity: f64,
828 pub intersection_stability: f64,
829 pub average_selection_size: f64,
830 pub unique_features_selected: usize,
831 pub feature_frequencies: Vec<(usize, f64)>,
832}
833
834#[allow(non_snake_case)]
835#[cfg(test)]
836mod tests {
837 use super::*;
838 use scirs2_core::ndarray::array;
839
840 fn mock_feature_selector(X: ArrayView2<f64>, _y: ArrayView1<f64>) -> Result<Vec<usize>> {
842 let n_features = X.ncols();
844 Ok((0..(n_features / 2)).collect())
845 }
846
847 fn mock_performance_evaluator(
849 _X_train: ArrayView2<f64>,
850 _y_train: ArrayView1<f64>,
851 _X_test: ArrayView2<f64>,
852 _y_test: ArrayView1<f64>,
853 _features: &[usize],
854 ) -> Result<f64> {
855 Ok(0.7 + thread_rng().gen::<f64>() * 0.2)
857 }
858
859 #[test]
860 #[allow(non_snake_case)]
861 fn test_nested_cross_validation() {
862 let X = array![
863 [1.0, 2.0, 3.0, 4.0],
864 [2.0, 3.0, 4.0, 5.0],
865 [3.0, 4.0, 5.0, 6.0],
866 [4.0, 5.0, 6.0, 7.0],
867 [5.0, 6.0, 7.0, 8.0],
868 [6.0, 7.0, 8.0, 9.0],
869 [7.0, 8.0, 9.0, 10.0],
870 [8.0, 9.0, 10.0, 11.0],
871 ];
872 let y = array![0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0];
873
874 let nested_cv = NestedCrossValidation::new(3, 2, false, Some(42));
875 let results = nested_cv
876 .evaluate(
877 X.view(),
878 y.view(),
879 mock_feature_selector,
880 mock_performance_evaluator,
881 )
882 .unwrap();
883
884 assert_eq!(results.outer_scores.len(), 3);
885 assert_eq!(results.inner_cv_results.len(), 3);
886 assert!(results.outer_mean_score >= 0.0 && results.outer_mean_score <= 1.0);
887 assert!(results.feature_stability.jaccard_similarity >= 0.0);
888
889 let report = results.report();
890 assert!(report.contains("Nested Cross-Validation"));
891 assert!(report.contains("Feature Selection Stability"));
892 }
893
894 #[test]
895 #[allow(non_snake_case)]
896 fn test_stratified_k_fold() {
897 let X = array![
898 [1.0, 2.0],
899 [2.0, 3.0],
900 [3.0, 4.0],
901 [4.0, 5.0],
902 [5.0, 6.0],
903 [6.0, 7.0],
904 ];
905 let y = array![0.0, 0.0, 1.0, 1.0, 0.0, 1.0];
906
907 let skf = StratifiedKFold::new(3, true, Some(42));
908 let splits = skf.split(X.view(), y.view()).unwrap();
909
910 assert_eq!(splits.len(), 3);
911
912 for (train_idx, test_idx) in splits {
913 assert!(!train_idx.is_empty());
914 assert!(!test_idx.is_empty());
915 assert_eq!(train_idx.len() + test_idx.len(), X.nrows());
916 }
917 }
918
919 #[test]
920 fn test_time_series_split() {
921 let ts_split = TimeSeriesSplit::new(3, None, Some(2));
922 let splits = ts_split.split(10).unwrap();
923
924 assert_eq!(splits.len(), 3);
925
926 for (train_idx, test_idx) in splits {
927 assert!(!train_idx.is_empty());
928 assert_eq!(test_idx.len(), 2);
929
930 if !train_idx.is_empty() && !test_idx.is_empty() {
932 let max_train = train_idx.iter().max().unwrap();
933 let min_test = test_idx.iter().min().unwrap();
934 assert!(max_train < min_test);
935 }
936 }
937 }
938
939 #[test]
940 fn test_group_k_fold() {
941 let groups = vec![0, 0, 1, 1, 2, 2];
942 let gkf = GroupKFold::new(3);
943 let splits = gkf.split(&groups).unwrap();
944
945 assert_eq!(splits.len(), 3);
946
947 for (train_idx, test_idx) in splits {
948 assert!(!train_idx.is_empty());
949 assert!(!test_idx.is_empty());
950 assert_eq!(train_idx.len() + test_idx.len(), groups.len());
951 }
952 }
953
954 #[test]
955 fn test_repeated_k_fold() {
956 let rkf = RepeatedKFold::new(3, 2, Some(42));
957 let splits = rkf.split(9).unwrap();
958
959 assert_eq!(splits.len(), 6); for (train_idx, test_idx) in splits {
962 assert!(!train_idx.is_empty());
963 assert!(!test_idx.is_empty());
964 assert_eq!(train_idx.len() + test_idx.len(), 9);
965 }
966 }
967}