sklears_model_selection/cv/
shuffle_cv.rs

1//! Shuffle-based cross-validation iterators
2//!
3//! This module provides cross-validation iterators that use random shuffling
4//! to create train/test splits. These methods are particularly useful when
5//! you want to control the exact size of training and test sets while
6//! maintaining randomness in the splits.
7
8use scirs2_core::ndarray::Array1;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::Rng;
11use scirs2_core::random::SeedableRng;
12use scirs2_core::SliceRandomExt;
13use std::collections::HashMap;
14
15use crate::cross_validation::CrossValidator;
16
17/// Utility function to generate combinations
18fn combinations<T: Clone>(items: &[T], k: usize) -> Vec<Vec<T>> {
19    if k == 0 {
20        return vec![vec![]];
21    }
22    if items.is_empty() {
23        return vec![];
24    }
25
26    let first = &items[0];
27    let rest = &items[1..];
28
29    let mut result = Vec::new();
30
31    // Include first element
32    for mut combo in combinations(rest, k - 1) {
33        combo.insert(0, first.clone());
34        result.push(combo);
35    }
36
37    // Exclude first element
38    result.extend(combinations(rest, k));
39
40    result
41}
42
43/// Shuffle Split cross-validator
44///
45/// Generates random train/test splits independent of the number of iterations
46#[derive(Debug, Clone)]
47pub struct ShuffleSplit {
48    n_splits: usize,
49    test_size: Option<f64>,
50    train_size: Option<f64>,
51    random_state: Option<u64>,
52}
53
54impl ShuffleSplit {
55    /// Create a new ShuffleSplit cross-validator
56    pub fn new(n_splits: usize) -> Self {
57        Self {
58            n_splits,
59            test_size: Some(0.1),
60            train_size: None,
61            random_state: None,
62        }
63    }
64
65    /// Set the test size as a proportion (0.0 to 1.0) of the dataset
66    pub fn test_size(mut self, size: f64) -> Self {
67        assert!(
68            (0.0..=1.0).contains(&size),
69            "test_size must be between 0.0 and 1.0"
70        );
71        self.test_size = Some(size);
72        self
73    }
74
75    /// Set the train size as a proportion (0.0 to 1.0) of the dataset
76    pub fn train_size(mut self, size: f64) -> Self {
77        assert!(
78            (0.0..=1.0).contains(&size),
79            "train_size must be between 0.0 and 1.0"
80        );
81        self.train_size = Some(size);
82        self
83    }
84
85    /// Set the random state for reproducible results
86    pub fn random_state(mut self, seed: u64) -> Self {
87        self.random_state = Some(seed);
88        self
89    }
90}
91
92impl CrossValidator for ShuffleSplit {
93    fn n_splits(&self) -> usize {
94        self.n_splits
95    }
96
97    fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
98        let test_size = self.test_size.unwrap_or(0.1);
99        let train_size = self.train_size.unwrap_or(1.0 - test_size);
100
101        assert!(
102            train_size + test_size <= 1.0,
103            "train_size + test_size cannot exceed 1.0"
104        );
105
106        let n_test = (n_samples as f64 * test_size).round() as usize;
107        let n_train = (n_samples as f64 * train_size).round() as usize;
108
109        assert!(
110            n_train + n_test <= n_samples,
111            "train_size + test_size results in more samples than available"
112        );
113
114        let mut rng = match self.random_state {
115            Some(seed) => StdRng::seed_from_u64(seed),
116            None => {
117                use scirs2_core::random::thread_rng;
118                StdRng::from_rng(&mut thread_rng())
119            }
120        };
121
122        let mut splits = Vec::new();
123
124        for _ in 0..self.n_splits {
125            let mut indices: Vec<usize> = (0..n_samples).collect();
126            indices.shuffle(&mut rng);
127
128            let test_indices = indices[..n_test].to_vec();
129            let train_indices = indices[n_test..n_test + n_train].to_vec();
130
131            splits.push((train_indices, test_indices));
132        }
133
134        splits
135    }
136}
137
138/// Stratified Shuffle Split cross-validator
139///
140/// Combines stratified sampling with shuffle split for balanced random splits
141#[derive(Debug, Clone)]
142pub struct StratifiedShuffleSplit {
143    n_splits: usize,
144    test_size: Option<f64>,
145    train_size: Option<f64>,
146    random_state: Option<u64>,
147}
148
149impl StratifiedShuffleSplit {
150    /// Create a new StratifiedShuffleSplit cross-validator
151    pub fn new(n_splits: usize) -> Self {
152        Self {
153            n_splits,
154            test_size: Some(0.1),
155            train_size: None,
156            random_state: None,
157        }
158    }
159
160    /// Set the test size as a proportion (0.0 to 1.0) of the dataset
161    pub fn test_size(mut self, size: f64) -> Self {
162        assert!(
163            (0.0..=1.0).contains(&size),
164            "test_size must be between 0.0 and 1.0"
165        );
166        self.test_size = Some(size);
167        self
168    }
169
170    /// Set the train size as a proportion (0.0 to 1.0) of the dataset
171    pub fn train_size(mut self, size: f64) -> Self {
172        assert!(
173            (0.0..=1.0).contains(&size),
174            "train_size must be between 0.0 and 1.0"
175        );
176        self.train_size = Some(size);
177        self
178    }
179
180    /// Set the random state for reproducible results
181    pub fn random_state(mut self, seed: u64) -> Self {
182        self.random_state = Some(seed);
183        self
184    }
185}
186
187impl CrossValidator for StratifiedShuffleSplit {
188    fn n_splits(&self) -> usize {
189        self.n_splits
190    }
191
192    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
193        let y = y.expect("StratifiedShuffleSplit requires y to be provided");
194        assert_eq!(
195            y.len(),
196            n_samples,
197            "y must have the same length as n_samples"
198        );
199
200        let test_size = self.test_size.unwrap_or(0.1);
201        let train_size = self.train_size.unwrap_or(1.0 - test_size);
202
203        assert!(
204            train_size + test_size <= 1.0,
205            "train_size + test_size cannot exceed 1.0"
206        );
207
208        // Group indices by class
209        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
210        for (idx, &label) in y.iter().enumerate() {
211            class_indices.entry(label).or_default().push(idx);
212        }
213
214        let mut rng = match self.random_state {
215            Some(seed) => StdRng::seed_from_u64(seed),
216            None => {
217                use scirs2_core::random::thread_rng;
218                StdRng::from_rng(&mut thread_rng())
219            }
220        };
221
222        let mut splits = Vec::new();
223
224        for _ in 0..self.n_splits {
225            let mut train_indices = Vec::new();
226            let mut test_indices = Vec::new();
227
228            // Stratified sampling within each class
229            for (_class, mut indices) in class_indices.clone() {
230                indices.shuffle(&mut rng);
231
232                let n_test_class = ((indices.len() as f64) * test_size).round() as usize;
233                let n_train_class = ((indices.len() as f64) * train_size).round() as usize;
234
235                test_indices.extend(&indices[..n_test_class]);
236                train_indices.extend(&indices[n_test_class..n_test_class + n_train_class]);
237            }
238
239            splits.push((train_indices, test_indices));
240        }
241
242        splits
243    }
244}
245
246/// Bootstrap cross-validator with confidence interval estimation
247///
248/// Bootstrap cross-validation uses sampling with replacement to create training sets
249/// of the same size as the original dataset. The out-of-bag (OOB) samples serve as
250/// the test set. This provides bootstrap estimates of model performance with built-in
251/// confidence intervals.
252#[derive(Debug, Clone)]
253pub struct BootstrapCV {
254    n_splits: usize,
255    train_size: Option<f64>,
256    random_state: Option<u64>,
257}
258
259impl BootstrapCV {
260    /// Create a new Bootstrap cross-validator
261    pub fn new(n_splits: usize) -> Self {
262        assert!(n_splits >= 1, "n_splits must be at least 1");
263        Self {
264            n_splits,
265            train_size: None, // Use same size as original dataset by default
266            random_state: None,
267        }
268    }
269
270    /// Set the size of the training set as a proportion (0.0 to 1.0) of the dataset
271    pub fn train_size(mut self, size: f64) -> Self {
272        assert!(
273            (0.0..=1.0).contains(&size),
274            "train_size must be between 0.0 and 1.0"
275        );
276        self.train_size = Some(size);
277        self
278    }
279
280    /// Set the random state for reproducible results
281    pub fn random_state(mut self, seed: u64) -> Self {
282        self.random_state = Some(seed);
283        self
284    }
285}
286
287impl CrossValidator for BootstrapCV {
288    fn n_splits(&self) -> usize {
289        self.n_splits
290    }
291
292    fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
293        let train_size = match self.train_size {
294            Some(frac) => (frac * n_samples as f64).round() as usize,
295            None => n_samples, // Bootstrap typically uses same size as original
296        };
297
298        let mut rng = match self.random_state {
299            Some(seed) => StdRng::seed_from_u64(seed),
300            None => {
301                use scirs2_core::random::thread_rng;
302                StdRng::from_rng(&mut thread_rng())
303            }
304        };
305
306        let mut splits = Vec::with_capacity(self.n_splits);
307
308        for _ in 0..self.n_splits {
309            // Bootstrap sampling with replacement for training set
310            let mut train_indices = Vec::with_capacity(train_size);
311            let mut sampled_indices = std::collections::HashSet::new();
312
313            for _ in 0..train_size {
314                let idx = rng.gen_range(0..n_samples);
315                train_indices.push(idx);
316                sampled_indices.insert(idx);
317            }
318
319            // Out-of-bag samples for test set
320            let test_indices: Vec<usize> = (0..n_samples)
321                .filter(|idx| !sampled_indices.contains(idx))
322                .collect();
323
324            splits.push((train_indices, test_indices));
325        }
326
327        splits
328    }
329}
330
331/// Monte Carlo Cross-Validation with random subsampling
332///
333/// Monte Carlo CV repeatedly randomly splits the data into training and test sets,
334/// unlike K-fold CV which ensures each sample appears exactly once in a test set.
335/// This allows for more flexible control over train/test sizes and provides
336/// bootstrap-like estimates of model performance.
337#[derive(Debug, Clone)]
338pub struct MonteCarloCV {
339    n_splits: usize,
340    test_size: f64,
341    train_size: Option<f64>,
342    random_state: Option<u64>,
343}
344
345impl MonteCarloCV {
346    /// Create a new Monte Carlo cross-validator
347    pub fn new(n_splits: usize, test_size: f64) -> Self {
348        assert!(n_splits >= 1, "n_splits must be at least 1");
349        assert!(
350            (0.0..=1.0).contains(&test_size),
351            "test_size must be between 0.0 and 1.0"
352        );
353        Self {
354            n_splits,
355            test_size,
356            train_size: None,
357            random_state: None,
358        }
359    }
360
361    /// Set the training set size as a proportion (0.0 to 1.0) of the dataset
362    pub fn train_size(mut self, size: f64) -> Self {
363        assert!(
364            (0.0..=1.0).contains(&size),
365            "train_size must be between 0.0 and 1.0"
366        );
367        self.train_size = Some(size);
368        self
369    }
370
371    /// Set the random state for reproducible results
372    pub fn random_state(mut self, seed: u64) -> Self {
373        self.random_state = Some(seed);
374        self
375    }
376}
377
378impl CrossValidator for MonteCarloCV {
379    fn n_splits(&self) -> usize {
380        self.n_splits
381    }
382
383    fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
384        let train_size = match self.train_size {
385            Some(frac) => (frac * n_samples as f64).round() as usize,
386            None => n_samples - (self.test_size * n_samples as f64).round() as usize,
387        };
388        let test_size = (self.test_size * n_samples as f64).round() as usize;
389
390        assert!(
391            train_size + test_size <= n_samples,
392            "train_size + test_size cannot exceed the number of samples"
393        );
394
395        let mut rng = match self.random_state {
396            Some(seed) => StdRng::seed_from_u64(seed),
397            None => {
398                use scirs2_core::random::thread_rng;
399                StdRng::from_rng(&mut thread_rng())
400            }
401        };
402
403        let mut splits = Vec::with_capacity(self.n_splits);
404
405        for _ in 0..self.n_splits {
406            let mut indices: Vec<usize> = (0..n_samples).collect();
407            indices.shuffle(&mut rng);
408
409            let test_indices = indices[..test_size].to_vec();
410            let train_indices = indices[test_size..test_size + train_size].to_vec();
411
412            splits.push((train_indices, test_indices));
413        }
414
415        splits
416    }
417}
418
419#[allow(non_snake_case)]
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use scirs2_core::ndarray::{array, Array1};
424
425    #[test]
426    fn test_shuffle_split() {
427        let cv = ShuffleSplit::new(3)
428            .test_size(0.2)
429            .train_size(0.6)
430            .random_state(42);
431
432        let splits = cv.split(100, None);
433        assert_eq!(splits.len(), 3);
434
435        for (train, test) in splits {
436            assert_eq!(test.len(), 20); // 20% of 100
437            assert_eq!(train.len(), 60); // 60% of 100
438
439            // Check no overlap
440            let train_set: std::collections::HashSet<_> = train.iter().collect();
441            let test_set: std::collections::HashSet<_> = test.iter().collect();
442            assert!(train_set.is_disjoint(&test_set));
443        }
444    }
445
446    #[test]
447    fn test_shuffle_split_basic() {
448        let cv = ShuffleSplit::new(3).test_size(0.2).random_state(42);
449        let splits = cv.split(10, None::<&Array1<i32>>);
450
451        assert_eq!(splits.len(), 3);
452
453        for (train, test) in &splits {
454            assert_eq!(test.len(), 2); // 20% of 10
455            assert_eq!(train.len(), 8); // 80% of 10
456
457            // No overlap between train and test
458            for &idx in test {
459                assert!(!train.contains(&idx));
460            }
461        }
462    }
463
464    #[test]
465    fn test_stratified_shuffle_split() {
466        let y = Array1::from_vec(vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2]);
467        let cv = StratifiedShuffleSplit::new(2)
468            .test_size(0.3)
469            .random_state(42);
470
471        let splits = cv.split(10, Some(&y));
472        assert_eq!(splits.len(), 2);
473
474        for (train, test) in splits {
475            assert_eq!(test.len(), 3); // 30% of 10
476            assert_eq!(train.len(), 7); // 70% of 10
477
478            // Check no overlap
479            let train_set: std::collections::HashSet<_> = train.iter().collect();
480            let test_set: std::collections::HashSet<_> = test.iter().collect();
481            assert!(train_set.is_disjoint(&test_set));
482        }
483    }
484
485    #[test]
486    fn test_stratified_shuffle_split_basic() {
487        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
488        let cv = StratifiedShuffleSplit::new(2)
489            .test_size(0.25)
490            .random_state(42);
491        let splits = cv.split(8, Some(&y));
492
493        assert_eq!(splits.len(), 2);
494
495        for (_train, test) in &splits {
496            // Check stratification in test set
497            let mut class_counts = HashMap::new();
498            for &idx in test {
499                *class_counts.entry(y[idx]).or_insert(0) += 1;
500            }
501
502            // Both classes should be represented
503            assert_eq!(class_counts.len(), 2);
504
505            // Each class should have roughly equal representation
506            assert_eq!(class_counts[&0], 1); // 25% of 4 samples of class 0
507            assert_eq!(class_counts[&1], 1); // 25% of 4 samples of class 1
508        }
509    }
510
511    #[test]
512    fn test_bootstrap_cv() {
513        let cv = BootstrapCV::new(3).random_state(42);
514
515        let splits = cv.split(50, None);
516        assert_eq!(splits.len(), 3);
517
518        for (train, test) in splits {
519            assert_eq!(train.len(), 50); // Bootstrap uses same size as original
520            assert!(test.len() > 0); // Out-of-bag samples
521            assert!(test.len() < 50); // Should be less than original
522        }
523    }
524
525    #[test]
526    fn test_monte_carlo_cv() {
527        let cv = MonteCarloCV::new(4, 0.25).random_state(42);
528
529        let splits = cv.split(80, None);
530        assert_eq!(splits.len(), 4);
531
532        for (train, test) in splits {
533            assert_eq!(test.len(), 20); // 25% of 80
534            assert_eq!(train.len(), 60); // 75% of 80
535
536            // Check no overlap
537            let train_set: std::collections::HashSet<_> = train.iter().collect();
538            let test_set: std::collections::HashSet<_> = test.iter().collect();
539            assert!(train_set.is_disjoint(&test_set));
540        }
541    }
542
543    #[test]
544    fn test_combinations() {
545        let items = vec![1, 2, 3, 4];
546        let combos = combinations(&items, 2);
547        assert_eq!(combos.len(), 6); // C(4,2) = 6
548
549        let expected = vec![
550            vec![1, 2],
551            vec![1, 3],
552            vec![1, 4],
553            vec![2, 3],
554            vec![2, 4],
555            vec![3, 4],
556        ];
557
558        for combo in combos {
559            assert_eq!(combo.len(), 2);
560            assert!(expected.contains(&combo));
561        }
562    }
563}