1use crate::{TrainError, TrainResult};
11use scirs2_core::random::{SeedableRng, StdRng};
12use std::collections::HashMap;
13
14pub trait CrossValidationSplit {
16 fn num_splits(&self) -> usize;
18
19 fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)>;
28}
29
30#[derive(Debug, Clone)]
35pub struct KFold {
36 pub n_splits: usize,
38 pub shuffle: bool,
40 pub random_seed: u64,
42}
43
44impl KFold {
45 pub fn new(n_splits: usize) -> TrainResult<Self> {
50 if n_splits < 2 {
51 return Err(TrainError::InvalidParameter(
52 "n_splits must be at least 2".to_string(),
53 ));
54 }
55 Ok(Self {
56 n_splits,
57 shuffle: false,
58 random_seed: 42,
59 })
60 }
61
62 pub fn with_shuffle(mut self, seed: u64) -> Self {
64 self.shuffle = true;
65 self.random_seed = seed;
66 self
67 }
68}
69
70impl CrossValidationSplit for KFold {
71 fn num_splits(&self) -> usize {
72 self.n_splits
73 }
74
75 fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
76 if fold >= self.n_splits {
77 return Err(TrainError::InvalidParameter(format!(
78 "fold {} is out of range [0, {})",
79 fold, self.n_splits
80 )));
81 }
82
83 let mut indices: Vec<usize> = (0..n_samples).collect();
85
86 if self.shuffle {
88 let mut rng = StdRng::seed_from_u64(self.random_seed);
89 for i in (1..n_samples).rev() {
90 let j = rng.gen_range(0..=i);
91 indices.swap(i, j);
92 }
93 }
94
95 let fold_size = n_samples / self.n_splits;
97 let remainder = n_samples % self.n_splits;
98
99 let mut fold_sizes = vec![fold_size; self.n_splits];
100 for fold in fold_sizes.iter_mut().take(remainder) {
101 *fold += 1;
102 }
103
104 let mut boundaries = vec![0];
106 for size in &fold_sizes {
107 boundaries.push(boundaries.last().unwrap() + size);
108 }
109
110 let val_start = boundaries[fold];
112 let val_end = boundaries[fold + 1];
113 let val_indices: Vec<usize> = indices[val_start..val_end].to_vec();
114
115 let mut train_indices = Vec::new();
117 train_indices.extend_from_slice(&indices[..val_start]);
118 train_indices.extend_from_slice(&indices[val_end..]);
119
120 Ok((train_indices, val_indices))
121 }
122}
123
124#[derive(Debug, Clone)]
128pub struct StratifiedKFold {
129 pub n_splits: usize,
131 pub shuffle: bool,
133 pub random_seed: u64,
135}
136
137impl StratifiedKFold {
138 pub fn new(n_splits: usize) -> TrainResult<Self> {
143 if n_splits < 2 {
144 return Err(TrainError::InvalidParameter(
145 "n_splits must be at least 2".to_string(),
146 ));
147 }
148 Ok(Self {
149 n_splits,
150 shuffle: true,
151 random_seed: 42,
152 })
153 }
154
155 pub fn with_seed(mut self, seed: u64) -> Self {
157 self.random_seed = seed;
158 self
159 }
160
161 pub fn get_stratified_split(
167 &self,
168 fold: usize,
169 labels: &[usize],
170 ) -> TrainResult<(Vec<usize>, Vec<usize>)> {
171 if fold >= self.n_splits {
172 return Err(TrainError::InvalidParameter(format!(
173 "fold {} is out of range [0, {})",
174 fold, self.n_splits
175 )));
176 }
177
178 let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
180 for (i, &label) in labels.iter().enumerate() {
181 class_indices.entry(label).or_default().push(i);
182 }
183
184 if self.shuffle {
186 let mut rng = StdRng::seed_from_u64(self.random_seed);
187 for indices in class_indices.values_mut() {
188 for i in (1..indices.len()).rev() {
189 let j = rng.gen_range(0..=i);
190 indices.swap(i, j);
191 }
192 }
193 }
194
195 let mut train_indices = Vec::new();
197 let mut val_indices = Vec::new();
198
199 for indices in class_indices.values() {
200 let class_size = indices.len();
201 let fold_size = class_size / self.n_splits;
202 let remainder = class_size % self.n_splits;
203
204 let mut fold_sizes = vec![fold_size; self.n_splits];
205 for fold in fold_sizes.iter_mut().take(remainder) {
206 *fold += 1;
207 }
208
209 let mut boundaries = vec![0];
211 for size in &fold_sizes {
212 boundaries.push(boundaries.last().unwrap() + size);
213 }
214
215 let val_start = boundaries[fold];
217 let val_end = boundaries[fold + 1];
218 val_indices.extend_from_slice(&indices[val_start..val_end]);
219
220 train_indices.extend_from_slice(&indices[..val_start]);
222 train_indices.extend_from_slice(&indices[val_end..]);
223 }
224
225 Ok((train_indices, val_indices))
226 }
227}
228
229impl CrossValidationSplit for StratifiedKFold {
230 fn num_splits(&self) -> usize {
231 self.n_splits
232 }
233
234 fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
235 let labels: Vec<usize> = (0..n_samples).map(|i| i % self.n_splits).collect();
238 self.get_stratified_split(fold, &labels)
239 }
240}
241
242#[derive(Debug, Clone)]
247pub struct TimeSeriesSplit {
248 pub n_splits: usize,
250 pub min_train_size: Option<usize>,
252 pub max_train_size: Option<usize>,
254}
255
256impl TimeSeriesSplit {
257 pub fn new(n_splits: usize) -> TrainResult<Self> {
262 if n_splits < 2 {
263 return Err(TrainError::InvalidParameter(
264 "n_splits must be at least 2".to_string(),
265 ));
266 }
267 Ok(Self {
268 n_splits,
269 min_train_size: None,
270 max_train_size: None,
271 })
272 }
273
274 pub fn with_min_train_size(mut self, size: usize) -> Self {
276 self.min_train_size = Some(size);
277 self
278 }
279
280 pub fn with_max_train_size(mut self, size: usize) -> Self {
282 self.max_train_size = Some(size);
283 self
284 }
285}
286
287impl CrossValidationSplit for TimeSeriesSplit {
288 fn num_splits(&self) -> usize {
289 self.n_splits
290 }
291
292 fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
293 if fold >= self.n_splits {
294 return Err(TrainError::InvalidParameter(format!(
295 "fold {} is out of range [0, {})",
296 fold, self.n_splits
297 )));
298 }
299
300 let test_size = n_samples / (self.n_splits + 1);
302 if test_size == 0 {
303 return Err(TrainError::InvalidParameter(
304 "Not enough samples for time series split".to_string(),
305 ));
306 }
307
308 let val_start = (fold + 1) * test_size;
310 let val_end = ((fold + 2) * test_size).min(n_samples);
311
312 let train_end = val_start;
314 let train_start = if let Some(max_size) = self.max_train_size {
315 train_end.saturating_sub(max_size)
316 } else if let Some(min_size) = self.min_train_size {
317 if train_end < min_size {
318 return Err(TrainError::InvalidParameter(
319 "Not enough samples for min_train_size".to_string(),
320 ));
321 }
322 0
323 } else {
324 0
325 };
326
327 let train_indices: Vec<usize> = (train_start..train_end).collect();
328 let val_indices: Vec<usize> = (val_start..val_end).collect();
329
330 if train_indices.is_empty() {
331 return Err(TrainError::InvalidParameter(
332 "Training set is empty for this fold".to_string(),
333 ));
334 }
335
336 Ok((train_indices, val_indices))
337 }
338}
339
340#[derive(Debug, Clone, Default)]
345pub struct LeaveOneOut;
346
347impl LeaveOneOut {
348 pub fn new() -> Self {
350 Self
351 }
352}
353
354impl CrossValidationSplit for LeaveOneOut {
355 fn num_splits(&self) -> usize {
356 usize::MAX
358 }
359
360 fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
361 if fold >= n_samples {
362 return Err(TrainError::InvalidParameter(format!(
363 "fold {} is out of range [0, {})",
364 fold, n_samples
365 )));
366 }
367
368 let val_indices = vec![fold];
370
371 let mut train_indices: Vec<usize> = (0..fold).collect();
373 train_indices.extend(fold + 1..n_samples);
374
375 Ok((train_indices, val_indices))
376 }
377}
378
379#[derive(Debug, Clone)]
381pub struct CrossValidationResults {
382 pub fold_scores: Vec<f64>,
384 pub fold_metrics: Vec<HashMap<String, f64>>,
386}
387
388impl CrossValidationResults {
389 pub fn new() -> Self {
391 Self {
392 fold_scores: Vec::new(),
393 fold_metrics: Vec::new(),
394 }
395 }
396
397 pub fn add_fold(&mut self, score: f64, metrics: HashMap<String, f64>) {
399 self.fold_scores.push(score);
400 self.fold_metrics.push(metrics);
401 }
402
403 pub fn mean_score(&self) -> f64 {
405 if self.fold_scores.is_empty() {
406 return 0.0;
407 }
408 self.fold_scores.iter().sum::<f64>() / self.fold_scores.len() as f64
409 }
410
411 pub fn std_score(&self) -> f64 {
413 if self.fold_scores.len() <= 1 {
414 return 0.0;
415 }
416
417 let mean = self.mean_score();
418 let variance = self
419 .fold_scores
420 .iter()
421 .map(|&score| (score - mean).powi(2))
422 .sum::<f64>()
423 / (self.fold_scores.len() - 1) as f64;
424
425 variance.sqrt()
426 }
427
428 pub fn mean_metric(&self, metric_name: &str) -> Option<f64> {
430 if self.fold_metrics.is_empty() {
431 return None;
432 }
433
434 let mut sum = 0.0;
435 let mut count = 0;
436
437 for metrics in &self.fold_metrics {
438 if let Some(&value) = metrics.get(metric_name) {
439 sum += value;
440 count += 1;
441 }
442 }
443
444 if count > 0 {
445 Some(sum / count as f64)
446 } else {
447 None
448 }
449 }
450
451 pub fn num_folds(&self) -> usize {
453 self.fold_scores.len()
454 }
455}
456
457impl Default for CrossValidationResults {
458 fn default() -> Self {
459 Self::new()
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 fn test_kfold_basic() {
469 let kfold = KFold::new(3).unwrap();
470 assert_eq!(kfold.num_splits(), 3);
471
472 let (train, val) = kfold.get_split(0, 10).unwrap();
473 assert!(!train.is_empty());
474 assert!(!val.is_empty());
475
476 for &idx in &val {
478 assert!(!train.contains(&idx));
479 }
480
481 let mut all_indices = train.clone();
483 all_indices.extend(&val);
484 all_indices.sort();
485 assert_eq!(all_indices, (0..10).collect::<Vec<_>>());
486 }
487
488 #[test]
489 fn test_kfold_with_shuffle() {
490 let kfold = KFold::new(3).unwrap().with_shuffle(42);
491 let (train1, val1) = kfold.get_split(0, 10).unwrap();
492 let (train2, val2) = kfold.get_split(0, 10).unwrap();
493
494 assert_eq!(train1, train2);
496 assert_eq!(val1, val2);
497 }
498
499 #[test]
500 fn test_kfold_invalid() {
501 assert!(KFold::new(1).is_err());
502 let kfold = KFold::new(3).unwrap();
503 assert!(kfold.get_split(5, 10).is_err()); }
505
506 #[test]
507 fn test_stratified_kfold() {
508 let skfold = StratifiedKFold::new(3).unwrap();
509 assert_eq!(skfold.num_splits(), 3);
510
511 let labels = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
513
514 let (_train, val) = skfold.get_stratified_split(0, &labels).unwrap();
515
516 let mut val_classes: Vec<usize> = val.iter().map(|&i| labels[i]).collect();
518 val_classes.sort();
519 val_classes.dedup();
520
521 assert!(!val.is_empty());
523 }
524
525 #[test]
526 fn test_time_series_split() {
527 let ts_split = TimeSeriesSplit::new(3).unwrap();
528 assert_eq!(ts_split.num_splits(), 3);
529
530 let (train, val) = ts_split.get_split(0, 10).unwrap();
531
532 if !train.is_empty() && !val.is_empty() {
534 assert!(train.iter().max().unwrap() < val.iter().min().unwrap());
535 }
536 }
537
538 #[test]
539 fn test_time_series_split_with_window() {
540 let ts_split = TimeSeriesSplit::new(3)
541 .unwrap()
542 .with_min_train_size(2)
543 .with_max_train_size(5);
544
545 let (train, val) = ts_split.get_split(1, 20).unwrap();
546
547 assert!(train.len() <= 5);
549 assert!(!val.is_empty());
550 }
551
552 #[test]
553 fn test_time_series_split_invalid() {
554 let ts_split = TimeSeriesSplit::new(3).unwrap();
555
556 assert!(ts_split.get_split(0, 2).is_err());
558
559 assert!(ts_split.get_split(5, 10).is_err());
561 }
562
563 #[test]
564 fn test_leave_one_out() {
565 let loo = LeaveOneOut::new();
566
567 let (train, val) = loo.get_split(0, 5).unwrap();
568
569 assert_eq!(val.len(), 1);
570 assert_eq!(train.len(), 4);
571 assert_eq!(val[0], 0);
572
573 let (train, val) = loo.get_split(3, 5).unwrap();
574 assert_eq!(val[0], 3);
575 assert_eq!(train.len(), 4);
576 }
577
578 #[test]
579 fn test_leave_one_out_invalid() {
580 let loo = LeaveOneOut::new();
581 assert!(loo.get_split(5, 5).is_err()); }
583
584 #[test]
585 fn test_cv_results() {
586 let mut results = CrossValidationResults::new();
587
588 let mut metrics1 = HashMap::new();
589 metrics1.insert("accuracy".to_string(), 0.9);
590 results.add_fold(0.85, metrics1);
591
592 let mut metrics2 = HashMap::new();
593 metrics2.insert("accuracy".to_string(), 0.95);
594 results.add_fold(0.90, metrics2);
595
596 let mut metrics3 = HashMap::new();
597 metrics3.insert("accuracy".to_string(), 0.92);
598 results.add_fold(0.88, metrics3);
599
600 assert_eq!(results.num_folds(), 3);
601
602 let mean = results.mean_score();
604 assert!((mean - 0.8766666).abs() < 1e-6);
605
606 let std = results.std_score();
608 assert!(std > 0.0);
609
610 let mean_acc = results.mean_metric("accuracy").unwrap();
612 assert!((mean_acc - 0.923333).abs() < 1e-5);
613 }
614
615 #[test]
616 fn test_cv_results_empty() {
617 let results = CrossValidationResults::new();
618 assert_eq!(results.mean_score(), 0.0);
619 assert_eq!(results.std_score(), 0.0);
620 assert_eq!(results.num_folds(), 0);
621 assert!(results.mean_metric("accuracy").is_none());
622 }
623
624 #[test]
625 fn test_kfold_all_folds() {
626 let kfold = KFold::new(5).unwrap();
627 let n_samples = 20;
628
629 let mut all_val_indices = Vec::new();
630
631 for fold in 0..5 {
633 let (_, val) = kfold.get_split(fold, n_samples).unwrap();
634 all_val_indices.extend(val);
635 }
636
637 all_val_indices.sort();
638
639 assert_eq!(all_val_indices.len(), n_samples);
641 assert_eq!(all_val_indices, (0..n_samples).collect::<Vec<_>>());
642 }
643}