Skip to main content

scirs2_core/
data_split.rs

1//! Data splitting utilities for machine learning workflows
2//!
3//! This module provides tools for splitting datasets into training, validation,
4//! and test sets using various strategies:
5//!
6//! - [`train_test_split`] - Simple random train/test split
7//! - [`stratified_train_test_split`] - Stratified split preserving class proportions
8//! - [`KFold`] - K-fold cross-validation
9//! - [`StratifiedKFold`] - Stratified K-fold cross-validation
10//! - [`LeaveOneOut`] - Leave-one-out cross-validation
11//! - [`TimeSeriesSplit`] - Time series cross-validation (expanding or sliding window)
12//! - [`GroupKFold`] - Group K-fold (keeps groups intact)
13//! - [`ShuffleSplit`] - Repeated random train/test splits
14
15use crate::error::{CoreError, CoreResult, ErrorContext};
16use rand::seq::SliceRandom;
17use rand::Rng;
18use rand::SeedableRng;
19use rand_chacha::ChaCha8Rng;
20use std::collections::HashMap;
21use std::hash::Hash;
22
23/// Indices for a single split (train indices, test indices).
24pub type SplitIndices = (Vec<usize>, Vec<usize>);
25
26// ---------------------------------------------------------------------------
27// train_test_split
28// ---------------------------------------------------------------------------
29
30/// Split data indices into training and test sets.
31///
32/// # Arguments
33///
34/// * `n_samples` - Total number of samples
35/// * `test_size` - Fraction of data to use for testing (0.0 .. 1.0)
36/// * `seed` - Optional random seed for reproducibility
37///
38/// # Example
39///
40/// ```
41/// use scirs2_core::data_split::train_test_split;
42///
43/// let (train, test) = train_test_split(100, 0.2, Some(42)).expect("split failed");
44/// assert_eq!(train.len() + test.len(), 100);
45/// assert_eq!(test.len(), 20);
46/// ```
47pub fn train_test_split(
48    n_samples: usize,
49    test_size: f64,
50    seed: Option<u64>,
51) -> CoreResult<SplitIndices> {
52    validate_split_params(n_samples, test_size)?;
53
54    let n_test = (n_samples as f64 * test_size).round() as usize;
55    let n_test = n_test.max(1).min(n_samples - 1);
56
57    let mut indices: Vec<usize> = (0..n_samples).collect();
58    let mut rng = make_rng(seed);
59    indices.shuffle(&mut rng);
60
61    let test_indices = indices[..n_test].to_vec();
62    let train_indices = indices[n_test..].to_vec();
63    Ok((train_indices, test_indices))
64}
65
66/// Stratified train/test split that preserves the proportion of each class.
67///
68/// # Arguments
69///
70/// * `labels` - Class labels for each sample
71/// * `test_size` - Fraction of data for testing (0.0 .. 1.0)
72/// * `seed` - Optional random seed
73///
74/// # Example
75///
76/// ```
77/// use scirs2_core::data_split::stratified_train_test_split;
78///
79/// let labels = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2];
80/// let (train, test) = stratified_train_test_split(&labels, 0.3, Some(42)).expect("split");
81/// assert_eq!(train.len() + test.len(), 10);
82/// ```
83pub fn stratified_train_test_split<L: Eq + Hash + Clone>(
84    labels: &[L],
85    test_size: f64,
86    seed: Option<u64>,
87) -> CoreResult<SplitIndices> {
88    let n_samples = labels.len();
89    validate_split_params(n_samples, test_size)?;
90
91    let mut class_indices: HashMap<&L, Vec<usize>> = HashMap::new();
92    for (i, label) in labels.iter().enumerate() {
93        class_indices.entry(label).or_default().push(i);
94    }
95
96    let mut rng = make_rng(seed);
97    let mut train_indices = Vec::new();
98    let mut test_indices = Vec::new();
99
100    for (_label, mut indices) in class_indices {
101        indices.shuffle(&mut rng);
102        let n_class_test = (indices.len() as f64 * test_size).round() as usize;
103        let n_class_test = n_class_test.max(1).min(indices.len().saturating_sub(1));
104        test_indices.extend_from_slice(&indices[..n_class_test]);
105        train_indices.extend_from_slice(&indices[n_class_test..]);
106    }
107
108    Ok((train_indices, test_indices))
109}
110
111// ---------------------------------------------------------------------------
112// KFold
113// ---------------------------------------------------------------------------
114
115/// K-fold cross-validation splitter.
116///
117/// Splits the data into K consecutive folds. Each fold is used once as
118/// validation while the remaining K-1 folds form the training set.
119///
120/// # Example
121///
122/// ```
123/// use scirs2_core::data_split::KFold;
124///
125/// let kf = KFold::new(5, true, Some(42)).expect("kfold");
126/// let splits: Vec<_> = kf.split(100).collect();
127/// assert_eq!(splits.len(), 5);
128/// for (train, test) in &splits {
129///     assert_eq!(train.len() + test.len(), 100);
130/// }
131/// ```
132#[derive(Debug, Clone)]
133pub struct KFold {
134    /// Number of folds
135    pub n_splits: usize,
136    /// Whether to shuffle before splitting
137    pub shuffle: bool,
138    /// Random seed
139    pub seed: Option<u64>,
140}
141
142impl KFold {
143    /// Create a new KFold splitter.
144    pub fn new(n_splits: usize, shuffle: bool, seed: Option<u64>) -> CoreResult<Self> {
145        if n_splits < 2 {
146            return Err(CoreError::ValueError(ErrorContext::new(
147                "n_splits must be >= 2 for KFold",
148            )));
149        }
150        Ok(Self {
151            n_splits,
152            shuffle,
153            seed,
154        })
155    }
156
157    /// Generate splits for `n_samples` data points.
158    pub fn split(&self, n_samples: usize) -> impl Iterator<Item = SplitIndices> {
159        let mut indices: Vec<usize> = (0..n_samples).collect();
160        if self.shuffle {
161            let mut rng = make_rng(self.seed);
162            indices.shuffle(&mut rng);
163        }
164
165        let n_splits = self.n_splits;
166        let fold_sizes = compute_fold_sizes(n_samples, n_splits);
167        let mut folds: Vec<Vec<usize>> = Vec::with_capacity(n_splits);
168        let mut offset = 0;
169        for &size in &fold_sizes {
170            folds.push(indices[offset..offset + size].to_vec());
171            offset += size;
172        }
173
174        (0..n_splits).map(move |k| {
175            let test = folds[k].clone();
176            let train: Vec<usize> = folds
177                .iter()
178                .enumerate()
179                .filter(|(i, _)| *i != k)
180                .flat_map(|(_, f)| f.iter().copied())
181                .collect();
182            (train, test)
183        })
184    }
185}
186
187// ---------------------------------------------------------------------------
188// StratifiedKFold
189// ---------------------------------------------------------------------------
190
191/// Stratified K-fold cross-validation.
192///
193/// Each fold preserves the percentage of samples for each class.
194///
195/// # Example
196///
197/// ```
198/// use scirs2_core::data_split::StratifiedKFold;
199///
200/// let labels = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
201/// let skf = StratifiedKFold::new(5, true, Some(42)).expect("skf");
202/// let splits: Vec<_> = skf.split(&labels);
203/// assert_eq!(splits.len(), 5);
204/// ```
205#[derive(Debug, Clone)]
206pub struct StratifiedKFold {
207    /// Number of folds
208    pub n_splits: usize,
209    /// Whether to shuffle within each class
210    pub shuffle: bool,
211    /// Random seed
212    pub seed: Option<u64>,
213}
214
215impl StratifiedKFold {
216    /// Create a new StratifiedKFold.
217    pub fn new(n_splits: usize, shuffle: bool, seed: Option<u64>) -> CoreResult<Self> {
218        if n_splits < 2 {
219            return Err(CoreError::ValueError(ErrorContext::new(
220                "n_splits must be >= 2 for StratifiedKFold",
221            )));
222        }
223        Ok(Self {
224            n_splits,
225            shuffle,
226            seed,
227        })
228    }
229
230    /// Generate stratified splits.
231    pub fn split<L: Eq + Hash + Clone>(&self, labels: &[L]) -> Vec<SplitIndices> {
232        let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
233        let mut label_to_int: HashMap<&L, usize> = HashMap::new();
234        let mut next_id = 0usize;
235
236        for (i, label) in labels.iter().enumerate() {
237            let class_id = *label_to_int.entry(label).or_insert_with(|| {
238                let id = next_id;
239                next_id += 1;
240                id
241            });
242            class_indices.entry(class_id).or_default().push(i);
243        }
244
245        let mut rng = make_rng(self.seed);
246        if self.shuffle {
247            for indices in class_indices.values_mut() {
248                indices.shuffle(&mut rng);
249            }
250        }
251
252        // Assign each sample in each class to a fold in round-robin fashion
253        let n_samples = labels.len();
254        let mut fold_assignment = vec![0usize; n_samples];
255        for indices in class_indices.values() {
256            for (pos, &idx) in indices.iter().enumerate() {
257                fold_assignment[idx] = pos % self.n_splits;
258            }
259        }
260
261        (0..self.n_splits)
262            .map(|k| {
263                let mut train = Vec::new();
264                let mut test = Vec::new();
265                for (i, &fold) in fold_assignment.iter().enumerate() {
266                    if fold == k {
267                        test.push(i);
268                    } else {
269                        train.push(i);
270                    }
271                }
272                (train, test)
273            })
274            .collect()
275    }
276}
277
278// ---------------------------------------------------------------------------
279// LeaveOneOut
280// ---------------------------------------------------------------------------
281
282/// Leave-one-out cross-validation.
283///
284/// Each sample is used once as the test set. Equivalent to KFold with
285/// n_splits = n_samples.
286///
287/// # Example
288///
289/// ```
290/// use scirs2_core::data_split::LeaveOneOut;
291///
292/// let loo = LeaveOneOut;
293/// let splits: Vec<_> = loo.split(5).collect();
294/// assert_eq!(splits.len(), 5);
295/// for (train, test) in &splits {
296///     assert_eq!(test.len(), 1);
297///     assert_eq!(train.len(), 4);
298/// }
299/// ```
300pub struct LeaveOneOut;
301
302impl LeaveOneOut {
303    /// Generate leave-one-out splits.
304    pub fn split(&self, n_samples: usize) -> impl Iterator<Item = SplitIndices> {
305        (0..n_samples).map(move |i| {
306            let test = vec![i];
307            let train: Vec<usize> = (0..n_samples).filter(|&j| j != i).collect();
308            (train, test)
309        })
310    }
311}
312
313// ---------------------------------------------------------------------------
314// TimeSeriesSplit
315// ---------------------------------------------------------------------------
316
317/// Time series split mode.
318#[derive(Debug, Clone, Copy, PartialEq, Eq)]
319pub enum TimeSeriesMode {
320    /// Expanding window: training set grows with each split
321    Expanding,
322    /// Sliding window: training window has fixed maximum size
323    Sliding,
324}
325
326/// Time series cross-validation splitter.
327///
328/// Provides train/test indices for time series data where the test set is always
329/// in the future relative to the training set.
330///
331/// # Example
332///
333/// ```
334/// use scirs2_core::data_split::{TimeSeriesSplit, TimeSeriesMode};
335///
336/// let ts = TimeSeriesSplit::new(3, TimeSeriesMode::Expanding, None).expect("ts");
337/// let splits: Vec<_> = ts.split(20);
338/// assert_eq!(splits.len(), 3);
339/// // Training sets grow: each has more data than the previous
340/// ```
341#[derive(Debug, Clone)]
342pub struct TimeSeriesSplit {
343    /// Number of splits
344    pub n_splits: usize,
345    /// Splitting mode
346    pub mode: TimeSeriesMode,
347    /// Maximum training set size (only for Sliding mode)
348    pub max_train_size: Option<usize>,
349    /// Gap between train and test sets
350    pub gap: usize,
351}
352
353impl TimeSeriesSplit {
354    /// Create a new TimeSeriesSplit.
355    pub fn new(
356        n_splits: usize,
357        mode: TimeSeriesMode,
358        max_train_size: Option<usize>,
359    ) -> CoreResult<Self> {
360        if n_splits < 1 {
361            return Err(CoreError::ValueError(ErrorContext::new(
362                "n_splits must be >= 1 for TimeSeriesSplit",
363            )));
364        }
365        Ok(Self {
366            n_splits,
367            mode,
368            max_train_size,
369            gap: 0,
370        })
371    }
372
373    /// Set the gap between train and test sets.
374    #[must_use]
375    pub fn with_gap(mut self, gap: usize) -> Self {
376        self.gap = gap;
377        self
378    }
379
380    /// Generate time series splits.
381    pub fn split(&self, n_samples: usize) -> Vec<SplitIndices> {
382        let test_size = n_samples / (self.n_splits + 1);
383        let test_size = test_size.max(1);
384
385        let mut splits = Vec::with_capacity(self.n_splits);
386
387        for k in 0..self.n_splits {
388            let test_start = (k + 1) * test_size;
389            let test_end = ((k + 2) * test_size).min(n_samples);
390            if test_start >= n_samples {
391                break;
392            }
393            let train_end = test_start.saturating_sub(self.gap);
394            let train_start = match self.mode {
395                TimeSeriesMode::Expanding => 0,
396                TimeSeriesMode::Sliding => {
397                    if let Some(max_size) = self.max_train_size {
398                        train_end.saturating_sub(max_size)
399                    } else {
400                        0
401                    }
402                }
403            };
404
405            if train_start >= train_end || test_start >= test_end {
406                continue;
407            }
408
409            let train: Vec<usize> = (train_start..train_end).collect();
410            let test: Vec<usize> = (test_start..test_end).collect();
411            splits.push((train, test));
412        }
413
414        splits
415    }
416}
417
418// ---------------------------------------------------------------------------
419// GroupKFold
420// ---------------------------------------------------------------------------
421
422/// Group K-fold cross-validation.
423///
424/// Ensures that the same group is not represented in both training and test
425/// sets. Useful when samples from the same subject/experiment should stay together.
426///
427/// # Example
428///
429/// ```
430/// use scirs2_core::data_split::GroupKFold;
431///
432/// let groups = vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4];
433/// let gkf = GroupKFold::new(5).expect("gkf");
434/// let splits = gkf.split(&groups);
435/// assert_eq!(splits.len(), 5);
436/// ```
437#[derive(Debug, Clone)]
438pub struct GroupKFold {
439    /// Number of folds
440    pub n_splits: usize,
441}
442
443impl GroupKFold {
444    /// Create a new GroupKFold.
445    pub fn new(n_splits: usize) -> CoreResult<Self> {
446        if n_splits < 2 {
447            return Err(CoreError::ValueError(ErrorContext::new(
448                "n_splits must be >= 2 for GroupKFold",
449            )));
450        }
451        Ok(Self { n_splits })
452    }
453
454    /// Generate splits where groups are kept together.
455    pub fn split<G: Eq + Hash + Clone>(&self, groups: &[G]) -> Vec<SplitIndices> {
456        // Collect unique groups and their sample indices
457        let mut group_to_indices: HashMap<usize, Vec<usize>> = HashMap::new();
458        let mut group_to_id: HashMap<&G, usize> = HashMap::new();
459        let mut next_id = 0usize;
460
461        for (i, group) in groups.iter().enumerate() {
462            let gid = *group_to_id.entry(group).or_insert_with(|| {
463                let id = next_id;
464                next_id += 1;
465                id
466            });
467            group_to_indices.entry(gid).or_default().push(i);
468        }
469
470        let n_groups = next_id;
471        let actual_splits = self.n_splits.min(n_groups);
472
473        // Assign groups to folds
474        let mut group_ids: Vec<usize> = (0..n_groups).collect();
475        // Sort by group size (largest first) for balanced folds
476        group_ids.sort_by(|a, b| {
477            let sa = group_to_indices.get(a).map(|v| v.len()).unwrap_or(0);
478            let sb = group_to_indices.get(b).map(|v| v.len()).unwrap_or(0);
479            sb.cmp(&sa)
480        });
481
482        // Greedy assignment: place each group in the fold with fewest samples
483        let mut fold_sizes = vec![0usize; actual_splits];
484        let mut group_fold = vec![0usize; n_groups];
485        for &gid in &group_ids {
486            let min_fold = fold_sizes
487                .iter()
488                .enumerate()
489                .min_by_key(|(_, &s)| s)
490                .map(|(i, _)| i)
491                .unwrap_or(0);
492            group_fold[gid] = min_fold;
493            fold_sizes[min_fold] += group_to_indices.get(&gid).map(|v| v.len()).unwrap_or(0);
494        }
495
496        (0..actual_splits)
497            .map(|k| {
498                let mut train = Vec::new();
499                let mut test = Vec::new();
500                for gid in 0..n_groups {
501                    let indices = group_to_indices.get(&gid).cloned().unwrap_or_default();
502                    if group_fold[gid] == k {
503                        test.extend(indices);
504                    } else {
505                        train.extend(indices);
506                    }
507                }
508                (train, test)
509            })
510            .collect()
511    }
512}
513
514// ---------------------------------------------------------------------------
515// ShuffleSplit
516// ---------------------------------------------------------------------------
517
518/// Repeated random train/test splits.
519///
520/// Generates independent random splits on each iteration.
521///
522/// # Example
523///
524/// ```
525/// use scirs2_core::data_split::ShuffleSplit;
526///
527/// let ss = ShuffleSplit::new(10, 0.2, Some(42)).expect("ss");
528/// let splits: Vec<_> = ss.split(100);
529/// assert_eq!(splits.len(), 10);
530/// for (train, test) in &splits {
531///     assert_eq!(train.len() + test.len(), 100);
532/// }
533/// ```
534#[derive(Debug, Clone)]
535pub struct ShuffleSplit {
536    /// Number of re-shuffled splits
537    pub n_splits: usize,
538    /// Fraction for test set
539    pub test_size: f64,
540    /// Random seed
541    pub seed: Option<u64>,
542}
543
544impl ShuffleSplit {
545    /// Create a new ShuffleSplit.
546    pub fn new(n_splits: usize, test_size: f64, seed: Option<u64>) -> CoreResult<Self> {
547        if n_splits < 1 {
548            return Err(CoreError::ValueError(ErrorContext::new(
549                "n_splits must be >= 1 for ShuffleSplit",
550            )));
551        }
552        if test_size <= 0.0 || test_size >= 1.0 {
553            return Err(CoreError::ValueError(ErrorContext::new(
554                "test_size must be between 0 and 1 (exclusive)",
555            )));
556        }
557        Ok(Self {
558            n_splits,
559            test_size,
560            seed,
561        })
562    }
563
564    /// Generate repeated random splits.
565    pub fn split(&self, n_samples: usize) -> Vec<SplitIndices> {
566        let n_test = ((n_samples as f64) * self.test_size).round() as usize;
567        let n_test = n_test.max(1).min(n_samples - 1);
568
569        let base_seed = self.seed.unwrap_or(0);
570        let mut splits = Vec::with_capacity(self.n_splits);
571
572        for k in 0..self.n_splits {
573            let mut indices: Vec<usize> = (0..n_samples).collect();
574            let mut rng = ChaCha8Rng::seed_from_u64(base_seed.wrapping_add(k as u64));
575            indices.shuffle(&mut rng);
576
577            let test = indices[..n_test].to_vec();
578            let train = indices[n_test..].to_vec();
579            splits.push((train, test));
580        }
581
582        splits
583    }
584}
585
586// ---------------------------------------------------------------------------
587// Helpers
588// ---------------------------------------------------------------------------
589
590fn validate_split_params(n_samples: usize, test_size: f64) -> CoreResult<()> {
591    if n_samples < 2 {
592        return Err(CoreError::ValueError(ErrorContext::new(
593            "Need at least 2 samples to split",
594        )));
595    }
596    if test_size <= 0.0 || test_size >= 1.0 {
597        return Err(CoreError::ValueError(ErrorContext::new(
598            "test_size must be between 0 and 1 (exclusive)",
599        )));
600    }
601    Ok(())
602}
603
604fn make_rng(seed: Option<u64>) -> ChaCha8Rng {
605    match seed {
606        Some(s) => ChaCha8Rng::seed_from_u64(s),
607        None => ChaCha8Rng::seed_from_u64(rand::rng().random()),
608    }
609}
610
611fn compute_fold_sizes(n_samples: usize, n_splits: usize) -> Vec<usize> {
612    let base_size = n_samples / n_splits;
613    let remainder = n_samples % n_splits;
614    let mut sizes = vec![base_size; n_splits];
615    for i in 0..remainder {
616        sizes[i] += 1;
617    }
618    sizes
619}
620
621// ---------------------------------------------------------------------------
622// Tests
623// ---------------------------------------------------------------------------
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628
629    #[test]
630    fn test_train_test_split_basic() {
631        let (train, test) = train_test_split(100, 0.2, Some(42)).expect("split");
632        assert_eq!(train.len() + test.len(), 100);
633        assert_eq!(test.len(), 20);
634        // Ensure no duplicates
635        let mut all: Vec<usize> = train.iter().chain(test.iter()).copied().collect();
636        all.sort();
637        all.dedup();
638        assert_eq!(all.len(), 100);
639    }
640
641    #[test]
642    fn test_train_test_split_reproducible() {
643        let (train1, test1) = train_test_split(50, 0.3, Some(123)).expect("split1");
644        let (train2, test2) = train_test_split(50, 0.3, Some(123)).expect("split2");
645        assert_eq!(train1, train2);
646        assert_eq!(test1, test2);
647    }
648
649    #[test]
650    fn test_train_test_split_invalid() {
651        assert!(train_test_split(1, 0.5, None).is_err());
652        assert!(train_test_split(10, 0.0, None).is_err());
653        assert!(train_test_split(10, 1.0, None).is_err());
654    }
655
656    #[test]
657    fn test_stratified_split() {
658        let labels = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
659        let (train, test) = stratified_train_test_split(&labels, 0.4, Some(42)).expect("split");
660        assert_eq!(train.len() + test.len(), 10);
661        // Both classes should be represented in test
662        let test_labels: Vec<i32> = test.iter().map(|&i| labels[i]).collect();
663        assert!(test_labels.contains(&0));
664        assert!(test_labels.contains(&1));
665    }
666
667    #[test]
668    fn test_kfold_basic() {
669        let kf = KFold::new(5, false, None).expect("kf");
670        let splits: Vec<_> = kf.split(100).collect();
671        assert_eq!(splits.len(), 5);
672        for (train, test) in &splits {
673            assert_eq!(train.len() + test.len(), 100);
674        }
675    }
676
677    #[test]
678    fn test_kfold_shuffle() {
679        let kf = KFold::new(3, true, Some(42)).expect("kf");
680        let splits: Vec<_> = kf.split(30).collect();
681        assert_eq!(splits.len(), 3);
682        for (train, test) in &splits {
683            assert_eq!(train.len() + test.len(), 30);
684            assert_eq!(test.len(), 10);
685        }
686    }
687
688    #[test]
689    fn test_kfold_invalid() {
690        assert!(KFold::new(1, false, None).is_err());
691    }
692
693    #[test]
694    fn test_stratified_kfold() {
695        let labels = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
696        let skf = StratifiedKFold::new(5, true, Some(42)).expect("skf");
697        let splits = skf.split(&labels);
698        assert_eq!(splits.len(), 5);
699        for (train, test) in &splits {
700            assert_eq!(train.len() + test.len(), 10);
701        }
702    }
703
704    #[test]
705    fn test_leave_one_out() {
706        let loo = LeaveOneOut;
707        let splits: Vec<_> = loo.split(5).collect();
708        assert_eq!(splits.len(), 5);
709        for (train, test) in &splits {
710            assert_eq!(test.len(), 1);
711            assert_eq!(train.len(), 4);
712        }
713    }
714
715    #[test]
716    fn test_time_series_expanding() {
717        let ts = TimeSeriesSplit::new(3, TimeSeriesMode::Expanding, None).expect("ts");
718        let splits = ts.split(20);
719        assert_eq!(splits.len(), 3);
720        // In expanding mode, training sets should grow
721        let train_sizes: Vec<usize> = splits.iter().map(|(t, _)| t.len()).collect();
722        for i in 1..train_sizes.len() {
723            assert!(
724                train_sizes[i] >= train_sizes[i - 1],
725                "expanding training sets should grow"
726            );
727        }
728    }
729
730    #[test]
731    fn test_time_series_sliding() {
732        let ts = TimeSeriesSplit::new(3, TimeSeriesMode::Sliding, Some(5)).expect("ts");
733        let splits = ts.split(20);
734        // All training sets should have at most 5 samples
735        for (train, _test) in &splits {
736            assert!(train.len() <= 5, "sliding window violated max_train_size");
737        }
738    }
739
740    #[test]
741    fn test_time_series_with_gap() {
742        let ts = TimeSeriesSplit::new(3, TimeSeriesMode::Expanding, None)
743            .expect("ts")
744            .with_gap(2);
745        let splits = ts.split(20);
746        for (train, test) in &splits {
747            if !train.is_empty() && !test.is_empty() {
748                let train_max = *train.iter().max().unwrap_or(&0);
749                let test_min = *test.iter().min().unwrap_or(&0);
750                assert!(test_min > train_max, "gap should separate train and test");
751            }
752        }
753    }
754
755    #[test]
756    fn test_group_kfold() {
757        let groups = vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4];
758        let gkf = GroupKFold::new(5).expect("gkf");
759        let splits = gkf.split(&groups);
760        assert_eq!(splits.len(), 5);
761
762        // Verify no group appears in both train and test
763        for (train, test) in &splits {
764            let train_groups: std::collections::HashSet<i32> =
765                train.iter().map(|&i| groups[i]).collect();
766            let test_groups: std::collections::HashSet<i32> =
767                test.iter().map(|&i| groups[i]).collect();
768            let overlap: Vec<_> = train_groups.intersection(&test_groups).collect();
769            assert!(
770                overlap.is_empty(),
771                "groups should not overlap: {:?}",
772                overlap
773            );
774        }
775    }
776
777    #[test]
778    fn test_group_kfold_string_groups() {
779        let groups = vec!["a", "a", "b", "b", "c", "c"];
780        let gkf = GroupKFold::new(3).expect("gkf");
781        let splits = gkf.split(&groups);
782        assert_eq!(splits.len(), 3);
783    }
784
785    #[test]
786    fn test_shuffle_split() {
787        let ss = ShuffleSplit::new(10, 0.2, Some(42)).expect("ss");
788        let splits = ss.split(100);
789        assert_eq!(splits.len(), 10);
790        for (train, test) in &splits {
791            assert_eq!(train.len() + test.len(), 100);
792            assert_eq!(test.len(), 20);
793        }
794    }
795
796    #[test]
797    fn test_shuffle_split_different_seeds() {
798        let ss = ShuffleSplit::new(3, 0.3, Some(42)).expect("ss");
799        let splits = ss.split(50);
800        // Each split should be different
801        assert_ne!(splits[0].1, splits[1].1);
802    }
803
804    #[test]
805    fn test_shuffle_split_invalid() {
806        assert!(ShuffleSplit::new(0, 0.2, None).is_err());
807        assert!(ShuffleSplit::new(5, 0.0, None).is_err());
808        assert!(ShuffleSplit::new(5, 1.0, None).is_err());
809    }
810
811    #[test]
812    fn test_fold_sizes_even() {
813        let sizes = compute_fold_sizes(10, 5);
814        assert_eq!(sizes, vec![2, 2, 2, 2, 2]);
815    }
816
817    #[test]
818    fn test_fold_sizes_uneven() {
819        let sizes = compute_fold_sizes(13, 5);
820        let total: usize = sizes.iter().sum();
821        assert_eq!(total, 13);
822        // First 3 should be 3, last 2 should be 2
823        assert_eq!(sizes, vec![3, 3, 3, 2, 2]);
824    }
825
826    #[test]
827    fn test_kfold_no_overlap() {
828        let kf = KFold::new(4, true, Some(99)).expect("kf");
829        let splits: Vec<_> = kf.split(20).collect();
830        // All test indices across folds should cover all samples exactly once
831        let mut all_test: Vec<usize> = splits.iter().flat_map(|(_, t)| t.iter().copied()).collect();
832        all_test.sort();
833        all_test.dedup();
834        assert_eq!(all_test.len(), 20);
835    }
836
837    #[test]
838    fn test_stratified_kfold_proportions() {
839        // 70% class 0, 30% class 1
840        let labels: Vec<i32> = vec![0; 70].into_iter().chain(vec![1; 30]).collect();
841        let skf = StratifiedKFold::new(5, false, None).expect("skf");
842        let splits = skf.split(&labels);
843        for (_, test) in &splits {
844            let n_class0 = test.iter().filter(|&&i| labels[i] == 0).count();
845            let n_class1 = test.iter().filter(|&&i| labels[i] == 1).count();
846            // Proportions should be roughly maintained
847            if !test.is_empty() {
848                let ratio = n_class0 as f64 / test.len() as f64;
849                assert!(
850                    ratio > 0.5 && ratio < 0.9,
851                    "class 0 ratio {} not within expected range",
852                    ratio
853                );
854            }
855        }
856    }
857}