1use scirs2_core::ndarray::Array1;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::SeedableRng;
11use scirs2_core::SliceRandomExt;
12use std::collections::HashMap;
13
14use crate::cross_validation::CrossValidator;
15
16#[derive(Debug, Clone)]
18pub enum GroupStrategy {
19 Direct,
21 Balanced,
23 SizeAware { max_group_size: usize },
25}
26
27#[derive(Debug, Clone)]
32pub struct GroupKFold {
33 n_splits: usize,
34 group_strategy: GroupStrategy,
35}
36
37impl GroupKFold {
38 pub fn new(n_splits: usize) -> Self {
40 assert!(n_splits >= 2, "n_splits must be at least 2");
41 Self {
42 n_splits,
43 group_strategy: GroupStrategy::Direct,
44 }
45 }
46
47 pub fn new_balanced(n_splits: usize) -> Self {
49 assert!(n_splits >= 2, "n_splits must be at least 2");
50 Self {
51 n_splits,
52 group_strategy: GroupStrategy::Balanced,
53 }
54 }
55
56 pub fn new_size_aware(n_splits: usize, max_group_size: usize) -> Self {
58 assert!(n_splits >= 2, "n_splits must be at least 2");
59 Self {
60 n_splits,
61 group_strategy: GroupStrategy::SizeAware { max_group_size },
62 }
63 }
64
65 pub fn group_strategy(mut self, strategy: GroupStrategy) -> Self {
67 self.group_strategy = strategy;
68 self
69 }
70
71 pub fn split_with_groups(
73 &self,
74 n_samples: usize,
75 groups: &Array1<i32>,
76 ) -> Vec<(Vec<usize>, Vec<usize>)> {
77 assert_eq!(
78 groups.len(),
79 n_samples,
80 "groups must have the same length as n_samples"
81 );
82
83 let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
85 for (idx, &group) in groups.iter().enumerate() {
86 group_indices.entry(group).or_default().push(idx);
87 }
88
89 let mut unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
90 unique_groups.sort();
91
92 assert!(
93 unique_groups.len() >= self.n_splits,
94 "The number of groups ({}) must be at least equal to the number of splits ({})",
95 unique_groups.len(),
96 self.n_splits
97 );
98
99 match &self.group_strategy {
100 GroupStrategy::Direct => self.split_direct(&unique_groups, &group_indices),
101 GroupStrategy::Balanced => self.split_balanced(&unique_groups, &group_indices),
102 GroupStrategy::SizeAware { max_group_size } => {
103 self.split_size_aware(&unique_groups, &group_indices, *max_group_size)
104 }
105 }
106 }
107
108 fn split_direct(
109 &self,
110 unique_groups: &[i32],
111 group_indices: &HashMap<i32, Vec<usize>>,
112 ) -> Vec<(Vec<usize>, Vec<usize>)> {
113 let n_groups = unique_groups.len();
115 let groups_per_fold = n_groups / self.n_splits;
116 let n_larger_folds = n_groups % self.n_splits;
117
118 let mut splits = Vec::new();
119 let mut current_group_idx = 0;
120
121 for i in 0..self.n_splits {
122 let fold_size = if i < n_larger_folds {
123 groups_per_fold + 1
124 } else {
125 groups_per_fold
126 };
127
128 let test_groups = &unique_groups[current_group_idx..current_group_idx + fold_size];
129 let train_groups: Vec<i32> = unique_groups
130 .iter()
131 .filter(|&group| !test_groups.contains(group))
132 .cloned()
133 .collect();
134
135 let mut test_indices = Vec::new();
136 for &group in test_groups {
137 test_indices.extend(&group_indices[&group]);
138 }
139
140 let mut train_indices = Vec::new();
141 for &group in &train_groups {
142 train_indices.extend(&group_indices[&group]);
143 }
144
145 splits.push((train_indices, test_indices));
146 current_group_idx += fold_size;
147 }
148
149 splits
150 }
151
152 fn split_balanced(
153 &self,
154 unique_groups: &[i32],
155 group_indices: &HashMap<i32, Vec<usize>>,
156 ) -> Vec<(Vec<usize>, Vec<usize>)> {
157 let mut group_sizes: Vec<(i32, usize)> = unique_groups
159 .iter()
160 .map(|&group| (group, group_indices[&group].len()))
161 .collect();
162
163 group_sizes.sort_by(|a, b| b.1.cmp(&a.1));
165
166 let mut fold_assignments: Vec<Vec<i32>> = vec![Vec::new(); self.n_splits];
167 let mut fold_sizes: Vec<usize> = vec![0; self.n_splits];
168
169 for (group, size) in group_sizes {
171 let min_fold = fold_sizes
172 .iter()
173 .enumerate()
174 .min_by_key(|(_, &size)| size)
175 .map(|(idx, _)| idx)
176 .expect("operation should succeed");
177
178 fold_assignments[min_fold].push(group);
179 fold_sizes[min_fold] += size;
180 }
181
182 let mut splits = Vec::new();
183 for test_groups in fold_assignments.iter().take(self.n_splits) {
184 let train_groups: Vec<i32> = unique_groups
185 .iter()
186 .filter(|&group| !test_groups.contains(group))
187 .cloned()
188 .collect();
189
190 let mut test_indices = Vec::new();
191 for &group in test_groups {
192 test_indices.extend(&group_indices[&group]);
193 }
194
195 let mut train_indices = Vec::new();
196 for &group in &train_groups {
197 train_indices.extend(&group_indices[&group]);
198 }
199
200 splits.push((train_indices, test_indices));
201 }
202
203 splits
204 }
205
206 fn split_size_aware(
207 &self,
208 unique_groups: &[i32],
209 group_indices: &HashMap<i32, Vec<usize>>,
210 max_group_size: usize,
211 ) -> Vec<(Vec<usize>, Vec<usize>)> {
212 let mut large_groups = Vec::new();
214 let mut small_groups = Vec::new();
215
216 for &group in unique_groups {
217 if group_indices[&group].len() > max_group_size {
218 large_groups.push(group);
219 } else {
220 small_groups.push(group);
221 }
222 }
223
224 let mut fold_assignments: Vec<Vec<i32>> = vec![Vec::new(); self.n_splits];
225 let mut fold_index = 0;
226
227 for group in large_groups {
229 if fold_index < self.n_splits {
230 fold_assignments[fold_index].push(group);
231 fold_index += 1;
232 } else {
233 fold_assignments[fold_index % self.n_splits].push(group);
235 }
236 }
237
238 let mut fold_sizes: Vec<usize> = fold_assignments
240 .iter()
241 .map(|groups| groups.iter().map(|&g| group_indices[&g].len()).sum())
242 .collect();
243
244 for group in small_groups {
245 let min_fold = fold_sizes
246 .iter()
247 .enumerate()
248 .min_by_key(|(_, &size)| size)
249 .map(|(idx, _)| idx)
250 .expect("operation should succeed");
251
252 fold_assignments[min_fold].push(group);
253 fold_sizes[min_fold] += group_indices[&group].len();
254 }
255
256 let mut splits = Vec::new();
257 for test_groups in fold_assignments.iter().take(self.n_splits) {
258 let train_groups: Vec<i32> = unique_groups
259 .iter()
260 .filter(|&group| !test_groups.contains(group))
261 .cloned()
262 .collect();
263
264 let mut test_indices = Vec::new();
265 for &group in test_groups {
266 test_indices.extend(&group_indices[&group]);
267 }
268
269 let mut train_indices = Vec::new();
270 for &group in &train_groups {
271 train_indices.extend(&group_indices[&group]);
272 }
273
274 splits.push((train_indices, test_indices));
275 }
276
277 splits
278 }
279}
280
281impl CrossValidator for GroupKFold {
282 fn n_splits(&self) -> usize {
283 self.n_splits
284 }
285
286 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
287 let groups = y.expect("GroupKFold requires group labels to be provided in y parameter");
289 self.split_with_groups(n_samples, groups)
290 }
291}
292
293#[derive(Debug, Clone)]
298pub struct StratifiedGroupKFold {
299 n_splits: usize,
300 shuffle: bool,
301 random_state: Option<u64>,
302}
303
304impl StratifiedGroupKFold {
305 pub fn new(n_splits: usize) -> Self {
307 assert!(n_splits >= 2, "n_splits must be at least 2");
308 Self {
309 n_splits,
310 shuffle: false,
311 random_state: None,
312 }
313 }
314
315 pub fn shuffle(mut self, shuffle: bool) -> Self {
317 self.shuffle = shuffle;
318 self
319 }
320
321 pub fn random_state(mut self, seed: u64) -> Self {
323 self.random_state = Some(seed);
324 self
325 }
326
327 pub fn split_with_groups_and_labels(
329 &self,
330 n_samples: usize,
331 y: &Array1<i32>,
332 groups: &Array1<i32>,
333 ) -> Vec<(Vec<usize>, Vec<usize>)> {
334 assert_eq!(
335 n_samples,
336 groups.len(),
337 "n_samples and groups must have the same length"
338 );
339 assert_eq!(
340 n_samples,
341 y.len(),
342 "n_samples and y must have the same length"
343 );
344
345 let mut group_class_counts: HashMap<i32, HashMap<i32, usize>> = HashMap::new();
347 let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
348
349 for (idx, (&group, &label)) in groups.iter().zip(y.iter()).enumerate() {
350 group_indices.entry(group).or_default().push(idx);
351
352 *group_class_counts
353 .entry(group)
354 .or_default()
355 .entry(label)
356 .or_insert(0) += 1;
357 }
358
359 let mut unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
360 let n_groups = unique_groups.len();
361
362 assert!(
363 self.n_splits <= n_groups,
364 "Cannot have number of splits {} greater than the number of groups {}",
365 self.n_splits,
366 n_groups
367 );
368
369 unique_groups.sort_by_key(|&g| {
371 let size: usize = group_class_counts[&g].values().sum();
372 std::cmp::Reverse(size)
373 });
374
375 if self.shuffle {
377 let mut rng = match self.random_state {
378 Some(seed) => StdRng::seed_from_u64(seed),
379 None => {
380 use scirs2_core::random::thread_rng;
381 StdRng::from_rng(&mut thread_rng())
382 }
383 };
384 unique_groups.shuffle(&mut rng);
385 }
386
387 let mut fold_groups: Vec<Vec<i32>> = vec![Vec::new(); self.n_splits];
389 let mut fold_class_counts: Vec<HashMap<i32, usize>> = vec![HashMap::new(); self.n_splits];
390
391 for group in unique_groups {
393 let mut best_fold = 0;
395 let mut min_size = usize::MAX;
396
397 for (fold_idx, fold_counts) in fold_class_counts.iter().enumerate() {
398 let fold_size: usize = fold_counts.values().sum();
399 if fold_size < min_size {
400 min_size = fold_size;
401 best_fold = fold_idx;
402 }
403 }
404
405 fold_groups[best_fold].push(group);
407
408 for (&class, &count) in &group_class_counts[&group] {
410 *fold_class_counts[best_fold].entry(class).or_insert(0) += count;
411 }
412 }
413
414 let mut splits = Vec::new();
416
417 for test_fold_idx in 0..self.n_splits {
418 let mut test_indices = Vec::new();
419 let mut train_indices = Vec::new();
420
421 for (fold_idx, groups_in_fold) in fold_groups.iter().enumerate() {
422 for &group in groups_in_fold {
423 if fold_idx == test_fold_idx {
424 test_indices.extend(&group_indices[&group]);
425 } else {
426 train_indices.extend(&group_indices[&group]);
427 }
428 }
429 }
430
431 test_indices.sort_unstable();
433 train_indices.sort_unstable();
434
435 splits.push((train_indices, test_indices));
436 }
437
438 splits
439 }
440}
441
442impl CrossValidator for StratifiedGroupKFold {
443 fn n_splits(&self) -> usize {
444 self.n_splits
445 }
446
447 fn split(&self, _n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
448 Vec::new()
452 }
453}
454
455#[derive(Debug, Clone)]
460pub struct GroupShuffleSplit {
461 n_splits: usize,
462 test_size: Option<f64>,
463 train_size: Option<f64>,
464 random_state: Option<u64>,
465}
466
467impl GroupShuffleSplit {
468 pub fn new(n_splits: usize) -> Self {
470 Self {
471 n_splits,
472 test_size: Some(0.2),
473 train_size: None,
474 random_state: None,
475 }
476 }
477
478 pub fn test_size(mut self, size: f64) -> Self {
480 assert!(
481 (0.0..=1.0).contains(&size),
482 "test_size must be between 0.0 and 1.0"
483 );
484 self.test_size = Some(size);
485 self
486 }
487
488 pub fn train_size(mut self, size: f64) -> Self {
490 assert!(
491 (0.0..=1.0).contains(&size),
492 "train_size must be between 0.0 and 1.0"
493 );
494 self.train_size = Some(size);
495 self
496 }
497
498 pub fn random_state(mut self, seed: u64) -> Self {
500 self.random_state = Some(seed);
501 self
502 }
503
504 pub fn split_with_groups(
506 &self,
507 n_samples: usize,
508 groups: &Array1<i32>,
509 ) -> Vec<(Vec<usize>, Vec<usize>)> {
510 assert_eq!(
511 groups.len(),
512 n_samples,
513 "groups must have the same length as n_samples"
514 );
515
516 let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
518 for (idx, &group) in groups.iter().enumerate() {
519 group_indices.entry(group).or_default().push(idx);
520 }
521
522 let unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
523 let n_groups = unique_groups.len();
524
525 let test_size = self.test_size.unwrap_or(0.2);
526 let train_size = self.train_size.unwrap_or(1.0 - test_size);
527
528 assert!(
529 train_size + test_size <= 1.0,
530 "train_size + test_size cannot exceed 1.0"
531 );
532
533 let n_test_groups = ((n_groups as f64) * test_size).round() as usize;
534 let n_train_groups = ((n_groups as f64) * train_size).round() as usize;
535
536 assert!(
537 n_train_groups + n_test_groups <= n_groups,
538 "train_size + test_size results in more groups than available"
539 );
540
541 let mut rng = match self.random_state {
542 Some(seed) => StdRng::seed_from_u64(seed),
543 None => {
544 use scirs2_core::random::thread_rng;
545 StdRng::from_rng(&mut thread_rng())
546 }
547 };
548
549 let mut splits = Vec::new();
550
551 for _ in 0..self.n_splits {
552 let mut shuffled_groups = unique_groups.clone();
553 shuffled_groups.shuffle(&mut rng);
554
555 let test_groups = &shuffled_groups[..n_test_groups];
556 let train_groups = &shuffled_groups[n_test_groups..n_test_groups + n_train_groups];
557
558 let mut test_indices = Vec::new();
559 for &group in test_groups {
560 test_indices.extend(&group_indices[&group]);
561 }
562
563 let mut train_indices = Vec::new();
564 for &group in train_groups {
565 train_indices.extend(&group_indices[&group]);
566 }
567
568 test_indices.sort_unstable();
570 train_indices.sort_unstable();
571
572 splits.push((train_indices, test_indices));
573 }
574
575 splits
576 }
577}
578
579impl CrossValidator for GroupShuffleSplit {
580 fn n_splits(&self) -> usize {
581 self.n_splits
582 }
583
584 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
585 let groups =
587 y.expect("GroupShuffleSplit requires group labels to be provided in y parameter");
588 self.split_with_groups(n_samples, groups)
589 }
590}
591
592#[derive(Debug, Clone)]
596pub struct LeaveOneGroupOut;
597
598impl Default for LeaveOneGroupOut {
599 fn default() -> Self {
600 Self::new()
601 }
602}
603
604impl LeaveOneGroupOut {
605 pub fn new() -> Self {
607 LeaveOneGroupOut
608 }
609
610 pub fn split_with_groups(
612 &self,
613 n_samples: usize,
614 groups: &Array1<i32>,
615 ) -> Vec<(Vec<usize>, Vec<usize>)> {
616 assert_eq!(
617 groups.len(),
618 n_samples,
619 "groups must have the same length as n_samples"
620 );
621
622 let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
624 for (idx, &group) in groups.iter().enumerate() {
625 group_indices.entry(group).or_default().push(idx);
626 }
627
628 let mut unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
629 unique_groups.sort();
630
631 let mut splits = Vec::new();
632
633 for &test_group in &unique_groups {
635 let test_indices = group_indices[&test_group].clone();
636 let mut train_indices = Vec::new();
637
638 for &train_group in &unique_groups {
639 if train_group != test_group {
640 train_indices.extend(&group_indices[&train_group]);
641 }
642 }
643
644 train_indices.sort_unstable();
646
647 splits.push((train_indices, test_indices));
648 }
649
650 splits
651 }
652}
653
654impl CrossValidator for LeaveOneGroupOut {
655 fn n_splits(&self) -> usize {
656 0 }
659
660 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
661 let groups =
663 y.expect("LeaveOneGroupOut requires group labels to be provided in y parameter");
664 self.split_with_groups(n_samples, groups)
665 }
666}
667
668#[derive(Debug, Clone)]
672pub struct LeavePGroupsOut {
673 p: usize,
674}
675
676impl LeavePGroupsOut {
677 pub fn new(p: usize) -> Self {
679 assert!(p >= 1, "p must be at least 1");
680 Self { p }
681 }
682
683 pub fn split_with_groups(
685 &self,
686 n_samples: usize,
687 groups: &Array1<i32>,
688 ) -> Vec<(Vec<usize>, Vec<usize>)> {
689 assert_eq!(
690 groups.len(),
691 n_samples,
692 "groups must have the same length as n_samples"
693 );
694
695 let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
697 for (idx, &group) in groups.iter().enumerate() {
698 group_indices.entry(group).or_default().push(idx);
699 }
700
701 let unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
702 let n_groups = unique_groups.len();
703
704 assert!(
705 self.p <= n_groups,
706 "p ({}) cannot be greater than number of groups ({})",
707 self.p,
708 n_groups
709 );
710
711 let mut splits = Vec::new();
712
713 let group_combinations = combinations(&unique_groups, self.p);
715
716 for test_groups in group_combinations {
717 let mut test_indices = Vec::new();
718 for &group in &test_groups {
719 test_indices.extend(&group_indices[&group]);
720 }
721
722 let mut train_indices = Vec::new();
723 for &group in &unique_groups {
724 if !test_groups.contains(&group) {
725 train_indices.extend(&group_indices[&group]);
726 }
727 }
728
729 test_indices.sort_unstable();
731 train_indices.sort_unstable();
732
733 splits.push((train_indices, test_indices));
734 }
735
736 splits
737 }
738}
739
740impl CrossValidator for LeavePGroupsOut {
741 fn n_splits(&self) -> usize {
742 0 }
745
746 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
747 let groups =
749 y.expect("LeavePGroupsOut requires group labels to be provided in y parameter");
750 self.split_with_groups(n_samples, groups)
751 }
752}
753
754fn combinations<T: Clone>(items: &[T], k: usize) -> Vec<Vec<T>> {
756 if k == 0 {
757 return vec![vec![]];
758 }
759 if items.is_empty() {
760 return vec![];
761 }
762
763 let first = &items[0];
764 let rest = &items[1..];
765
766 let mut result = Vec::new();
767
768 for mut combo in combinations(rest, k - 1) {
770 combo.insert(0, first.clone());
771 result.push(combo);
772 }
773
774 result.extend(combinations(rest, k));
776
777 result
778}
779
780#[allow(non_snake_case)]
781#[cfg(test)]
782mod tests {
783 use super::*;
784 use scirs2_core::ndarray::array;
785
786 #[test]
787 fn test_group_kfold() {
788 let groups = array![0, 0, 1, 1, 2, 2];
789 let cv = GroupKFold::new(2);
790 let splits = cv.split_with_groups(6, &groups);
791
792 assert_eq!(splits.len(), 2);
793
794 for (train, test) in &splits {
795 let train_groups: std::collections::HashSet<i32> =
797 train.iter().map(|&idx| groups[idx]).collect();
798 let test_groups: std::collections::HashSet<i32> =
799 test.iter().map(|&idx| groups[idx]).collect();
800
801 for &test_group in &test_groups {
803 assert!(!train_groups.contains(&test_group));
804 }
805 }
806 }
807
808 #[test]
809 fn test_group_kfold_custom_strategies() {
810 let groups = array![0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 4]; let cv_balanced = GroupKFold::new_balanced(3);
813 let splits = cv_balanced.split_with_groups(12, &groups);
814
815 assert_eq!(splits.len(), 3);
816
817 for (train, test) in &splits {
819 let train_groups: std::collections::HashSet<i32> =
820 train.iter().map(|&idx| groups[idx]).collect();
821 let test_groups: std::collections::HashSet<i32> =
822 test.iter().map(|&idx| groups[idx]).collect();
823
824 for &test_group in &test_groups {
825 assert!(!train_groups.contains(&test_group));
826 }
827 }
828
829 let cv_size_aware = GroupKFold::new_size_aware(3, 3); let splits = cv_size_aware.split_with_groups(12, &groups);
832
833 assert_eq!(splits.len(), 3);
834
835 for (train, test) in &splits {
837 let train_groups: std::collections::HashSet<i32> =
838 train.iter().map(|&idx| groups[idx]).collect();
839 let test_groups: std::collections::HashSet<i32> =
840 test.iter().map(|&idx| groups[idx]).collect();
841
842 for &test_group in &test_groups {
843 assert!(!train_groups.contains(&test_group));
844 }
845 }
846
847 let cv_custom = GroupKFold::new(3).group_strategy(GroupStrategy::Balanced);
849 let splits = cv_custom.split_with_groups(12, &groups);
850
851 assert_eq!(splits.len(), 3);
852 }
853
854 #[test]
855 fn test_stratified_group_kfold() {
856 let groups = array![1, 1, 2, 2, 3, 3, 4, 4];
857 let y = array![0, 0, 1, 1, 0, 1, 0, 1];
858 let cv = StratifiedGroupKFold::new(2);
859 let splits = cv.split_with_groups_and_labels(8, &y, &groups);
860
861 assert_eq!(splits.len(), 2);
862
863 for (train_idx, test_idx) in &splits {
865 let train_groups: std::collections::HashSet<i32> =
866 train_idx.iter().map(|&i| groups[i]).collect();
867 let test_groups: std::collections::HashSet<i32> =
868 test_idx.iter().map(|&i| groups[i]).collect();
869
870 assert!(train_groups.is_disjoint(&test_groups));
871
872 let train_class_0 = train_idx.iter().filter(|&&i| y[i] == 0).count();
874 let train_class_1 = train_idx.iter().filter(|&&i| y[i] == 1).count();
875
876 assert!(train_class_0 > 0);
877 assert!(train_class_1 > 0);
878 }
879 }
880
881 #[test]
882 fn test_group_shuffle_split() {
883 let groups = array![0, 0, 1, 1, 2, 2, 3, 3];
884 let cv = GroupShuffleSplit::new(3).test_size(0.25).random_state(42);
885 let splits = cv.split_with_groups(8, &groups);
886
887 assert_eq!(splits.len(), 3);
888
889 for (train, test) in &splits {
890 let train_groups: std::collections::HashSet<i32> =
892 train.iter().map(|&idx| groups[idx]).collect();
893 let test_groups: std::collections::HashSet<i32> =
894 test.iter().map(|&idx| groups[idx]).collect();
895
896 assert!(train_groups.is_disjoint(&test_groups));
898
899 assert_eq!(test_groups.len(), 1); }
902 }
903
904 #[test]
905 fn test_leave_one_group_out() {
906 let groups = array![0, 0, 1, 1, 2, 2];
907 let cv = LeaveOneGroupOut::new();
908 let splits = cv.split_with_groups(6, &groups);
909
910 assert_eq!(splits.len(), 3);
912
913 for (train, test) in splits.iter() {
914 let test_groups: std::collections::HashSet<i32> =
916 test.iter().map(|&idx| groups[idx]).collect();
917 assert_eq!(test_groups.len(), 1);
918
919 let train_groups: std::collections::HashSet<i32> =
921 train.iter().map(|&idx| groups[idx]).collect();
922 assert_eq!(train_groups.len(), 2);
923
924 assert!(train_groups.is_disjoint(&test_groups));
926 }
927 }
928
929 #[test]
930 fn test_leave_p_groups_out() {
931 let groups = array![0, 0, 1, 1, 2, 2, 3, 3];
932 let cv = LeavePGroupsOut::new(2);
933 let splits = cv.split_with_groups(8, &groups);
934
935 assert_eq!(splits.len(), 6);
937
938 for (train, test) in &splits {
939 let test_groups: std::collections::HashSet<i32> =
941 test.iter().map(|&idx| groups[idx]).collect();
942 assert_eq!(test_groups.len(), 2);
943
944 let train_groups: std::collections::HashSet<i32> =
946 train.iter().map(|&idx| groups[idx]).collect();
947 assert_eq!(train_groups.len(), 2);
948
949 assert!(train_groups.is_disjoint(&test_groups));
951 }
952 }
953}