1use crate::error::{CoreError, CoreResult, ErrorContext};
16use rand::seq::SliceRandom;
17use rand::Rng;
18use rand::SeedableRng;
19use rand_chacha::ChaCha8Rng;
20use std::collections::HashMap;
21use std::hash::Hash;
22
23pub type SplitIndices = (Vec<usize>, Vec<usize>);
25
26pub fn train_test_split(
48 n_samples: usize,
49 test_size: f64,
50 seed: Option<u64>,
51) -> CoreResult<SplitIndices> {
52 validate_split_params(n_samples, test_size)?;
53
54 let n_test = (n_samples as f64 * test_size).round() as usize;
55 let n_test = n_test.max(1).min(n_samples - 1);
56
57 let mut indices: Vec<usize> = (0..n_samples).collect();
58 let mut rng = make_rng(seed);
59 indices.shuffle(&mut rng);
60
61 let test_indices = indices[..n_test].to_vec();
62 let train_indices = indices[n_test..].to_vec();
63 Ok((train_indices, test_indices))
64}
65
66pub fn stratified_train_test_split<L: Eq + Hash + Clone>(
84 labels: &[L],
85 test_size: f64,
86 seed: Option<u64>,
87) -> CoreResult<SplitIndices> {
88 let n_samples = labels.len();
89 validate_split_params(n_samples, test_size)?;
90
91 let mut class_indices: HashMap<&L, Vec<usize>> = HashMap::new();
92 for (i, label) in labels.iter().enumerate() {
93 class_indices.entry(label).or_default().push(i);
94 }
95
96 let mut rng = make_rng(seed);
97 let mut train_indices = Vec::new();
98 let mut test_indices = Vec::new();
99
100 for (_label, mut indices) in class_indices {
101 indices.shuffle(&mut rng);
102 let n_class_test = (indices.len() as f64 * test_size).round() as usize;
103 let n_class_test = n_class_test.max(1).min(indices.len().saturating_sub(1));
104 test_indices.extend_from_slice(&indices[..n_class_test]);
105 train_indices.extend_from_slice(&indices[n_class_test..]);
106 }
107
108 Ok((train_indices, test_indices))
109}
110
111#[derive(Debug, Clone)]
133pub struct KFold {
134 pub n_splits: usize,
136 pub shuffle: bool,
138 pub seed: Option<u64>,
140}
141
142impl KFold {
143 pub fn new(n_splits: usize, shuffle: bool, seed: Option<u64>) -> CoreResult<Self> {
145 if n_splits < 2 {
146 return Err(CoreError::ValueError(ErrorContext::new(
147 "n_splits must be >= 2 for KFold",
148 )));
149 }
150 Ok(Self {
151 n_splits,
152 shuffle,
153 seed,
154 })
155 }
156
157 pub fn split(&self, n_samples: usize) -> impl Iterator<Item = SplitIndices> {
159 let mut indices: Vec<usize> = (0..n_samples).collect();
160 if self.shuffle {
161 let mut rng = make_rng(self.seed);
162 indices.shuffle(&mut rng);
163 }
164
165 let n_splits = self.n_splits;
166 let fold_sizes = compute_fold_sizes(n_samples, n_splits);
167 let mut folds: Vec<Vec<usize>> = Vec::with_capacity(n_splits);
168 let mut offset = 0;
169 for &size in &fold_sizes {
170 folds.push(indices[offset..offset + size].to_vec());
171 offset += size;
172 }
173
174 (0..n_splits).map(move |k| {
175 let test = folds[k].clone();
176 let train: Vec<usize> = folds
177 .iter()
178 .enumerate()
179 .filter(|(i, _)| *i != k)
180 .flat_map(|(_, f)| f.iter().copied())
181 .collect();
182 (train, test)
183 })
184 }
185}
186
187#[derive(Debug, Clone)]
206pub struct StratifiedKFold {
207 pub n_splits: usize,
209 pub shuffle: bool,
211 pub seed: Option<u64>,
213}
214
215impl StratifiedKFold {
216 pub fn new(n_splits: usize, shuffle: bool, seed: Option<u64>) -> CoreResult<Self> {
218 if n_splits < 2 {
219 return Err(CoreError::ValueError(ErrorContext::new(
220 "n_splits must be >= 2 for StratifiedKFold",
221 )));
222 }
223 Ok(Self {
224 n_splits,
225 shuffle,
226 seed,
227 })
228 }
229
230 pub fn split<L: Eq + Hash + Clone>(&self, labels: &[L]) -> Vec<SplitIndices> {
232 let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
233 let mut label_to_int: HashMap<&L, usize> = HashMap::new();
234 let mut next_id = 0usize;
235
236 for (i, label) in labels.iter().enumerate() {
237 let class_id = *label_to_int.entry(label).or_insert_with(|| {
238 let id = next_id;
239 next_id += 1;
240 id
241 });
242 class_indices.entry(class_id).or_default().push(i);
243 }
244
245 let mut rng = make_rng(self.seed);
246 if self.shuffle {
247 for indices in class_indices.values_mut() {
248 indices.shuffle(&mut rng);
249 }
250 }
251
252 let n_samples = labels.len();
254 let mut fold_assignment = vec![0usize; n_samples];
255 for indices in class_indices.values() {
256 for (pos, &idx) in indices.iter().enumerate() {
257 fold_assignment[idx] = pos % self.n_splits;
258 }
259 }
260
261 (0..self.n_splits)
262 .map(|k| {
263 let mut train = Vec::new();
264 let mut test = Vec::new();
265 for (i, &fold) in fold_assignment.iter().enumerate() {
266 if fold == k {
267 test.push(i);
268 } else {
269 train.push(i);
270 }
271 }
272 (train, test)
273 })
274 .collect()
275 }
276}
277
278pub struct LeaveOneOut;
301
302impl LeaveOneOut {
303 pub fn split(&self, n_samples: usize) -> impl Iterator<Item = SplitIndices> {
305 (0..n_samples).map(move |i| {
306 let test = vec![i];
307 let train: Vec<usize> = (0..n_samples).filter(|&j| j != i).collect();
308 (train, test)
309 })
310 }
311}
312
313#[derive(Debug, Clone, Copy, PartialEq, Eq)]
319pub enum TimeSeriesMode {
320 Expanding,
322 Sliding,
324}
325
326#[derive(Debug, Clone)]
342pub struct TimeSeriesSplit {
343 pub n_splits: usize,
345 pub mode: TimeSeriesMode,
347 pub max_train_size: Option<usize>,
349 pub gap: usize,
351}
352
353impl TimeSeriesSplit {
354 pub fn new(
356 n_splits: usize,
357 mode: TimeSeriesMode,
358 max_train_size: Option<usize>,
359 ) -> CoreResult<Self> {
360 if n_splits < 1 {
361 return Err(CoreError::ValueError(ErrorContext::new(
362 "n_splits must be >= 1 for TimeSeriesSplit",
363 )));
364 }
365 Ok(Self {
366 n_splits,
367 mode,
368 max_train_size,
369 gap: 0,
370 })
371 }
372
373 #[must_use]
375 pub fn with_gap(mut self, gap: usize) -> Self {
376 self.gap = gap;
377 self
378 }
379
380 pub fn split(&self, n_samples: usize) -> Vec<SplitIndices> {
382 let test_size = n_samples / (self.n_splits + 1);
383 let test_size = test_size.max(1);
384
385 let mut splits = Vec::with_capacity(self.n_splits);
386
387 for k in 0..self.n_splits {
388 let test_start = (k + 1) * test_size;
389 let test_end = ((k + 2) * test_size).min(n_samples);
390 if test_start >= n_samples {
391 break;
392 }
393 let train_end = test_start.saturating_sub(self.gap);
394 let train_start = match self.mode {
395 TimeSeriesMode::Expanding => 0,
396 TimeSeriesMode::Sliding => {
397 if let Some(max_size) = self.max_train_size {
398 train_end.saturating_sub(max_size)
399 } else {
400 0
401 }
402 }
403 };
404
405 if train_start >= train_end || test_start >= test_end {
406 continue;
407 }
408
409 let train: Vec<usize> = (train_start..train_end).collect();
410 let test: Vec<usize> = (test_start..test_end).collect();
411 splits.push((train, test));
412 }
413
414 splits
415 }
416}
417
418#[derive(Debug, Clone)]
438pub struct GroupKFold {
439 pub n_splits: usize,
441}
442
443impl GroupKFold {
444 pub fn new(n_splits: usize) -> CoreResult<Self> {
446 if n_splits < 2 {
447 return Err(CoreError::ValueError(ErrorContext::new(
448 "n_splits must be >= 2 for GroupKFold",
449 )));
450 }
451 Ok(Self { n_splits })
452 }
453
454 pub fn split<G: Eq + Hash + Clone>(&self, groups: &[G]) -> Vec<SplitIndices> {
456 let mut group_to_indices: HashMap<usize, Vec<usize>> = HashMap::new();
458 let mut group_to_id: HashMap<&G, usize> = HashMap::new();
459 let mut next_id = 0usize;
460
461 for (i, group) in groups.iter().enumerate() {
462 let gid = *group_to_id.entry(group).or_insert_with(|| {
463 let id = next_id;
464 next_id += 1;
465 id
466 });
467 group_to_indices.entry(gid).or_default().push(i);
468 }
469
470 let n_groups = next_id;
471 let actual_splits = self.n_splits.min(n_groups);
472
473 let mut group_ids: Vec<usize> = (0..n_groups).collect();
475 group_ids.sort_by(|a, b| {
477 let sa = group_to_indices.get(a).map(|v| v.len()).unwrap_or(0);
478 let sb = group_to_indices.get(b).map(|v| v.len()).unwrap_or(0);
479 sb.cmp(&sa)
480 });
481
482 let mut fold_sizes = vec![0usize; actual_splits];
484 let mut group_fold = vec![0usize; n_groups];
485 for &gid in &group_ids {
486 let min_fold = fold_sizes
487 .iter()
488 .enumerate()
489 .min_by_key(|(_, &s)| s)
490 .map(|(i, _)| i)
491 .unwrap_or(0);
492 group_fold[gid] = min_fold;
493 fold_sizes[min_fold] += group_to_indices.get(&gid).map(|v| v.len()).unwrap_or(0);
494 }
495
496 (0..actual_splits)
497 .map(|k| {
498 let mut train = Vec::new();
499 let mut test = Vec::new();
500 for gid in 0..n_groups {
501 let indices = group_to_indices.get(&gid).cloned().unwrap_or_default();
502 if group_fold[gid] == k {
503 test.extend(indices);
504 } else {
505 train.extend(indices);
506 }
507 }
508 (train, test)
509 })
510 .collect()
511 }
512}
513
514#[derive(Debug, Clone)]
535pub struct ShuffleSplit {
536 pub n_splits: usize,
538 pub test_size: f64,
540 pub seed: Option<u64>,
542}
543
544impl ShuffleSplit {
545 pub fn new(n_splits: usize, test_size: f64, seed: Option<u64>) -> CoreResult<Self> {
547 if n_splits < 1 {
548 return Err(CoreError::ValueError(ErrorContext::new(
549 "n_splits must be >= 1 for ShuffleSplit",
550 )));
551 }
552 if test_size <= 0.0 || test_size >= 1.0 {
553 return Err(CoreError::ValueError(ErrorContext::new(
554 "test_size must be between 0 and 1 (exclusive)",
555 )));
556 }
557 Ok(Self {
558 n_splits,
559 test_size,
560 seed,
561 })
562 }
563
564 pub fn split(&self, n_samples: usize) -> Vec<SplitIndices> {
566 let n_test = ((n_samples as f64) * self.test_size).round() as usize;
567 let n_test = n_test.max(1).min(n_samples - 1);
568
569 let base_seed = self.seed.unwrap_or(0);
570 let mut splits = Vec::with_capacity(self.n_splits);
571
572 for k in 0..self.n_splits {
573 let mut indices: Vec<usize> = (0..n_samples).collect();
574 let mut rng = ChaCha8Rng::seed_from_u64(base_seed.wrapping_add(k as u64));
575 indices.shuffle(&mut rng);
576
577 let test = indices[..n_test].to_vec();
578 let train = indices[n_test..].to_vec();
579 splits.push((train, test));
580 }
581
582 splits
583 }
584}
585
586fn validate_split_params(n_samples: usize, test_size: f64) -> CoreResult<()> {
591 if n_samples < 2 {
592 return Err(CoreError::ValueError(ErrorContext::new(
593 "Need at least 2 samples to split",
594 )));
595 }
596 if test_size <= 0.0 || test_size >= 1.0 {
597 return Err(CoreError::ValueError(ErrorContext::new(
598 "test_size must be between 0 and 1 (exclusive)",
599 )));
600 }
601 Ok(())
602}
603
604fn make_rng(seed: Option<u64>) -> ChaCha8Rng {
605 match seed {
606 Some(s) => ChaCha8Rng::seed_from_u64(s),
607 None => ChaCha8Rng::seed_from_u64(rand::rng().random()),
608 }
609}
610
611fn compute_fold_sizes(n_samples: usize, n_splits: usize) -> Vec<usize> {
612 let base_size = n_samples / n_splits;
613 let remainder = n_samples % n_splits;
614 let mut sizes = vec![base_size; n_splits];
615 for i in 0..remainder {
616 sizes[i] += 1;
617 }
618 sizes
619}
620
621#[cfg(test)]
626mod tests {
627 use super::*;
628
629 #[test]
630 fn test_train_test_split_basic() {
631 let (train, test) = train_test_split(100, 0.2, Some(42)).expect("split");
632 assert_eq!(train.len() + test.len(), 100);
633 assert_eq!(test.len(), 20);
634 let mut all: Vec<usize> = train.iter().chain(test.iter()).copied().collect();
636 all.sort();
637 all.dedup();
638 assert_eq!(all.len(), 100);
639 }
640
641 #[test]
642 fn test_train_test_split_reproducible() {
643 let (train1, test1) = train_test_split(50, 0.3, Some(123)).expect("split1");
644 let (train2, test2) = train_test_split(50, 0.3, Some(123)).expect("split2");
645 assert_eq!(train1, train2);
646 assert_eq!(test1, test2);
647 }
648
649 #[test]
650 fn test_train_test_split_invalid() {
651 assert!(train_test_split(1, 0.5, None).is_err());
652 assert!(train_test_split(10, 0.0, None).is_err());
653 assert!(train_test_split(10, 1.0, None).is_err());
654 }
655
656 #[test]
657 fn test_stratified_split() {
658 let labels = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
659 let (train, test) = stratified_train_test_split(&labels, 0.4, Some(42)).expect("split");
660 assert_eq!(train.len() + test.len(), 10);
661 let test_labels: Vec<i32> = test.iter().map(|&i| labels[i]).collect();
663 assert!(test_labels.contains(&0));
664 assert!(test_labels.contains(&1));
665 }
666
667 #[test]
668 fn test_kfold_basic() {
669 let kf = KFold::new(5, false, None).expect("kf");
670 let splits: Vec<_> = kf.split(100).collect();
671 assert_eq!(splits.len(), 5);
672 for (train, test) in &splits {
673 assert_eq!(train.len() + test.len(), 100);
674 }
675 }
676
677 #[test]
678 fn test_kfold_shuffle() {
679 let kf = KFold::new(3, true, Some(42)).expect("kf");
680 let splits: Vec<_> = kf.split(30).collect();
681 assert_eq!(splits.len(), 3);
682 for (train, test) in &splits {
683 assert_eq!(train.len() + test.len(), 30);
684 assert_eq!(test.len(), 10);
685 }
686 }
687
688 #[test]
689 fn test_kfold_invalid() {
690 assert!(KFold::new(1, false, None).is_err());
691 }
692
693 #[test]
694 fn test_stratified_kfold() {
695 let labels = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
696 let skf = StratifiedKFold::new(5, true, Some(42)).expect("skf");
697 let splits = skf.split(&labels);
698 assert_eq!(splits.len(), 5);
699 for (train, test) in &splits {
700 assert_eq!(train.len() + test.len(), 10);
701 }
702 }
703
704 #[test]
705 fn test_leave_one_out() {
706 let loo = LeaveOneOut;
707 let splits: Vec<_> = loo.split(5).collect();
708 assert_eq!(splits.len(), 5);
709 for (train, test) in &splits {
710 assert_eq!(test.len(), 1);
711 assert_eq!(train.len(), 4);
712 }
713 }
714
715 #[test]
716 fn test_time_series_expanding() {
717 let ts = TimeSeriesSplit::new(3, TimeSeriesMode::Expanding, None).expect("ts");
718 let splits = ts.split(20);
719 assert_eq!(splits.len(), 3);
720 let train_sizes: Vec<usize> = splits.iter().map(|(t, _)| t.len()).collect();
722 for i in 1..train_sizes.len() {
723 assert!(
724 train_sizes[i] >= train_sizes[i - 1],
725 "expanding training sets should grow"
726 );
727 }
728 }
729
730 #[test]
731 fn test_time_series_sliding() {
732 let ts = TimeSeriesSplit::new(3, TimeSeriesMode::Sliding, Some(5)).expect("ts");
733 let splits = ts.split(20);
734 for (train, _test) in &splits {
736 assert!(train.len() <= 5, "sliding window violated max_train_size");
737 }
738 }
739
740 #[test]
741 fn test_time_series_with_gap() {
742 let ts = TimeSeriesSplit::new(3, TimeSeriesMode::Expanding, None)
743 .expect("ts")
744 .with_gap(2);
745 let splits = ts.split(20);
746 for (train, test) in &splits {
747 if !train.is_empty() && !test.is_empty() {
748 let train_max = *train.iter().max().unwrap_or(&0);
749 let test_min = *test.iter().min().unwrap_or(&0);
750 assert!(test_min > train_max, "gap should separate train and test");
751 }
752 }
753 }
754
755 #[test]
756 fn test_group_kfold() {
757 let groups = vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4];
758 let gkf = GroupKFold::new(5).expect("gkf");
759 let splits = gkf.split(&groups);
760 assert_eq!(splits.len(), 5);
761
762 for (train, test) in &splits {
764 let train_groups: std::collections::HashSet<i32> =
765 train.iter().map(|&i| groups[i]).collect();
766 let test_groups: std::collections::HashSet<i32> =
767 test.iter().map(|&i| groups[i]).collect();
768 let overlap: Vec<_> = train_groups.intersection(&test_groups).collect();
769 assert!(
770 overlap.is_empty(),
771 "groups should not overlap: {:?}",
772 overlap
773 );
774 }
775 }
776
777 #[test]
778 fn test_group_kfold_string_groups() {
779 let groups = vec!["a", "a", "b", "b", "c", "c"];
780 let gkf = GroupKFold::new(3).expect("gkf");
781 let splits = gkf.split(&groups);
782 assert_eq!(splits.len(), 3);
783 }
784
785 #[test]
786 fn test_shuffle_split() {
787 let ss = ShuffleSplit::new(10, 0.2, Some(42)).expect("ss");
788 let splits = ss.split(100);
789 assert_eq!(splits.len(), 10);
790 for (train, test) in &splits {
791 assert_eq!(train.len() + test.len(), 100);
792 assert_eq!(test.len(), 20);
793 }
794 }
795
796 #[test]
797 fn test_shuffle_split_different_seeds() {
798 let ss = ShuffleSplit::new(3, 0.3, Some(42)).expect("ss");
799 let splits = ss.split(50);
800 assert_ne!(splits[0].1, splits[1].1);
802 }
803
804 #[test]
805 fn test_shuffle_split_invalid() {
806 assert!(ShuffleSplit::new(0, 0.2, None).is_err());
807 assert!(ShuffleSplit::new(5, 0.0, None).is_err());
808 assert!(ShuffleSplit::new(5, 1.0, None).is_err());
809 }
810
811 #[test]
812 fn test_fold_sizes_even() {
813 let sizes = compute_fold_sizes(10, 5);
814 assert_eq!(sizes, vec![2, 2, 2, 2, 2]);
815 }
816
817 #[test]
818 fn test_fold_sizes_uneven() {
819 let sizes = compute_fold_sizes(13, 5);
820 let total: usize = sizes.iter().sum();
821 assert_eq!(total, 13);
822 assert_eq!(sizes, vec![3, 3, 3, 2, 2]);
824 }
825
826 #[test]
827 fn test_kfold_no_overlap() {
828 let kf = KFold::new(4, true, Some(99)).expect("kf");
829 let splits: Vec<_> = kf.split(20).collect();
830 let mut all_test: Vec<usize> = splits.iter().flat_map(|(_, t)| t.iter().copied()).collect();
832 all_test.sort();
833 all_test.dedup();
834 assert_eq!(all_test.len(), 20);
835 }
836
837 #[test]
838 fn test_stratified_kfold_proportions() {
839 let labels: Vec<i32> = vec![0; 70].into_iter().chain(vec![1; 30]).collect();
841 let skf = StratifiedKFold::new(5, false, None).expect("skf");
842 let splits = skf.split(&labels);
843 for (_, test) in &splits {
844 let n_class0 = test.iter().filter(|&&i| labels[i] == 0).count();
845 let n_class1 = test.iter().filter(|&&i| labels[i] == 1).count();
846 if !test.is_empty() {
848 let ratio = n_class0 as f64 / test.len() as f64;
849 assert!(
850 ratio > 0.5 && ratio < 0.9,
851 "class 0 ratio {} not within expected range",
852 ratio
853 );
854 }
855 }
856 }
857}