sklears_model_selection/cv/
repeated_cv.rs

1//! Repeated cross-validation methods for robust model evaluation
2//!
3//! This module provides repeated cross-validation methods that run multiple rounds
4//! of K-fold or stratified K-fold with different random states to obtain more robust
5//! estimates of model performance. These methods are particularly useful when you
6//! want to reduce the variance in your cross-validation estimates.
7
8use scirs2_core::ndarray::Array1;
9// use scirs2_core::SliceRandomExt;
10
11use crate::cross_validation::{CrossValidator, KFold, StratifiedKFold};
12
13/// Repeated K-Fold cross-validator
14///
15/// Repeats K-Fold n times with different randomization in each repetition.
16/// This provides more robust estimates by reducing the variance that comes
17/// from a single random split of the data.
18///
19/// # Examples
20///
21/// ```
22/// use sklears_model_selection::{RepeatedKFold, CrossValidator};
23///
24/// let cv = RepeatedKFold::new(5, 3)  // 5-fold repeated 3 times
25///     .random_state(42);
26/// let splits = cv.split(100, None);
27/// assert_eq!(splits.len(), 15);  // 5 * 3 = 15 splits
28/// ```
29#[derive(Debug, Clone)]
30pub struct RepeatedKFold {
31    n_splits: usize,
32    n_repeats: usize,
33    random_state: Option<u64>,
34}
35
36impl RepeatedKFold {
37    /// Create a new RepeatedKFold cross-validator
38    ///
39    /// # Arguments
40    /// * `n_splits` - Number of folds per repetition (must be >= 2)
41    /// * `n_repeats` - Number of repetitions (must be >= 1)
42    ///
43    /// # Panics
44    /// Panics if `n_splits` < 2 or `n_repeats` < 1
45    pub fn new(n_splits: usize, n_repeats: usize) -> Self {
46        assert!(n_splits >= 2, "n_splits must be at least 2");
47        assert!(n_repeats >= 1, "n_repeats must be at least 1");
48        Self {
49            n_splits,
50            n_repeats,
51            random_state: None,
52        }
53    }
54
55    /// Set the random state for reproducible results
56    ///
57    /// # Arguments
58    /// * `seed` - Random seed value
59    pub fn random_state(mut self, seed: u64) -> Self {
60        self.random_state = Some(seed);
61        self
62    }
63
64    /// Get the number of splits per repetition
65    pub fn n_splits_per_repeat(&self) -> usize {
66        self.n_splits
67    }
68
69    /// Get the number of repetitions
70    pub fn n_repeats(&self) -> usize {
71        self.n_repeats
72    }
73}
74
75impl CrossValidator for RepeatedKFold {
76    fn n_splits(&self) -> usize {
77        self.n_splits * self.n_repeats
78    }
79
80    fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
81        let mut all_splits = Vec::new();
82
83        let base_seed = self.random_state.unwrap_or(42);
84
85        for repeat in 0..self.n_repeats {
86            // Create a KFold with shuffling and a different seed for each repeat
87            let kfold = KFold::new(self.n_splits)
88                .shuffle(true)
89                .random_state(base_seed + repeat as u64);
90
91            let splits = kfold.split(n_samples, None);
92            all_splits.extend(splits);
93        }
94
95        all_splits
96    }
97}
98
99/// Repeated Stratified K-Fold cross-validator
100///
101/// Repeats Stratified K-Fold n times with different randomization in each repetition.
102/// This provides more robust estimates while maintaining the class distribution
103/// in each fold. This is particularly useful for imbalanced datasets where you
104/// want both stratification and robust estimates.
105///
106/// # Examples
107///
108/// ```
109/// use sklears_model_selection::{RepeatedStratifiedKFold, CrossValidator};
110/// use scirs2_core::ndarray::array;
111///
112/// let cv = RepeatedStratifiedKFold::new(2, 2)  // 2-fold repeated 2 times
113///     .random_state(42);
114/// let y = array![0, 0, 1, 1, 2, 2];
115/// let splits = cv.split(6, Some(&y));
116/// assert_eq!(splits.len(), 4);  // 2 * 2 = 4 splits
117/// ```
118#[derive(Debug, Clone)]
119pub struct RepeatedStratifiedKFold {
120    n_splits: usize,
121    n_repeats: usize,
122    random_state: Option<u64>,
123}
124
125impl RepeatedStratifiedKFold {
126    /// Create a new RepeatedStratifiedKFold cross-validator
127    ///
128    /// # Arguments
129    /// * `n_splits` - Number of folds per repetition (must be >= 2)
130    /// * `n_repeats` - Number of repetitions (must be >= 1)
131    ///
132    /// # Panics
133    /// Panics if `n_splits` < 2 or `n_repeats` < 1
134    pub fn new(n_splits: usize, n_repeats: usize) -> Self {
135        assert!(n_splits >= 2, "n_splits must be at least 2");
136        assert!(n_repeats >= 1, "n_repeats must be at least 1");
137        Self {
138            n_splits,
139            n_repeats,
140            random_state: None,
141        }
142    }
143
144    /// Set the random state for reproducible results
145    ///
146    /// # Arguments
147    /// * `seed` - Random seed value
148    pub fn random_state(mut self, seed: u64) -> Self {
149        self.random_state = Some(seed);
150        self
151    }
152
153    /// Get the number of splits per repetition
154    pub fn n_splits_per_repeat(&self) -> usize {
155        self.n_splits
156    }
157
158    /// Get the number of repetitions
159    pub fn n_repeats(&self) -> usize {
160        self.n_repeats
161    }
162}
163
164impl CrossValidator for RepeatedStratifiedKFold {
165    fn n_splits(&self) -> usize {
166        self.n_splits * self.n_repeats
167    }
168
169    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
170        let y = y.expect("RepeatedStratifiedKFold requires y to be provided");
171        let mut all_splits = Vec::new();
172
173        let base_seed = self.random_state.unwrap_or(42);
174
175        for repeat in 0..self.n_repeats {
176            // Create a StratifiedKFold with shuffling and a different seed for each repeat
177            let stratified_kfold = StratifiedKFold::new(self.n_splits)
178                .shuffle(true)
179                .random_state(base_seed + repeat as u64);
180
181            let splits = stratified_kfold.split(n_samples, Some(y));
182            all_splits.extend(splits);
183        }
184
185        all_splits
186    }
187}
188
189#[allow(non_snake_case)]
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use scirs2_core::ndarray::array;
194    use std::collections::HashMap;
195
196    #[test]
197    fn test_repeated_kfold_basic() {
198        let cv = RepeatedKFold::new(3, 2).random_state(42);
199        let splits = cv.split(9, None);
200
201        // Should have n_splits * n_repeats splits
202        assert_eq!(splits.len(), 6);
203        assert_eq!(cv.n_splits(), 6);
204        assert_eq!(cv.n_splits_per_repeat(), 3);
205        assert_eq!(cv.n_repeats(), 2);
206
207        // Check that each sample appears in test sets the expected number of times
208        let mut test_count = vec![0; 9];
209        for (_, test) in &splits {
210            for &idx in test {
211                test_count[idx] += 1;
212            }
213        }
214
215        // Each sample should appear in test sets exactly n_repeats times
216        for count in test_count {
217            assert_eq!(count, 2);
218        }
219    }
220
221    #[test]
222    fn test_repeated_kfold_no_overlap() {
223        let cv = RepeatedKFold::new(3, 2).random_state(42);
224        let splits = cv.split(9, None);
225
226        // Verify no overlap between train and test in each split
227        for (train, test) in &splits {
228            for &test_idx in test {
229                assert!(!train.contains(&test_idx));
230            }
231        }
232    }
233
234    #[test]
235    fn test_repeated_kfold_different_seeds() {
236        let cv1 = RepeatedKFold::new(3, 2).random_state(42);
237        let cv2 = RepeatedKFold::new(3, 2).random_state(123);
238
239        let splits1 = cv1.split(9, None);
240        let splits2 = cv2.split(9, None);
241
242        // Different seeds should produce different splits
243        assert_ne!(splits1, splits2);
244    }
245
246    #[test]
247    fn test_repeated_stratified_kfold_basic() {
248        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
249        let cv = RepeatedStratifiedKFold::new(3, 2).random_state(42);
250        let splits = cv.split(9, Some(&y));
251
252        // Should have n_splits * n_repeats splits
253        assert_eq!(splits.len(), 6);
254        assert_eq!(cv.n_splits(), 6);
255        assert_eq!(cv.n_splits_per_repeat(), 3);
256        assert_eq!(cv.n_repeats(), 2);
257
258        // Check stratification in each split
259        for (_, test) in &splits {
260            let mut class_counts = HashMap::new();
261            for &idx in test {
262                *class_counts.entry(y[idx]).or_insert(0) += 1;
263            }
264
265            // Each class should be represented
266            assert_eq!(class_counts.len(), 3);
267            // Each class should have exactly 1 sample in test set
268            for count in class_counts.values() {
269                assert_eq!(*count, 1);
270            }
271        }
272    }
273
274    #[test]
275    fn test_repeated_stratified_kfold_requires_y() {
276        let cv = RepeatedStratifiedKFold::new(3, 2);
277
278        // Should panic when y is not provided
279        let result = std::panic::catch_unwind(|| cv.split(9, None));
280        assert!(result.is_err());
281    }
282
283    #[test]
284    fn test_repeated_stratified_kfold_class_distribution() {
285        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
286        let cv = RepeatedStratifiedKFold::new(3, 2).random_state(42);
287        let splits = cv.split(9, Some(&y));
288
289        // Each sample should appear in test sets exactly n_repeats times
290        let mut test_count = vec![0; 9];
291        for (_, test) in &splits {
292            for &idx in test {
293                test_count[idx] += 1;
294            }
295        }
296
297        for count in test_count {
298            assert_eq!(count, 2);
299        }
300    }
301
302    #[test]
303    fn test_repeated_stratified_kfold_different_seeds() {
304        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
305        let cv1 = RepeatedStratifiedKFold::new(3, 2).random_state(42);
306        let cv2 = RepeatedStratifiedKFold::new(3, 2).random_state(123);
307
308        let splits1 = cv1.split(9, Some(&y));
309        let splits2 = cv2.split(9, Some(&y));
310
311        // Different seeds should produce different splits
312        assert_ne!(splits1, splits2);
313    }
314
315    #[test]
316    #[should_panic(expected = "n_splits must be at least 2")]
317    fn test_repeated_kfold_invalid_n_splits() {
318        RepeatedKFold::new(1, 2);
319    }
320
321    #[test]
322    #[should_panic(expected = "n_repeats must be at least 1")]
323    fn test_repeated_kfold_invalid_n_repeats() {
324        RepeatedKFold::new(3, 0);
325    }
326
327    #[test]
328    #[should_panic(expected = "n_splits must be at least 2")]
329    fn test_repeated_stratified_kfold_invalid_n_splits() {
330        RepeatedStratifiedKFold::new(1, 2);
331    }
332
333    #[test]
334    #[should_panic(expected = "n_repeats must be at least 1")]
335    fn test_repeated_stratified_kfold_invalid_n_repeats() {
336        RepeatedStratifiedKFold::new(3, 0);
337    }
338
339    #[test]
340    fn test_repeated_kfold_single_repeat() {
341        let cv = RepeatedKFold::new(3, 1).random_state(42);
342        let splits = cv.split(9, None);
343
344        // Should have exactly n_splits splits
345        assert_eq!(splits.len(), 3);
346
347        // Each sample should appear exactly once in test sets
348        let mut test_count = vec![0; 9];
349        for (_, test) in &splits {
350            for &idx in test {
351                test_count[idx] += 1;
352            }
353        }
354
355        for count in test_count {
356            assert_eq!(count, 1);
357        }
358    }
359
360    #[test]
361    fn test_repeated_stratified_kfold_single_repeat() {
362        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
363        let cv = RepeatedStratifiedKFold::new(3, 1).random_state(42);
364        let splits = cv.split(9, Some(&y));
365
366        // Should have exactly n_splits splits
367        assert_eq!(splits.len(), 3);
368
369        // Each sample should appear exactly once in test sets
370        let mut test_count = vec![0; 9];
371        for (_, test) in &splits {
372            for &idx in test {
373                test_count[idx] += 1;
374            }
375        }
376
377        for count in test_count {
378            assert_eq!(count, 1);
379        }
380    }
381}