1use crate::{TrainError, TrainResult};
11use scirs2_core::random::{SeedableRng, StdRng};
12use std::collections::HashMap;
13
14pub trait CrossValidationSplit {
16 fn num_splits(&self) -> usize;
18
19 fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)>;
28}
29
30#[derive(Debug, Clone)]
35pub struct KFold {
36 pub n_splits: usize,
38 pub shuffle: bool,
40 pub random_seed: u64,
42}
43
44impl KFold {
45 pub fn new(n_splits: usize) -> TrainResult<Self> {
50 if n_splits < 2 {
51 return Err(TrainError::InvalidParameter(
52 "n_splits must be at least 2".to_string(),
53 ));
54 }
55 Ok(Self {
56 n_splits,
57 shuffle: false,
58 random_seed: 42,
59 })
60 }
61
62 pub fn with_shuffle(mut self, seed: u64) -> Self {
64 self.shuffle = true;
65 self.random_seed = seed;
66 self
67 }
68}
69
70impl CrossValidationSplit for KFold {
71 fn num_splits(&self) -> usize {
72 self.n_splits
73 }
74
75 fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
76 if fold >= self.n_splits {
77 return Err(TrainError::InvalidParameter(format!(
78 "fold {} is out of range [0, {})",
79 fold, self.n_splits
80 )));
81 }
82
83 let mut indices: Vec<usize> = (0..n_samples).collect();
85
86 if self.shuffle {
88 let mut rng = StdRng::seed_from_u64(self.random_seed);
89 for i in (1..n_samples).rev() {
90 let j = rng.gen_range(0..=i);
91 indices.swap(i, j);
92 }
93 }
94
95 let fold_size = n_samples / self.n_splits;
97 let remainder = n_samples % self.n_splits;
98
99 let mut fold_sizes = vec![fold_size; self.n_splits];
100 for fold in fold_sizes.iter_mut().take(remainder) {
101 *fold += 1;
102 }
103
104 let mut boundaries = vec![0];
106 for size in &fold_sizes {
107 boundaries.push(
108 boundaries
109 .last()
110 .expect("boundaries is initialized non-empty")
111 + size,
112 );
113 }
114
115 let val_start = boundaries[fold];
117 let val_end = boundaries[fold + 1];
118 let val_indices: Vec<usize> = indices[val_start..val_end].to_vec();
119
120 let mut train_indices = Vec::new();
122 train_indices.extend_from_slice(&indices[..val_start]);
123 train_indices.extend_from_slice(&indices[val_end..]);
124
125 Ok((train_indices, val_indices))
126 }
127}
128
129#[derive(Debug, Clone)]
133pub struct StratifiedKFold {
134 pub n_splits: usize,
136 pub shuffle: bool,
138 pub random_seed: u64,
140}
141
142impl StratifiedKFold {
143 pub fn new(n_splits: usize) -> TrainResult<Self> {
148 if n_splits < 2 {
149 return Err(TrainError::InvalidParameter(
150 "n_splits must be at least 2".to_string(),
151 ));
152 }
153 Ok(Self {
154 n_splits,
155 shuffle: true,
156 random_seed: 42,
157 })
158 }
159
160 pub fn with_seed(mut self, seed: u64) -> Self {
162 self.random_seed = seed;
163 self
164 }
165
166 pub fn get_stratified_split(
172 &self,
173 fold: usize,
174 labels: &[usize],
175 ) -> TrainResult<(Vec<usize>, Vec<usize>)> {
176 if fold >= self.n_splits {
177 return Err(TrainError::InvalidParameter(format!(
178 "fold {} is out of range [0, {})",
179 fold, self.n_splits
180 )));
181 }
182
183 let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
185 for (i, &label) in labels.iter().enumerate() {
186 class_indices.entry(label).or_default().push(i);
187 }
188
189 if self.shuffle {
191 let mut rng = StdRng::seed_from_u64(self.random_seed);
192 for indices in class_indices.values_mut() {
193 for i in (1..indices.len()).rev() {
194 let j = rng.gen_range(0..=i);
195 indices.swap(i, j);
196 }
197 }
198 }
199
200 let mut train_indices = Vec::new();
202 let mut val_indices = Vec::new();
203
204 for indices in class_indices.values() {
205 let class_size = indices.len();
206 let fold_size = class_size / self.n_splits;
207 let remainder = class_size % self.n_splits;
208
209 let mut fold_sizes = vec![fold_size; self.n_splits];
210 for fold in fold_sizes.iter_mut().take(remainder) {
211 *fold += 1;
212 }
213
214 let mut boundaries = vec![0];
216 for size in &fold_sizes {
217 boundaries.push(
218 boundaries
219 .last()
220 .expect("boundaries is initialized non-empty")
221 + size,
222 );
223 }
224
225 let val_start = boundaries[fold];
227 let val_end = boundaries[fold + 1];
228 val_indices.extend_from_slice(&indices[val_start..val_end]);
229
230 train_indices.extend_from_slice(&indices[..val_start]);
232 train_indices.extend_from_slice(&indices[val_end..]);
233 }
234
235 Ok((train_indices, val_indices))
236 }
237}
238
239impl CrossValidationSplit for StratifiedKFold {
240 fn num_splits(&self) -> usize {
241 self.n_splits
242 }
243
244 fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
245 let labels: Vec<usize> = (0..n_samples).map(|i| i % self.n_splits).collect();
248 self.get_stratified_split(fold, &labels)
249 }
250}
251
252#[derive(Debug, Clone)]
257pub struct TimeSeriesSplit {
258 pub n_splits: usize,
260 pub min_train_size: Option<usize>,
262 pub max_train_size: Option<usize>,
264}
265
266impl TimeSeriesSplit {
267 pub fn new(n_splits: usize) -> TrainResult<Self> {
272 if n_splits < 2 {
273 return Err(TrainError::InvalidParameter(
274 "n_splits must be at least 2".to_string(),
275 ));
276 }
277 Ok(Self {
278 n_splits,
279 min_train_size: None,
280 max_train_size: None,
281 })
282 }
283
284 pub fn with_min_train_size(mut self, size: usize) -> Self {
286 self.min_train_size = Some(size);
287 self
288 }
289
290 pub fn with_max_train_size(mut self, size: usize) -> Self {
292 self.max_train_size = Some(size);
293 self
294 }
295}
296
297impl CrossValidationSplit for TimeSeriesSplit {
298 fn num_splits(&self) -> usize {
299 self.n_splits
300 }
301
302 fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
303 if fold >= self.n_splits {
304 return Err(TrainError::InvalidParameter(format!(
305 "fold {} is out of range [0, {})",
306 fold, self.n_splits
307 )));
308 }
309
310 let test_size = n_samples / (self.n_splits + 1);
312 if test_size == 0 {
313 return Err(TrainError::InvalidParameter(
314 "Not enough samples for time series split".to_string(),
315 ));
316 }
317
318 let val_start = (fold + 1) * test_size;
320 let val_end = ((fold + 2) * test_size).min(n_samples);
321
322 let train_end = val_start;
324 let train_start = if let Some(max_size) = self.max_train_size {
325 train_end.saturating_sub(max_size)
326 } else if let Some(min_size) = self.min_train_size {
327 if train_end < min_size {
328 return Err(TrainError::InvalidParameter(
329 "Not enough samples for min_train_size".to_string(),
330 ));
331 }
332 0
333 } else {
334 0
335 };
336
337 let train_indices: Vec<usize> = (train_start..train_end).collect();
338 let val_indices: Vec<usize> = (val_start..val_end).collect();
339
340 if train_indices.is_empty() {
341 return Err(TrainError::InvalidParameter(
342 "Training set is empty for this fold".to_string(),
343 ));
344 }
345
346 Ok((train_indices, val_indices))
347 }
348}
349
350#[derive(Debug, Clone, Default)]
355pub struct LeaveOneOut;
356
357impl LeaveOneOut {
358 pub fn new() -> Self {
360 Self
361 }
362}
363
364impl CrossValidationSplit for LeaveOneOut {
365 fn num_splits(&self) -> usize {
366 usize::MAX
368 }
369
370 fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
371 if fold >= n_samples {
372 return Err(TrainError::InvalidParameter(format!(
373 "fold {} is out of range [0, {})",
374 fold, n_samples
375 )));
376 }
377
378 let val_indices = vec![fold];
380
381 let mut train_indices: Vec<usize> = (0..fold).collect();
383 train_indices.extend(fold + 1..n_samples);
384
385 Ok((train_indices, val_indices))
386 }
387}
388
389#[derive(Debug, Clone)]
391pub struct CrossValidationResults {
392 pub fold_scores: Vec<f64>,
394 pub fold_metrics: Vec<HashMap<String, f64>>,
396}
397
398impl CrossValidationResults {
399 pub fn new() -> Self {
401 Self {
402 fold_scores: Vec::new(),
403 fold_metrics: Vec::new(),
404 }
405 }
406
407 pub fn add_fold(&mut self, score: f64, metrics: HashMap<String, f64>) {
409 self.fold_scores.push(score);
410 self.fold_metrics.push(metrics);
411 }
412
413 pub fn mean_score(&self) -> f64 {
415 if self.fold_scores.is_empty() {
416 return 0.0;
417 }
418 self.fold_scores.iter().sum::<f64>() / self.fold_scores.len() as f64
419 }
420
421 pub fn std_score(&self) -> f64 {
423 if self.fold_scores.len() <= 1 {
424 return 0.0;
425 }
426
427 let mean = self.mean_score();
428 let variance = self
429 .fold_scores
430 .iter()
431 .map(|&score| (score - mean).powi(2))
432 .sum::<f64>()
433 / (self.fold_scores.len() - 1) as f64;
434
435 variance.sqrt()
436 }
437
438 pub fn mean_metric(&self, metric_name: &str) -> Option<f64> {
440 if self.fold_metrics.is_empty() {
441 return None;
442 }
443
444 let mut sum = 0.0;
445 let mut count = 0;
446
447 for metrics in &self.fold_metrics {
448 if let Some(&value) = metrics.get(metric_name) {
449 sum += value;
450 count += 1;
451 }
452 }
453
454 if count > 0 {
455 Some(sum / count as f64)
456 } else {
457 None
458 }
459 }
460
461 pub fn num_folds(&self) -> usize {
463 self.fold_scores.len()
464 }
465}
466
467impl Default for CrossValidationResults {
468 fn default() -> Self {
469 Self::new()
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476
477 #[test]
478 fn test_kfold_basic() {
479 let kfold = KFold::new(3).expect("unwrap");
480 assert_eq!(kfold.num_splits(), 3);
481
482 let (train, val) = kfold.get_split(0, 10).expect("unwrap");
483 assert!(!train.is_empty());
484 assert!(!val.is_empty());
485
486 for &idx in &val {
488 assert!(!train.contains(&idx));
489 }
490
491 let mut all_indices = train.clone();
493 all_indices.extend(&val);
494 all_indices.sort();
495 assert_eq!(all_indices, (0..10).collect::<Vec<_>>());
496 }
497
498 #[test]
499 fn test_kfold_with_shuffle() {
500 let kfold = KFold::new(3).expect("unwrap").with_shuffle(42);
501 let (train1, val1) = kfold.get_split(0, 10).expect("unwrap");
502 let (train2, val2) = kfold.get_split(0, 10).expect("unwrap");
503
504 assert_eq!(train1, train2);
506 assert_eq!(val1, val2);
507 }
508
509 #[test]
510 fn test_kfold_invalid() {
511 assert!(KFold::new(1).is_err());
512 let kfold = KFold::new(3).expect("unwrap");
513 assert!(kfold.get_split(5, 10).is_err()); }
515
516 #[test]
517 fn test_stratified_kfold() {
518 let skfold = StratifiedKFold::new(3).expect("unwrap");
519 assert_eq!(skfold.num_splits(), 3);
520
521 let labels = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
523
524 let (_train, val) = skfold.get_stratified_split(0, &labels).expect("unwrap");
525
526 let mut val_classes: Vec<usize> = val.iter().map(|&i| labels[i]).collect();
528 val_classes.sort();
529 val_classes.dedup();
530
531 assert!(!val.is_empty());
533 }
534
535 #[test]
536 fn test_time_series_split() {
537 let ts_split = TimeSeriesSplit::new(3).expect("unwrap");
538 assert_eq!(ts_split.num_splits(), 3);
539
540 let (train, val) = ts_split.get_split(0, 10).expect("unwrap");
541
542 if !train.is_empty() && !val.is_empty() {
544 assert!(train.iter().max().expect("unwrap") < val.iter().min().expect("unwrap"));
545 }
546 }
547
548 #[test]
549 fn test_time_series_split_with_window() {
550 let ts_split = TimeSeriesSplit::new(3)
551 .expect("unwrap")
552 .with_min_train_size(2)
553 .with_max_train_size(5);
554
555 let (train, val) = ts_split.get_split(1, 20).expect("unwrap");
556
557 assert!(train.len() <= 5);
559 assert!(!val.is_empty());
560 }
561
562 #[test]
563 fn test_time_series_split_invalid() {
564 let ts_split = TimeSeriesSplit::new(3).expect("unwrap");
565
566 assert!(ts_split.get_split(0, 2).is_err());
568
569 assert!(ts_split.get_split(5, 10).is_err());
571 }
572
573 #[test]
574 fn test_leave_one_out() {
575 let loo = LeaveOneOut::new();
576
577 let (train, val) = loo.get_split(0, 5).expect("unwrap");
578
579 assert_eq!(val.len(), 1);
580 assert_eq!(train.len(), 4);
581 assert_eq!(val[0], 0);
582
583 let (train, val) = loo.get_split(3, 5).expect("unwrap");
584 assert_eq!(val[0], 3);
585 assert_eq!(train.len(), 4);
586 }
587
588 #[test]
589 fn test_leave_one_out_invalid() {
590 let loo = LeaveOneOut::new();
591 assert!(loo.get_split(5, 5).is_err()); }
593
594 #[test]
595 fn test_cv_results() {
596 let mut results = CrossValidationResults::new();
597
598 let mut metrics1 = HashMap::new();
599 metrics1.insert("accuracy".to_string(), 0.9);
600 results.add_fold(0.85, metrics1);
601
602 let mut metrics2 = HashMap::new();
603 metrics2.insert("accuracy".to_string(), 0.95);
604 results.add_fold(0.90, metrics2);
605
606 let mut metrics3 = HashMap::new();
607 metrics3.insert("accuracy".to_string(), 0.92);
608 results.add_fold(0.88, metrics3);
609
610 assert_eq!(results.num_folds(), 3);
611
612 let mean = results.mean_score();
614 assert!((mean - 0.8766666).abs() < 1e-6);
615
616 let std = results.std_score();
618 assert!(std > 0.0);
619
620 let mean_acc = results.mean_metric("accuracy").expect("unwrap");
622 assert!((mean_acc - 0.923333).abs() < 1e-5);
623 }
624
625 #[test]
626 fn test_cv_results_empty() {
627 let results = CrossValidationResults::new();
628 assert_eq!(results.mean_score(), 0.0);
629 assert_eq!(results.std_score(), 0.0);
630 assert_eq!(results.num_folds(), 0);
631 assert!(results.mean_metric("accuracy").is_none());
632 }
633
634 #[test]
635 fn test_kfold_all_folds() {
636 let kfold = KFold::new(5).expect("unwrap");
637 let n_samples = 20;
638
639 let mut all_val_indices = Vec::new();
640
641 for fold in 0..5 {
643 let (_, val) = kfold.get_split(fold, n_samples).expect("unwrap");
644 all_val_indices.extend(val);
645 }
646
647 all_val_indices.sort();
648
649 assert_eq!(all_val_indices.len(), n_samples);
651 assert_eq!(all_val_indices, (0..n_samples).collect::<Vec<_>>());
652 }
653}