sklears_model_selection/cv/
repeated_cv.rs1use scirs2_core::ndarray::Array1;
9use crate::cross_validation::{CrossValidator, KFold, StratifiedKFold};
12
13#[derive(Debug, Clone)]
30pub struct RepeatedKFold {
31 n_splits: usize,
32 n_repeats: usize,
33 random_state: Option<u64>,
34}
35
36impl RepeatedKFold {
37 pub fn new(n_splits: usize, n_repeats: usize) -> Self {
46 assert!(n_splits >= 2, "n_splits must be at least 2");
47 assert!(n_repeats >= 1, "n_repeats must be at least 1");
48 Self {
49 n_splits,
50 n_repeats,
51 random_state: None,
52 }
53 }
54
55 pub fn random_state(mut self, seed: u64) -> Self {
60 self.random_state = Some(seed);
61 self
62 }
63
64 pub fn n_splits_per_repeat(&self) -> usize {
66 self.n_splits
67 }
68
69 pub fn n_repeats(&self) -> usize {
71 self.n_repeats
72 }
73}
74
75impl CrossValidator for RepeatedKFold {
76 fn n_splits(&self) -> usize {
77 self.n_splits * self.n_repeats
78 }
79
80 fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
81 let mut all_splits = Vec::new();
82
83 let base_seed = self.random_state.unwrap_or(42);
84
85 for repeat in 0..self.n_repeats {
86 let kfold = KFold::new(self.n_splits)
88 .shuffle(true)
89 .random_state(base_seed + repeat as u64);
90
91 let splits = kfold.split(n_samples, None);
92 all_splits.extend(splits);
93 }
94
95 all_splits
96 }
97}
98
99#[derive(Debug, Clone)]
119pub struct RepeatedStratifiedKFold {
120 n_splits: usize,
121 n_repeats: usize,
122 random_state: Option<u64>,
123}
124
125impl RepeatedStratifiedKFold {
126 pub fn new(n_splits: usize, n_repeats: usize) -> Self {
135 assert!(n_splits >= 2, "n_splits must be at least 2");
136 assert!(n_repeats >= 1, "n_repeats must be at least 1");
137 Self {
138 n_splits,
139 n_repeats,
140 random_state: None,
141 }
142 }
143
144 pub fn random_state(mut self, seed: u64) -> Self {
149 self.random_state = Some(seed);
150 self
151 }
152
153 pub fn n_splits_per_repeat(&self) -> usize {
155 self.n_splits
156 }
157
158 pub fn n_repeats(&self) -> usize {
160 self.n_repeats
161 }
162}
163
164impl CrossValidator for RepeatedStratifiedKFold {
165 fn n_splits(&self) -> usize {
166 self.n_splits * self.n_repeats
167 }
168
169 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
170 let y = y.expect("RepeatedStratifiedKFold requires y to be provided");
171 let mut all_splits = Vec::new();
172
173 let base_seed = self.random_state.unwrap_or(42);
174
175 for repeat in 0..self.n_repeats {
176 let stratified_kfold = StratifiedKFold::new(self.n_splits)
178 .shuffle(true)
179 .random_state(base_seed + repeat as u64);
180
181 let splits = stratified_kfold.split(n_samples, Some(y));
182 all_splits.extend(splits);
183 }
184
185 all_splits
186 }
187}
188
189#[allow(non_snake_case)]
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use scirs2_core::ndarray::array;
194 use std::collections::HashMap;
195
196 #[test]
197 fn test_repeated_kfold_basic() {
198 let cv = RepeatedKFold::new(3, 2).random_state(42);
199 let splits = cv.split(9, None);
200
201 assert_eq!(splits.len(), 6);
203 assert_eq!(cv.n_splits(), 6);
204 assert_eq!(cv.n_splits_per_repeat(), 3);
205 assert_eq!(cv.n_repeats(), 2);
206
207 let mut test_count = vec![0; 9];
209 for (_, test) in &splits {
210 for &idx in test {
211 test_count[idx] += 1;
212 }
213 }
214
215 for count in test_count {
217 assert_eq!(count, 2);
218 }
219 }
220
221 #[test]
222 fn test_repeated_kfold_no_overlap() {
223 let cv = RepeatedKFold::new(3, 2).random_state(42);
224 let splits = cv.split(9, None);
225
226 for (train, test) in &splits {
228 for &test_idx in test {
229 assert!(!train.contains(&test_idx));
230 }
231 }
232 }
233
234 #[test]
235 fn test_repeated_kfold_different_seeds() {
236 let cv1 = RepeatedKFold::new(3, 2).random_state(42);
237 let cv2 = RepeatedKFold::new(3, 2).random_state(123);
238
239 let splits1 = cv1.split(9, None);
240 let splits2 = cv2.split(9, None);
241
242 assert_ne!(splits1, splits2);
244 }
245
246 #[test]
247 fn test_repeated_stratified_kfold_basic() {
248 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
249 let cv = RepeatedStratifiedKFold::new(3, 2).random_state(42);
250 let splits = cv.split(9, Some(&y));
251
252 assert_eq!(splits.len(), 6);
254 assert_eq!(cv.n_splits(), 6);
255 assert_eq!(cv.n_splits_per_repeat(), 3);
256 assert_eq!(cv.n_repeats(), 2);
257
258 for (_, test) in &splits {
260 let mut class_counts = HashMap::new();
261 for &idx in test {
262 *class_counts.entry(y[idx]).or_insert(0) += 1;
263 }
264
265 assert_eq!(class_counts.len(), 3);
267 for count in class_counts.values() {
269 assert_eq!(*count, 1);
270 }
271 }
272 }
273
274 #[test]
275 fn test_repeated_stratified_kfold_requires_y() {
276 let cv = RepeatedStratifiedKFold::new(3, 2);
277
278 let result = std::panic::catch_unwind(|| cv.split(9, None));
280 assert!(result.is_err());
281 }
282
283 #[test]
284 fn test_repeated_stratified_kfold_class_distribution() {
285 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
286 let cv = RepeatedStratifiedKFold::new(3, 2).random_state(42);
287 let splits = cv.split(9, Some(&y));
288
289 let mut test_count = vec![0; 9];
291 for (_, test) in &splits {
292 for &idx in test {
293 test_count[idx] += 1;
294 }
295 }
296
297 for count in test_count {
298 assert_eq!(count, 2);
299 }
300 }
301
302 #[test]
303 fn test_repeated_stratified_kfold_different_seeds() {
304 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
305 let cv1 = RepeatedStratifiedKFold::new(3, 2).random_state(42);
306 let cv2 = RepeatedStratifiedKFold::new(3, 2).random_state(123);
307
308 let splits1 = cv1.split(9, Some(&y));
309 let splits2 = cv2.split(9, Some(&y));
310
311 assert_ne!(splits1, splits2);
313 }
314
315 #[test]
316 #[should_panic(expected = "n_splits must be at least 2")]
317 fn test_repeated_kfold_invalid_n_splits() {
318 RepeatedKFold::new(1, 2);
319 }
320
321 #[test]
322 #[should_panic(expected = "n_repeats must be at least 1")]
323 fn test_repeated_kfold_invalid_n_repeats() {
324 RepeatedKFold::new(3, 0);
325 }
326
327 #[test]
328 #[should_panic(expected = "n_splits must be at least 2")]
329 fn test_repeated_stratified_kfold_invalid_n_splits() {
330 RepeatedStratifiedKFold::new(1, 2);
331 }
332
333 #[test]
334 #[should_panic(expected = "n_repeats must be at least 1")]
335 fn test_repeated_stratified_kfold_invalid_n_repeats() {
336 RepeatedStratifiedKFold::new(3, 0);
337 }
338
339 #[test]
340 fn test_repeated_kfold_single_repeat() {
341 let cv = RepeatedKFold::new(3, 1).random_state(42);
342 let splits = cv.split(9, None);
343
344 assert_eq!(splits.len(), 3);
346
347 let mut test_count = vec![0; 9];
349 for (_, test) in &splits {
350 for &idx in test {
351 test_count[idx] += 1;
352 }
353 }
354
355 for count in test_count {
356 assert_eq!(count, 1);
357 }
358 }
359
360 #[test]
361 fn test_repeated_stratified_kfold_single_repeat() {
362 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
363 let cv = RepeatedStratifiedKFold::new(3, 1).random_state(42);
364 let splits = cv.split(9, Some(&y));
365
366 assert_eq!(splits.len(), 3);
368
369 let mut test_count = vec![0; 9];
371 for (_, test) in &splits {
372 for &idx in test {
373 test_count[idx] += 1;
374 }
375 }
376
377 for count in test_count {
378 assert_eq!(count, 1);
379 }
380 }
381}