sklears_model_selection/cv/
shuffle_cv.rs1use scirs2_core::ndarray::Array1;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::Rng;
11use scirs2_core::random::SeedableRng;
12use scirs2_core::SliceRandomExt;
13use std::collections::HashMap;
14
15use crate::cross_validation::CrossValidator;
16
17fn combinations<T: Clone>(items: &[T], k: usize) -> Vec<Vec<T>> {
19 if k == 0 {
20 return vec![vec![]];
21 }
22 if items.is_empty() {
23 return vec![];
24 }
25
26 let first = &items[0];
27 let rest = &items[1..];
28
29 let mut result = Vec::new();
30
31 for mut combo in combinations(rest, k - 1) {
33 combo.insert(0, first.clone());
34 result.push(combo);
35 }
36
37 result.extend(combinations(rest, k));
39
40 result
41}
42
43#[derive(Debug, Clone)]
47pub struct ShuffleSplit {
48 n_splits: usize,
49 test_size: Option<f64>,
50 train_size: Option<f64>,
51 random_state: Option<u64>,
52}
53
54impl ShuffleSplit {
55 pub fn new(n_splits: usize) -> Self {
57 Self {
58 n_splits,
59 test_size: Some(0.1),
60 train_size: None,
61 random_state: None,
62 }
63 }
64
65 pub fn test_size(mut self, size: f64) -> Self {
67 assert!(
68 (0.0..=1.0).contains(&size),
69 "test_size must be between 0.0 and 1.0"
70 );
71 self.test_size = Some(size);
72 self
73 }
74
75 pub fn train_size(mut self, size: f64) -> Self {
77 assert!(
78 (0.0..=1.0).contains(&size),
79 "train_size must be between 0.0 and 1.0"
80 );
81 self.train_size = Some(size);
82 self
83 }
84
85 pub fn random_state(mut self, seed: u64) -> Self {
87 self.random_state = Some(seed);
88 self
89 }
90}
91
92impl CrossValidator for ShuffleSplit {
93 fn n_splits(&self) -> usize {
94 self.n_splits
95 }
96
97 fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
98 let test_size = self.test_size.unwrap_or(0.1);
99 let train_size = self.train_size.unwrap_or(1.0 - test_size);
100
101 assert!(
102 train_size + test_size <= 1.0,
103 "train_size + test_size cannot exceed 1.0"
104 );
105
106 let n_test = (n_samples as f64 * test_size).round() as usize;
107 let n_train = (n_samples as f64 * train_size).round() as usize;
108
109 assert!(
110 n_train + n_test <= n_samples,
111 "train_size + test_size results in more samples than available"
112 );
113
114 let mut rng = match self.random_state {
115 Some(seed) => StdRng::seed_from_u64(seed),
116 None => {
117 use scirs2_core::random::thread_rng;
118 StdRng::from_rng(&mut thread_rng())
119 }
120 };
121
122 let mut splits = Vec::new();
123
124 for _ in 0..self.n_splits {
125 let mut indices: Vec<usize> = (0..n_samples).collect();
126 indices.shuffle(&mut rng);
127
128 let test_indices = indices[..n_test].to_vec();
129 let train_indices = indices[n_test..n_test + n_train].to_vec();
130
131 splits.push((train_indices, test_indices));
132 }
133
134 splits
135 }
136}
137
138#[derive(Debug, Clone)]
142pub struct StratifiedShuffleSplit {
143 n_splits: usize,
144 test_size: Option<f64>,
145 train_size: Option<f64>,
146 random_state: Option<u64>,
147}
148
149impl StratifiedShuffleSplit {
150 pub fn new(n_splits: usize) -> Self {
152 Self {
153 n_splits,
154 test_size: Some(0.1),
155 train_size: None,
156 random_state: None,
157 }
158 }
159
160 pub fn test_size(mut self, size: f64) -> Self {
162 assert!(
163 (0.0..=1.0).contains(&size),
164 "test_size must be between 0.0 and 1.0"
165 );
166 self.test_size = Some(size);
167 self
168 }
169
170 pub fn train_size(mut self, size: f64) -> Self {
172 assert!(
173 (0.0..=1.0).contains(&size),
174 "train_size must be between 0.0 and 1.0"
175 );
176 self.train_size = Some(size);
177 self
178 }
179
180 pub fn random_state(mut self, seed: u64) -> Self {
182 self.random_state = Some(seed);
183 self
184 }
185}
186
187impl CrossValidator for StratifiedShuffleSplit {
188 fn n_splits(&self) -> usize {
189 self.n_splits
190 }
191
192 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
193 let y = y.expect("StratifiedShuffleSplit requires y to be provided");
194 assert_eq!(
195 y.len(),
196 n_samples,
197 "y must have the same length as n_samples"
198 );
199
200 let test_size = self.test_size.unwrap_or(0.1);
201 let train_size = self.train_size.unwrap_or(1.0 - test_size);
202
203 assert!(
204 train_size + test_size <= 1.0,
205 "train_size + test_size cannot exceed 1.0"
206 );
207
208 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
210 for (idx, &label) in y.iter().enumerate() {
211 class_indices.entry(label).or_default().push(idx);
212 }
213
214 let mut rng = match self.random_state {
215 Some(seed) => StdRng::seed_from_u64(seed),
216 None => {
217 use scirs2_core::random::thread_rng;
218 StdRng::from_rng(&mut thread_rng())
219 }
220 };
221
222 let mut splits = Vec::new();
223
224 for _ in 0..self.n_splits {
225 let mut train_indices = Vec::new();
226 let mut test_indices = Vec::new();
227
228 for (_class, mut indices) in class_indices.clone() {
230 indices.shuffle(&mut rng);
231
232 let n_test_class = ((indices.len() as f64) * test_size).round() as usize;
233 let n_train_class = ((indices.len() as f64) * train_size).round() as usize;
234
235 test_indices.extend(&indices[..n_test_class]);
236 train_indices.extend(&indices[n_test_class..n_test_class + n_train_class]);
237 }
238
239 splits.push((train_indices, test_indices));
240 }
241
242 splits
243 }
244}
245
246#[derive(Debug, Clone)]
253pub struct BootstrapCV {
254 n_splits: usize,
255 train_size: Option<f64>,
256 random_state: Option<u64>,
257}
258
259impl BootstrapCV {
260 pub fn new(n_splits: usize) -> Self {
262 assert!(n_splits >= 1, "n_splits must be at least 1");
263 Self {
264 n_splits,
265 train_size: None, random_state: None,
267 }
268 }
269
270 pub fn train_size(mut self, size: f64) -> Self {
272 assert!(
273 (0.0..=1.0).contains(&size),
274 "train_size must be between 0.0 and 1.0"
275 );
276 self.train_size = Some(size);
277 self
278 }
279
280 pub fn random_state(mut self, seed: u64) -> Self {
282 self.random_state = Some(seed);
283 self
284 }
285}
286
287impl CrossValidator for BootstrapCV {
288 fn n_splits(&self) -> usize {
289 self.n_splits
290 }
291
292 fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
293 let train_size = match self.train_size {
294 Some(frac) => (frac * n_samples as f64).round() as usize,
295 None => n_samples, };
297
298 let mut rng = match self.random_state {
299 Some(seed) => StdRng::seed_from_u64(seed),
300 None => {
301 use scirs2_core::random::thread_rng;
302 StdRng::from_rng(&mut thread_rng())
303 }
304 };
305
306 let mut splits = Vec::with_capacity(self.n_splits);
307
308 for _ in 0..self.n_splits {
309 let mut train_indices = Vec::with_capacity(train_size);
311 let mut sampled_indices = std::collections::HashSet::new();
312
313 for _ in 0..train_size {
314 let idx = rng.gen_range(0..n_samples);
315 train_indices.push(idx);
316 sampled_indices.insert(idx);
317 }
318
319 let test_indices: Vec<usize> = (0..n_samples)
321 .filter(|idx| !sampled_indices.contains(idx))
322 .collect();
323
324 splits.push((train_indices, test_indices));
325 }
326
327 splits
328 }
329}
330
331#[derive(Debug, Clone)]
338pub struct MonteCarloCV {
339 n_splits: usize,
340 test_size: f64,
341 train_size: Option<f64>,
342 random_state: Option<u64>,
343}
344
345impl MonteCarloCV {
346 pub fn new(n_splits: usize, test_size: f64) -> Self {
348 assert!(n_splits >= 1, "n_splits must be at least 1");
349 assert!(
350 (0.0..=1.0).contains(&test_size),
351 "test_size must be between 0.0 and 1.0"
352 );
353 Self {
354 n_splits,
355 test_size,
356 train_size: None,
357 random_state: None,
358 }
359 }
360
361 pub fn train_size(mut self, size: f64) -> Self {
363 assert!(
364 (0.0..=1.0).contains(&size),
365 "train_size must be between 0.0 and 1.0"
366 );
367 self.train_size = Some(size);
368 self
369 }
370
371 pub fn random_state(mut self, seed: u64) -> Self {
373 self.random_state = Some(seed);
374 self
375 }
376}
377
378impl CrossValidator for MonteCarloCV {
379 fn n_splits(&self) -> usize {
380 self.n_splits
381 }
382
383 fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
384 let train_size = match self.train_size {
385 Some(frac) => (frac * n_samples as f64).round() as usize,
386 None => n_samples - (self.test_size * n_samples as f64).round() as usize,
387 };
388 let test_size = (self.test_size * n_samples as f64).round() as usize;
389
390 assert!(
391 train_size + test_size <= n_samples,
392 "train_size + test_size cannot exceed the number of samples"
393 );
394
395 let mut rng = match self.random_state {
396 Some(seed) => StdRng::seed_from_u64(seed),
397 None => {
398 use scirs2_core::random::thread_rng;
399 StdRng::from_rng(&mut thread_rng())
400 }
401 };
402
403 let mut splits = Vec::with_capacity(self.n_splits);
404
405 for _ in 0..self.n_splits {
406 let mut indices: Vec<usize> = (0..n_samples).collect();
407 indices.shuffle(&mut rng);
408
409 let test_indices = indices[..test_size].to_vec();
410 let train_indices = indices[test_size..test_size + train_size].to_vec();
411
412 splits.push((train_indices, test_indices));
413 }
414
415 splits
416 }
417}
418
419#[allow(non_snake_case)]
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use scirs2_core::ndarray::{array, Array1};
424
425 #[test]
426 fn test_shuffle_split() {
427 let cv = ShuffleSplit::new(3)
428 .test_size(0.2)
429 .train_size(0.6)
430 .random_state(42);
431
432 let splits = cv.split(100, None);
433 assert_eq!(splits.len(), 3);
434
435 for (train, test) in splits {
436 assert_eq!(test.len(), 20); assert_eq!(train.len(), 60); let train_set: std::collections::HashSet<_> = train.iter().collect();
441 let test_set: std::collections::HashSet<_> = test.iter().collect();
442 assert!(train_set.is_disjoint(&test_set));
443 }
444 }
445
446 #[test]
447 fn test_shuffle_split_basic() {
448 let cv = ShuffleSplit::new(3).test_size(0.2).random_state(42);
449 let splits = cv.split(10, None::<&Array1<i32>>);
450
451 assert_eq!(splits.len(), 3);
452
453 for (train, test) in &splits {
454 assert_eq!(test.len(), 2); assert_eq!(train.len(), 8); for &idx in test {
459 assert!(!train.contains(&idx));
460 }
461 }
462 }
463
464 #[test]
465 fn test_stratified_shuffle_split() {
466 let y = Array1::from_vec(vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2]);
467 let cv = StratifiedShuffleSplit::new(2)
468 .test_size(0.3)
469 .random_state(42);
470
471 let splits = cv.split(10, Some(&y));
472 assert_eq!(splits.len(), 2);
473
474 for (train, test) in splits {
475 assert_eq!(test.len(), 3); assert_eq!(train.len(), 7); let train_set: std::collections::HashSet<_> = train.iter().collect();
480 let test_set: std::collections::HashSet<_> = test.iter().collect();
481 assert!(train_set.is_disjoint(&test_set));
482 }
483 }
484
485 #[test]
486 fn test_stratified_shuffle_split_basic() {
487 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
488 let cv = StratifiedShuffleSplit::new(2)
489 .test_size(0.25)
490 .random_state(42);
491 let splits = cv.split(8, Some(&y));
492
493 assert_eq!(splits.len(), 2);
494
495 for (_train, test) in &splits {
496 let mut class_counts = HashMap::new();
498 for &idx in test {
499 *class_counts.entry(y[idx]).or_insert(0) += 1;
500 }
501
502 assert_eq!(class_counts.len(), 2);
504
505 assert_eq!(class_counts[&0], 1); assert_eq!(class_counts[&1], 1); }
509 }
510
511 #[test]
512 fn test_bootstrap_cv() {
513 let cv = BootstrapCV::new(3).random_state(42);
514
515 let splits = cv.split(50, None);
516 assert_eq!(splits.len(), 3);
517
518 for (train, test) in splits {
519 assert_eq!(train.len(), 50); assert!(test.len() > 0); assert!(test.len() < 50); }
523 }
524
525 #[test]
526 fn test_monte_carlo_cv() {
527 let cv = MonteCarloCV::new(4, 0.25).random_state(42);
528
529 let splits = cv.split(80, None);
530 assert_eq!(splits.len(), 4);
531
532 for (train, test) in splits {
533 assert_eq!(test.len(), 20); assert_eq!(train.len(), 60); let train_set: std::collections::HashSet<_> = train.iter().collect();
538 let test_set: std::collections::HashSet<_> = test.iter().collect();
539 assert!(train_set.is_disjoint(&test_set));
540 }
541 }
542
543 #[test]
544 fn test_combinations() {
545 let items = vec![1, 2, 3, 4];
546 let combos = combinations(&items, 2);
547 assert_eq!(combos.len(), 6); let expected = vec![
550 vec![1, 2],
551 vec![1, 3],
552 vec![1, 4],
553 vec![2, 3],
554 vec![2, 4],
555 vec![3, 4],
556 ];
557
558 for combo in combos {
559 assert_eq!(combo.len(), 2);
560 assert!(expected.contains(&combo));
561 }
562 }
563}