1use scirs2_core::ndarray::Array1;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::SeedableRng;
11use scirs2_core::SliceRandomExt;
12use std::collections::HashMap;
13
14use crate::cross_validation::CrossValidator;
15
16#[derive(Debug, Clone)]
18pub enum GroupStrategy {
19 Direct,
21 Balanced,
23 SizeAware { max_group_size: usize },
25}
26
27#[derive(Debug, Clone)]
32pub struct GroupKFold {
33 n_splits: usize,
34 group_strategy: GroupStrategy,
35}
36
37impl GroupKFold {
38 pub fn new(n_splits: usize) -> Self {
40 assert!(n_splits >= 2, "n_splits must be at least 2");
41 Self {
42 n_splits,
43 group_strategy: GroupStrategy::Direct,
44 }
45 }
46
47 pub fn new_balanced(n_splits: usize) -> Self {
49 assert!(n_splits >= 2, "n_splits must be at least 2");
50 Self {
51 n_splits,
52 group_strategy: GroupStrategy::Balanced,
53 }
54 }
55
56 pub fn new_size_aware(n_splits: usize, max_group_size: usize) -> Self {
58 assert!(n_splits >= 2, "n_splits must be at least 2");
59 Self {
60 n_splits,
61 group_strategy: GroupStrategy::SizeAware { max_group_size },
62 }
63 }
64
65 pub fn group_strategy(mut self, strategy: GroupStrategy) -> Self {
67 self.group_strategy = strategy;
68 self
69 }
70
71 pub fn split_with_groups(
73 &self,
74 n_samples: usize,
75 groups: &Array1<i32>,
76 ) -> Vec<(Vec<usize>, Vec<usize>)> {
77 assert_eq!(
78 groups.len(),
79 n_samples,
80 "groups must have the same length as n_samples"
81 );
82
83 let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
85 for (idx, &group) in groups.iter().enumerate() {
86 group_indices.entry(group).or_default().push(idx);
87 }
88
89 let mut unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
90 unique_groups.sort();
91
92 assert!(
93 unique_groups.len() >= self.n_splits,
94 "The number of groups ({}) must be at least equal to the number of splits ({})",
95 unique_groups.len(),
96 self.n_splits
97 );
98
99 match &self.group_strategy {
100 GroupStrategy::Direct => self.split_direct(&unique_groups, &group_indices),
101 GroupStrategy::Balanced => self.split_balanced(&unique_groups, &group_indices),
102 GroupStrategy::SizeAware { max_group_size } => {
103 self.split_size_aware(&unique_groups, &group_indices, *max_group_size)
104 }
105 }
106 }
107
108 fn split_direct(
109 &self,
110 unique_groups: &[i32],
111 group_indices: &HashMap<i32, Vec<usize>>,
112 ) -> Vec<(Vec<usize>, Vec<usize>)> {
113 let n_groups = unique_groups.len();
115 let groups_per_fold = n_groups / self.n_splits;
116 let n_larger_folds = n_groups % self.n_splits;
117
118 let mut splits = Vec::new();
119 let mut current_group_idx = 0;
120
121 for i in 0..self.n_splits {
122 let fold_size = if i < n_larger_folds {
123 groups_per_fold + 1
124 } else {
125 groups_per_fold
126 };
127
128 let test_groups = &unique_groups[current_group_idx..current_group_idx + fold_size];
129 let train_groups: Vec<i32> = unique_groups
130 .iter()
131 .filter(|&group| !test_groups.contains(group))
132 .cloned()
133 .collect();
134
135 let mut test_indices = Vec::new();
136 for &group in test_groups {
137 test_indices.extend(&group_indices[&group]);
138 }
139
140 let mut train_indices = Vec::new();
141 for &group in &train_groups {
142 train_indices.extend(&group_indices[&group]);
143 }
144
145 splits.push((train_indices, test_indices));
146 current_group_idx += fold_size;
147 }
148
149 splits
150 }
151
152 fn split_balanced(
153 &self,
154 unique_groups: &[i32],
155 group_indices: &HashMap<i32, Vec<usize>>,
156 ) -> Vec<(Vec<usize>, Vec<usize>)> {
157 let mut group_sizes: Vec<(i32, usize)> = unique_groups
159 .iter()
160 .map(|&group| (group, group_indices[&group].len()))
161 .collect();
162
163 group_sizes.sort_by(|a, b| b.1.cmp(&a.1));
165
166 let mut fold_assignments: Vec<Vec<i32>> = vec![Vec::new(); self.n_splits];
167 let mut fold_sizes: Vec<usize> = vec![0; self.n_splits];
168
169 for (group, size) in group_sizes {
171 let min_fold = fold_sizes
172 .iter()
173 .enumerate()
174 .min_by_key(|(_, &size)| size)
175 .map(|(idx, _)| idx)
176 .unwrap();
177
178 fold_assignments[min_fold].push(group);
179 fold_sizes[min_fold] += size;
180 }
181
182 let mut splits = Vec::new();
183 for test_groups in fold_assignments.iter().take(self.n_splits) {
184 let train_groups: Vec<i32> = unique_groups
185 .iter()
186 .filter(|&group| !test_groups.contains(group))
187 .cloned()
188 .collect();
189
190 let mut test_indices = Vec::new();
191 for &group in test_groups {
192 test_indices.extend(&group_indices[&group]);
193 }
194
195 let mut train_indices = Vec::new();
196 for &group in &train_groups {
197 train_indices.extend(&group_indices[&group]);
198 }
199
200 splits.push((train_indices, test_indices));
201 }
202
203 splits
204 }
205
206 fn split_size_aware(
207 &self,
208 unique_groups: &[i32],
209 group_indices: &HashMap<i32, Vec<usize>>,
210 max_group_size: usize,
211 ) -> Vec<(Vec<usize>, Vec<usize>)> {
212 let mut large_groups = Vec::new();
214 let mut small_groups = Vec::new();
215
216 for &group in unique_groups {
217 if group_indices[&group].len() > max_group_size {
218 large_groups.push(group);
219 } else {
220 small_groups.push(group);
221 }
222 }
223
224 let mut fold_assignments: Vec<Vec<i32>> = vec![Vec::new(); self.n_splits];
225 let mut fold_index = 0;
226
227 for group in large_groups {
229 if fold_index < self.n_splits {
230 fold_assignments[fold_index].push(group);
231 fold_index += 1;
232 } else {
233 fold_assignments[fold_index % self.n_splits].push(group);
235 }
236 }
237
238 let mut fold_sizes: Vec<usize> = fold_assignments
240 .iter()
241 .map(|groups| groups.iter().map(|&g| group_indices[&g].len()).sum())
242 .collect();
243
244 for group in small_groups {
245 let min_fold = fold_sizes
246 .iter()
247 .enumerate()
248 .min_by_key(|(_, &size)| size)
249 .map(|(idx, _)| idx)
250 .unwrap();
251
252 fold_assignments[min_fold].push(group);
253 fold_sizes[min_fold] += group_indices[&group].len();
254 }
255
256 let mut splits = Vec::new();
257 for test_groups in fold_assignments.iter().take(self.n_splits) {
258 let train_groups: Vec<i32> = unique_groups
259 .iter()
260 .filter(|&group| !test_groups.contains(group))
261 .cloned()
262 .collect();
263
264 let mut test_indices = Vec::new();
265 for &group in test_groups {
266 test_indices.extend(&group_indices[&group]);
267 }
268
269 let mut train_indices = Vec::new();
270 for &group in &train_groups {
271 train_indices.extend(&group_indices[&group]);
272 }
273
274 splits.push((train_indices, test_indices));
275 }
276
277 splits
278 }
279}
280
281impl CrossValidator for GroupKFold {
282 fn n_splits(&self) -> usize {
283 self.n_splits
284 }
285
286 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
287 let groups = y.expect("GroupKFold requires group labels to be provided in y parameter");
289 self.split_with_groups(n_samples, groups)
290 }
291}
292
293#[derive(Debug, Clone)]
298pub struct StratifiedGroupKFold {
299 n_splits: usize,
300 shuffle: bool,
301 random_state: Option<u64>,
302}
303
304impl StratifiedGroupKFold {
305 pub fn new(n_splits: usize) -> Self {
307 assert!(n_splits >= 2, "n_splits must be at least 2");
308 Self {
309 n_splits,
310 shuffle: false,
311 random_state: None,
312 }
313 }
314
315 pub fn shuffle(mut self, shuffle: bool) -> Self {
317 self.shuffle = shuffle;
318 self
319 }
320
321 pub fn random_state(mut self, seed: u64) -> Self {
323 self.random_state = Some(seed);
324 self
325 }
326
327 pub fn split_with_groups_and_labels(
329 &self,
330 n_samples: usize,
331 y: &Array1<i32>,
332 groups: &Array1<i32>,
333 ) -> Vec<(Vec<usize>, Vec<usize>)> {
334 assert_eq!(
335 n_samples,
336 groups.len(),
337 "n_samples and groups must have the same length"
338 );
339 assert_eq!(
340 n_samples,
341 y.len(),
342 "n_samples and y must have the same length"
343 );
344
345 let mut group_class_counts: HashMap<i32, HashMap<i32, usize>> = HashMap::new();
347 let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
348
349 for (idx, (&group, &label)) in groups.iter().zip(y.iter()).enumerate() {
350 group_indices.entry(group).or_default().push(idx);
351
352 *group_class_counts
353 .entry(group)
354 .or_default()
355 .entry(label)
356 .or_insert(0) += 1;
357 }
358
359 let mut unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
360 let n_groups = unique_groups.len();
361
362 assert!(
363 self.n_splits <= n_groups,
364 "Cannot have number of splits {} greater than the number of groups {}",
365 self.n_splits,
366 n_groups
367 );
368
369 unique_groups.sort_by_key(|&g| {
371 let size: usize = group_class_counts[&g].values().sum();
372 std::cmp::Reverse(size)
373 });
374
375 if self.shuffle {
377 let mut rng = match self.random_state {
378 Some(seed) => StdRng::seed_from_u64(seed),
379 None => {
380 use scirs2_core::random::thread_rng;
381 StdRng::from_rng(&mut thread_rng())
382 }
383 };
384 unique_groups.shuffle(&mut rng);
385 }
386
387 let mut fold_groups: Vec<Vec<i32>> = vec![Vec::new(); self.n_splits];
389 let mut fold_class_counts: Vec<HashMap<i32, usize>> = vec![HashMap::new(); self.n_splits];
390
391 for group in unique_groups {
393 let mut best_fold = 0;
395 let mut min_size = usize::MAX;
396
397 for (fold_idx, fold_counts) in fold_class_counts.iter().enumerate() {
398 let fold_size: usize = fold_counts.values().sum();
399 if fold_size < min_size {
400 min_size = fold_size;
401 best_fold = fold_idx;
402 }
403 }
404
405 fold_groups[best_fold].push(group);
407
408 for (&class, &count) in &group_class_counts[&group] {
410 *fold_class_counts[best_fold].entry(class).or_insert(0) += count;
411 }
412 }
413
414 let mut splits = Vec::new();
416
417 for test_fold_idx in 0..self.n_splits {
418 let mut test_indices = Vec::new();
419 let mut train_indices = Vec::new();
420
421 for (fold_idx, groups_in_fold) in fold_groups.iter().enumerate() {
422 for &group in groups_in_fold {
423 if fold_idx == test_fold_idx {
424 test_indices.extend(&group_indices[&group]);
425 } else {
426 train_indices.extend(&group_indices[&group]);
427 }
428 }
429 }
430
431 test_indices.sort_unstable();
433 train_indices.sort_unstable();
434
435 splits.push((train_indices, test_indices));
436 }
437
438 splits
439 }
440}
441
442impl CrossValidator for StratifiedGroupKFold {
443 fn n_splits(&self) -> usize {
444 self.n_splits
445 }
446
447 fn split(&self, _n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
448 panic!("StratifiedGroupKFold requires both groups and labels. Use split_with_groups_and_labels method instead.");
449 }
450}
451
452#[derive(Debug, Clone)]
457pub struct GroupShuffleSplit {
458 n_splits: usize,
459 test_size: Option<f64>,
460 train_size: Option<f64>,
461 random_state: Option<u64>,
462}
463
464impl GroupShuffleSplit {
465 pub fn new(n_splits: usize) -> Self {
467 Self {
468 n_splits,
469 test_size: Some(0.2),
470 train_size: None,
471 random_state: None,
472 }
473 }
474
475 pub fn test_size(mut self, size: f64) -> Self {
477 assert!(
478 (0.0..=1.0).contains(&size),
479 "test_size must be between 0.0 and 1.0"
480 );
481 self.test_size = Some(size);
482 self
483 }
484
485 pub fn train_size(mut self, size: f64) -> Self {
487 assert!(
488 (0.0..=1.0).contains(&size),
489 "train_size must be between 0.0 and 1.0"
490 );
491 self.train_size = Some(size);
492 self
493 }
494
495 pub fn random_state(mut self, seed: u64) -> Self {
497 self.random_state = Some(seed);
498 self
499 }
500
501 pub fn split_with_groups(
503 &self,
504 n_samples: usize,
505 groups: &Array1<i32>,
506 ) -> Vec<(Vec<usize>, Vec<usize>)> {
507 assert_eq!(
508 groups.len(),
509 n_samples,
510 "groups must have the same length as n_samples"
511 );
512
513 let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
515 for (idx, &group) in groups.iter().enumerate() {
516 group_indices.entry(group).or_default().push(idx);
517 }
518
519 let unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
520 let n_groups = unique_groups.len();
521
522 let test_size = self.test_size.unwrap_or(0.2);
523 let train_size = self.train_size.unwrap_or(1.0 - test_size);
524
525 assert!(
526 train_size + test_size <= 1.0,
527 "train_size + test_size cannot exceed 1.0"
528 );
529
530 let n_test_groups = ((n_groups as f64) * test_size).round() as usize;
531 let n_train_groups = ((n_groups as f64) * train_size).round() as usize;
532
533 assert!(
534 n_train_groups + n_test_groups <= n_groups,
535 "train_size + test_size results in more groups than available"
536 );
537
538 let mut rng = match self.random_state {
539 Some(seed) => StdRng::seed_from_u64(seed),
540 None => {
541 use scirs2_core::random::thread_rng;
542 StdRng::from_rng(&mut thread_rng())
543 }
544 };
545
546 let mut splits = Vec::new();
547
548 for _ in 0..self.n_splits {
549 let mut shuffled_groups = unique_groups.clone();
550 shuffled_groups.shuffle(&mut rng);
551
552 let test_groups = &shuffled_groups[..n_test_groups];
553 let train_groups = &shuffled_groups[n_test_groups..n_test_groups + n_train_groups];
554
555 let mut test_indices = Vec::new();
556 for &group in test_groups {
557 test_indices.extend(&group_indices[&group]);
558 }
559
560 let mut train_indices = Vec::new();
561 for &group in train_groups {
562 train_indices.extend(&group_indices[&group]);
563 }
564
565 test_indices.sort_unstable();
567 train_indices.sort_unstable();
568
569 splits.push((train_indices, test_indices));
570 }
571
572 splits
573 }
574}
575
576impl CrossValidator for GroupShuffleSplit {
577 fn n_splits(&self) -> usize {
578 self.n_splits
579 }
580
581 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
582 let groups =
584 y.expect("GroupShuffleSplit requires group labels to be provided in y parameter");
585 self.split_with_groups(n_samples, groups)
586 }
587}
588
589#[derive(Debug, Clone)]
593pub struct LeaveOneGroupOut;
594
595impl Default for LeaveOneGroupOut {
596 fn default() -> Self {
597 Self::new()
598 }
599}
600
601impl LeaveOneGroupOut {
602 pub fn new() -> Self {
604 LeaveOneGroupOut
605 }
606
607 pub fn split_with_groups(
609 &self,
610 n_samples: usize,
611 groups: &Array1<i32>,
612 ) -> Vec<(Vec<usize>, Vec<usize>)> {
613 assert_eq!(
614 groups.len(),
615 n_samples,
616 "groups must have the same length as n_samples"
617 );
618
619 let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
621 for (idx, &group) in groups.iter().enumerate() {
622 group_indices.entry(group).or_default().push(idx);
623 }
624
625 let mut unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
626 unique_groups.sort();
627
628 let mut splits = Vec::new();
629
630 for &test_group in &unique_groups {
632 let test_indices = group_indices[&test_group].clone();
633 let mut train_indices = Vec::new();
634
635 for &train_group in &unique_groups {
636 if train_group != test_group {
637 train_indices.extend(&group_indices[&train_group]);
638 }
639 }
640
641 train_indices.sort_unstable();
643
644 splits.push((train_indices, test_indices));
645 }
646
647 splits
648 }
649}
650
651impl CrossValidator for LeaveOneGroupOut {
652 fn n_splits(&self) -> usize {
653 0 }
656
657 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
658 let groups =
660 y.expect("LeaveOneGroupOut requires group labels to be provided in y parameter");
661 self.split_with_groups(n_samples, groups)
662 }
663}
664
665#[derive(Debug, Clone)]
669pub struct LeavePGroupsOut {
670 p: usize,
671}
672
673impl LeavePGroupsOut {
674 pub fn new(p: usize) -> Self {
676 assert!(p >= 1, "p must be at least 1");
677 Self { p }
678 }
679
680 pub fn split_with_groups(
682 &self,
683 n_samples: usize,
684 groups: &Array1<i32>,
685 ) -> Vec<(Vec<usize>, Vec<usize>)> {
686 assert_eq!(
687 groups.len(),
688 n_samples,
689 "groups must have the same length as n_samples"
690 );
691
692 let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
694 for (idx, &group) in groups.iter().enumerate() {
695 group_indices.entry(group).or_default().push(idx);
696 }
697
698 let unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
699 let n_groups = unique_groups.len();
700
701 assert!(
702 self.p <= n_groups,
703 "p ({}) cannot be greater than number of groups ({})",
704 self.p,
705 n_groups
706 );
707
708 let mut splits = Vec::new();
709
710 let group_combinations = combinations(&unique_groups, self.p);
712
713 for test_groups in group_combinations {
714 let mut test_indices = Vec::new();
715 for &group in &test_groups {
716 test_indices.extend(&group_indices[&group]);
717 }
718
719 let mut train_indices = Vec::new();
720 for &group in &unique_groups {
721 if !test_groups.contains(&group) {
722 train_indices.extend(&group_indices[&group]);
723 }
724 }
725
726 test_indices.sort_unstable();
728 train_indices.sort_unstable();
729
730 splits.push((train_indices, test_indices));
731 }
732
733 splits
734 }
735}
736
737impl CrossValidator for LeavePGroupsOut {
738 fn n_splits(&self) -> usize {
739 0 }
742
743 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
744 let groups =
746 y.expect("LeavePGroupsOut requires group labels to be provided in y parameter");
747 self.split_with_groups(n_samples, groups)
748 }
749}
750
751fn combinations<T: Clone>(items: &[T], k: usize) -> Vec<Vec<T>> {
753 if k == 0 {
754 return vec![vec![]];
755 }
756 if items.is_empty() {
757 return vec![];
758 }
759
760 let first = &items[0];
761 let rest = &items[1..];
762
763 let mut result = Vec::new();
764
765 for mut combo in combinations(rest, k - 1) {
767 combo.insert(0, first.clone());
768 result.push(combo);
769 }
770
771 result.extend(combinations(rest, k));
773
774 result
775}
776
777#[allow(non_snake_case)]
778#[cfg(test)]
779mod tests {
780 use super::*;
781 use scirs2_core::ndarray::array;
782
783 #[test]
784 fn test_group_kfold() {
785 let groups = array![0, 0, 1, 1, 2, 2];
786 let cv = GroupKFold::new(2);
787 let splits = cv.split_with_groups(6, &groups);
788
789 assert_eq!(splits.len(), 2);
790
791 for (train, test) in &splits {
792 let train_groups: std::collections::HashSet<i32> =
794 train.iter().map(|&idx| groups[idx]).collect();
795 let test_groups: std::collections::HashSet<i32> =
796 test.iter().map(|&idx| groups[idx]).collect();
797
798 for &test_group in &test_groups {
800 assert!(!train_groups.contains(&test_group));
801 }
802 }
803 }
804
805 #[test]
806 fn test_group_kfold_custom_strategies() {
807 let groups = array![0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 4]; let cv_balanced = GroupKFold::new_balanced(3);
810 let splits = cv_balanced.split_with_groups(12, &groups);
811
812 assert_eq!(splits.len(), 3);
813
814 for (train, test) in &splits {
816 let train_groups: std::collections::HashSet<i32> =
817 train.iter().map(|&idx| groups[idx]).collect();
818 let test_groups: std::collections::HashSet<i32> =
819 test.iter().map(|&idx| groups[idx]).collect();
820
821 for &test_group in &test_groups {
822 assert!(!train_groups.contains(&test_group));
823 }
824 }
825
826 let cv_size_aware = GroupKFold::new_size_aware(3, 3); let splits = cv_size_aware.split_with_groups(12, &groups);
829
830 assert_eq!(splits.len(), 3);
831
832 for (train, test) in &splits {
834 let train_groups: std::collections::HashSet<i32> =
835 train.iter().map(|&idx| groups[idx]).collect();
836 let test_groups: std::collections::HashSet<i32> =
837 test.iter().map(|&idx| groups[idx]).collect();
838
839 for &test_group in &test_groups {
840 assert!(!train_groups.contains(&test_group));
841 }
842 }
843
844 let cv_custom = GroupKFold::new(3).group_strategy(GroupStrategy::Balanced);
846 let splits = cv_custom.split_with_groups(12, &groups);
847
848 assert_eq!(splits.len(), 3);
849 }
850
851 #[test]
852 fn test_stratified_group_kfold() {
853 let groups = array![1, 1, 2, 2, 3, 3, 4, 4];
854 let y = array![0, 0, 1, 1, 0, 1, 0, 1];
855 let cv = StratifiedGroupKFold::new(2);
856 let splits = cv.split_with_groups_and_labels(8, &y, &groups);
857
858 assert_eq!(splits.len(), 2);
859
860 for (train_idx, test_idx) in &splits {
862 let train_groups: std::collections::HashSet<i32> =
863 train_idx.iter().map(|&i| groups[i]).collect();
864 let test_groups: std::collections::HashSet<i32> =
865 test_idx.iter().map(|&i| groups[i]).collect();
866
867 assert!(train_groups.is_disjoint(&test_groups));
868
869 let train_class_0 = train_idx.iter().filter(|&&i| y[i] == 0).count();
871 let train_class_1 = train_idx.iter().filter(|&&i| y[i] == 1).count();
872
873 assert!(train_class_0 > 0);
874 assert!(train_class_1 > 0);
875 }
876 }
877
878 #[test]
879 fn test_group_shuffle_split() {
880 let groups = array![0, 0, 1, 1, 2, 2, 3, 3];
881 let cv = GroupShuffleSplit::new(3).test_size(0.25).random_state(42);
882 let splits = cv.split_with_groups(8, &groups);
883
884 assert_eq!(splits.len(), 3);
885
886 for (train, test) in &splits {
887 let train_groups: std::collections::HashSet<i32> =
889 train.iter().map(|&idx| groups[idx]).collect();
890 let test_groups: std::collections::HashSet<i32> =
891 test.iter().map(|&idx| groups[idx]).collect();
892
893 assert!(train_groups.is_disjoint(&test_groups));
895
896 assert_eq!(test_groups.len(), 1); }
899 }
900
901 #[test]
902 fn test_leave_one_group_out() {
903 let groups = array![0, 0, 1, 1, 2, 2];
904 let cv = LeaveOneGroupOut::new();
905 let splits = cv.split_with_groups(6, &groups);
906
907 assert_eq!(splits.len(), 3);
909
910 for (train, test) in splits.iter() {
911 let test_groups: std::collections::HashSet<i32> =
913 test.iter().map(|&idx| groups[idx]).collect();
914 assert_eq!(test_groups.len(), 1);
915
916 let train_groups: std::collections::HashSet<i32> =
918 train.iter().map(|&idx| groups[idx]).collect();
919 assert_eq!(train_groups.len(), 2);
920
921 assert!(train_groups.is_disjoint(&test_groups));
923 }
924 }
925
926 #[test]
927 fn test_leave_p_groups_out() {
928 let groups = array![0, 0, 1, 1, 2, 2, 3, 3];
929 let cv = LeavePGroupsOut::new(2);
930 let splits = cv.split_with_groups(8, &groups);
931
932 assert_eq!(splits.len(), 6);
934
935 for (train, test) in &splits {
936 let test_groups: std::collections::HashSet<i32> =
938 test.iter().map(|&idx| groups[idx]).collect();
939 assert_eq!(test_groups.len(), 2);
940
941 let train_groups: std::collections::HashSet<i32> =
943 train.iter().map(|&idx| groups[idx]).collect();
944 assert_eq!(train_groups.len(), 2);
945
946 assert!(train_groups.is_disjoint(&test_groups));
948 }
949 }
950}