Skip to main content

sklears_model_selection/cv/
group_cv.rs

1//! Group-based cross-validation iterators
2//!
3//! This module provides cross-validation iterators that work with group labels to ensure
4//! no data leakage between groups. These are particularly useful when samples are not
5//! independent and can be grouped together (e.g., patients in medical studies, time series
6//! from the same entity, etc.).
7
8use scirs2_core::ndarray::Array1;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::SeedableRng;
11use scirs2_core::SliceRandomExt;
12use std::collections::HashMap;
13
14use crate::cross_validation::CrossValidator;
15
16/// Strategy for defining groups in GroupKFold
17#[derive(Debug, Clone)]
18pub enum GroupStrategy {
19    /// Use provided group labels directly
20    Direct,
21    /// Use balanced distribution of groups across folds
22    Balanced,
23    /// Use size-aware distribution (larger groups get separate folds)
24    SizeAware { max_group_size: usize },
25}
26
27/// Group K-Fold cross-validator with custom group definitions
28///
29/// Ensures that samples from the same group are not in both training and test sets.
30/// Supports custom grouping strategies for advanced use cases.
31#[derive(Debug, Clone)]
32pub struct GroupKFold {
33    n_splits: usize,
34    group_strategy: GroupStrategy,
35}
36
37impl GroupKFold {
38    /// Create a new GroupKFold cross-validator with direct grouping strategy
39    pub fn new(n_splits: usize) -> Self {
40        assert!(n_splits >= 2, "n_splits must be at least 2");
41        Self {
42            n_splits,
43            group_strategy: GroupStrategy::Direct,
44        }
45    }
46
47    /// Create a GroupKFold with balanced group distribution strategy
48    pub fn new_balanced(n_splits: usize) -> Self {
49        assert!(n_splits >= 2, "n_splits must be at least 2");
50        Self {
51            n_splits,
52            group_strategy: GroupStrategy::Balanced,
53        }
54    }
55
56    /// Create a GroupKFold with size-aware group distribution strategy
57    pub fn new_size_aware(n_splits: usize, max_group_size: usize) -> Self {
58        assert!(n_splits >= 2, "n_splits must be at least 2");
59        Self {
60            n_splits,
61            group_strategy: GroupStrategy::SizeAware { max_group_size },
62        }
63    }
64
65    /// Set the group strategy
66    pub fn group_strategy(mut self, strategy: GroupStrategy) -> Self {
67        self.group_strategy = strategy;
68        self
69    }
70
71    /// Split based on groups using the configured strategy to ensure no leakage
72    pub fn split_with_groups(
73        &self,
74        n_samples: usize,
75        groups: &Array1<i32>,
76    ) -> Vec<(Vec<usize>, Vec<usize>)> {
77        assert_eq!(
78            groups.len(),
79            n_samples,
80            "groups must have the same length as n_samples"
81        );
82
83        // Group indices by group labels
84        let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
85        for (idx, &group) in groups.iter().enumerate() {
86            group_indices.entry(group).or_default().push(idx);
87        }
88
89        let mut unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
90        unique_groups.sort();
91
92        assert!(
93            unique_groups.len() >= self.n_splits,
94            "The number of groups ({}) must be at least equal to the number of splits ({})",
95            unique_groups.len(),
96            self.n_splits
97        );
98
99        match &self.group_strategy {
100            GroupStrategy::Direct => self.split_direct(&unique_groups, &group_indices),
101            GroupStrategy::Balanced => self.split_balanced(&unique_groups, &group_indices),
102            GroupStrategy::SizeAware { max_group_size } => {
103                self.split_size_aware(&unique_groups, &group_indices, *max_group_size)
104            }
105        }
106    }
107
108    fn split_direct(
109        &self,
110        unique_groups: &[i32],
111        group_indices: &HashMap<i32, Vec<usize>>,
112    ) -> Vec<(Vec<usize>, Vec<usize>)> {
113        // Original distribution strategy - evenly distribute groups
114        let n_groups = unique_groups.len();
115        let groups_per_fold = n_groups / self.n_splits;
116        let n_larger_folds = n_groups % self.n_splits;
117
118        let mut splits = Vec::new();
119        let mut current_group_idx = 0;
120
121        for i in 0..self.n_splits {
122            let fold_size = if i < n_larger_folds {
123                groups_per_fold + 1
124            } else {
125                groups_per_fold
126            };
127
128            let test_groups = &unique_groups[current_group_idx..current_group_idx + fold_size];
129            let train_groups: Vec<i32> = unique_groups
130                .iter()
131                .filter(|&group| !test_groups.contains(group))
132                .cloned()
133                .collect();
134
135            let mut test_indices = Vec::new();
136            for &group in test_groups {
137                test_indices.extend(&group_indices[&group]);
138            }
139
140            let mut train_indices = Vec::new();
141            for &group in &train_groups {
142                train_indices.extend(&group_indices[&group]);
143            }
144
145            splits.push((train_indices, test_indices));
146            current_group_idx += fold_size;
147        }
148
149        splits
150    }
151
152    fn split_balanced(
153        &self,
154        unique_groups: &[i32],
155        group_indices: &HashMap<i32, Vec<usize>>,
156    ) -> Vec<(Vec<usize>, Vec<usize>)> {
157        // Balanced strategy - try to balance the number of samples in each fold
158        let mut group_sizes: Vec<(i32, usize)> = unique_groups
159            .iter()
160            .map(|&group| (group, group_indices[&group].len()))
161            .collect();
162
163        // Sort by size to distribute large groups first
164        group_sizes.sort_by(|a, b| b.1.cmp(&a.1));
165
166        let mut fold_assignments: Vec<Vec<i32>> = vec![Vec::new(); self.n_splits];
167        let mut fold_sizes: Vec<usize> = vec![0; self.n_splits];
168
169        // Assign each group to the fold with the smallest current size
170        for (group, size) in group_sizes {
171            let min_fold = fold_sizes
172                .iter()
173                .enumerate()
174                .min_by_key(|(_, &size)| size)
175                .map(|(idx, _)| idx)
176                .expect("operation should succeed");
177
178            fold_assignments[min_fold].push(group);
179            fold_sizes[min_fold] += size;
180        }
181
182        let mut splits = Vec::new();
183        for test_groups in fold_assignments.iter().take(self.n_splits) {
184            let train_groups: Vec<i32> = unique_groups
185                .iter()
186                .filter(|&group| !test_groups.contains(group))
187                .cloned()
188                .collect();
189
190            let mut test_indices = Vec::new();
191            for &group in test_groups {
192                test_indices.extend(&group_indices[&group]);
193            }
194
195            let mut train_indices = Vec::new();
196            for &group in &train_groups {
197                train_indices.extend(&group_indices[&group]);
198            }
199
200            splits.push((train_indices, test_indices));
201        }
202
203        splits
204    }
205
206    fn split_size_aware(
207        &self,
208        unique_groups: &[i32],
209        group_indices: &HashMap<i32, Vec<usize>>,
210        max_group_size: usize,
211    ) -> Vec<(Vec<usize>, Vec<usize>)> {
212        // Size-aware strategy - large groups get their own folds
213        let mut large_groups = Vec::new();
214        let mut small_groups = Vec::new();
215
216        for &group in unique_groups {
217            if group_indices[&group].len() > max_group_size {
218                large_groups.push(group);
219            } else {
220                small_groups.push(group);
221            }
222        }
223
224        let mut fold_assignments: Vec<Vec<i32>> = vec![Vec::new(); self.n_splits];
225        let mut fold_index = 0;
226
227        // Assign large groups to individual folds
228        for group in large_groups {
229            if fold_index < self.n_splits {
230                fold_assignments[fold_index].push(group);
231                fold_index += 1;
232            } else {
233                // If we have more large groups than folds, add to existing folds
234                fold_assignments[fold_index % self.n_splits].push(group);
235            }
236        }
237
238        // Distribute small groups among remaining folds
239        let mut fold_sizes: Vec<usize> = fold_assignments
240            .iter()
241            .map(|groups| groups.iter().map(|&g| group_indices[&g].len()).sum())
242            .collect();
243
244        for group in small_groups {
245            let min_fold = fold_sizes
246                .iter()
247                .enumerate()
248                .min_by_key(|(_, &size)| size)
249                .map(|(idx, _)| idx)
250                .expect("operation should succeed");
251
252            fold_assignments[min_fold].push(group);
253            fold_sizes[min_fold] += group_indices[&group].len();
254        }
255
256        let mut splits = Vec::new();
257        for test_groups in fold_assignments.iter().take(self.n_splits) {
258            let train_groups: Vec<i32> = unique_groups
259                .iter()
260                .filter(|&group| !test_groups.contains(group))
261                .cloned()
262                .collect();
263
264            let mut test_indices = Vec::new();
265            for &group in test_groups {
266                test_indices.extend(&group_indices[&group]);
267            }
268
269            let mut train_indices = Vec::new();
270            for &group in &train_groups {
271                train_indices.extend(&group_indices[&group]);
272            }
273
274            splits.push((train_indices, test_indices));
275        }
276
277        splits
278    }
279}
280
281impl CrossValidator for GroupKFold {
282    fn n_splits(&self) -> usize {
283        self.n_splits
284    }
285
286    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
287        // For the generic interface, we assume y contains group labels
288        let groups = y.expect("GroupKFold requires group labels to be provided in y parameter");
289        self.split_with_groups(n_samples, groups)
290    }
291}
292
293/// Stratified Group K-Fold cross-validator
294///
295/// This is a variation of GroupKFold that attempts to preserve the percentage of samples
296/// for each class while ensuring that the same group is not in both training and testing sets.
297#[derive(Debug, Clone)]
298pub struct StratifiedGroupKFold {
299    n_splits: usize,
300    shuffle: bool,
301    random_state: Option<u64>,
302}
303
304impl StratifiedGroupKFold {
305    /// Create a new StratifiedGroupKFold cross-validator
306    pub fn new(n_splits: usize) -> Self {
307        assert!(n_splits >= 2, "n_splits must be at least 2");
308        Self {
309            n_splits,
310            shuffle: false,
311            random_state: None,
312        }
313    }
314
315    /// Set whether to shuffle the groups before splitting
316    pub fn shuffle(mut self, shuffle: bool) -> Self {
317        self.shuffle = shuffle;
318        self
319    }
320
321    /// Set the random state for shuffling
322    pub fn random_state(mut self, seed: u64) -> Self {
323        self.random_state = Some(seed);
324        self
325    }
326
327    /// Split data into train/test sets based on stratification and groups
328    pub fn split_with_groups_and_labels(
329        &self,
330        n_samples: usize,
331        y: &Array1<i32>,
332        groups: &Array1<i32>,
333    ) -> Vec<(Vec<usize>, Vec<usize>)> {
334        assert_eq!(
335            n_samples,
336            groups.len(),
337            "n_samples and groups must have the same length"
338        );
339        assert_eq!(
340            n_samples,
341            y.len(),
342            "n_samples and y must have the same length"
343        );
344
345        // Get unique groups and their class distributions
346        let mut group_class_counts: HashMap<i32, HashMap<i32, usize>> = HashMap::new();
347        let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
348
349        for (idx, (&group, &label)) in groups.iter().zip(y.iter()).enumerate() {
350            group_indices.entry(group).or_default().push(idx);
351
352            *group_class_counts
353                .entry(group)
354                .or_default()
355                .entry(label)
356                .or_insert(0) += 1;
357        }
358
359        let mut unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
360        let n_groups = unique_groups.len();
361
362        assert!(
363            self.n_splits <= n_groups,
364            "Cannot have number of splits {} greater than the number of groups {}",
365            self.n_splits,
366            n_groups
367        );
368
369        // Sort groups by size (descending) for better distribution
370        unique_groups.sort_by_key(|&g| {
371            let size: usize = group_class_counts[&g].values().sum();
372            std::cmp::Reverse(size)
373        });
374
375        // Shuffle if requested
376        if self.shuffle {
377            let mut rng = match self.random_state {
378                Some(seed) => StdRng::seed_from_u64(seed),
379                None => {
380                    use scirs2_core::random::thread_rng;
381                    StdRng::from_rng(&mut thread_rng())
382                }
383            };
384            unique_groups.shuffle(&mut rng);
385        }
386
387        // Distribute groups to folds to maintain class balance
388        let mut fold_groups: Vec<Vec<i32>> = vec![Vec::new(); self.n_splits];
389        let mut fold_class_counts: Vec<HashMap<i32, usize>> = vec![HashMap::new(); self.n_splits];
390
391        // Assign groups to folds using a greedy approach
392        for group in unique_groups {
393            // Find the fold with the smallest total size
394            let mut best_fold = 0;
395            let mut min_size = usize::MAX;
396
397            for (fold_idx, fold_counts) in fold_class_counts.iter().enumerate() {
398                let fold_size: usize = fold_counts.values().sum();
399                if fold_size < min_size {
400                    min_size = fold_size;
401                    best_fold = fold_idx;
402                }
403            }
404
405            // Add group to the selected fold
406            fold_groups[best_fold].push(group);
407
408            // Update fold class counts
409            for (&class, &count) in &group_class_counts[&group] {
410                *fold_class_counts[best_fold].entry(class).or_insert(0) += count;
411            }
412        }
413
414        // Generate train/test splits
415        let mut splits = Vec::new();
416
417        for test_fold_idx in 0..self.n_splits {
418            let mut test_indices = Vec::new();
419            let mut train_indices = Vec::new();
420
421            for (fold_idx, groups_in_fold) in fold_groups.iter().enumerate() {
422                for &group in groups_in_fold {
423                    if fold_idx == test_fold_idx {
424                        test_indices.extend(&group_indices[&group]);
425                    } else {
426                        train_indices.extend(&group_indices[&group]);
427                    }
428                }
429            }
430
431            // Sort indices for consistency
432            test_indices.sort_unstable();
433            train_indices.sort_unstable();
434
435            splits.push((train_indices, test_indices));
436        }
437
438        splits
439    }
440}
441
442impl CrossValidator for StratifiedGroupKFold {
443    fn n_splits(&self) -> usize {
444        self.n_splits
445    }
446
447    fn split(&self, _n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
448        // StratifiedGroupKFold requires both groups and labels.
449        // Use split_with_groups_and_labels method instead.
450        // Return empty splits to signal that this method cannot be used directly.
451        Vec::new()
452    }
453}
454
455/// Group Shuffle Split cross-validator
456///
457/// Generates random train/test splits that respect group constraints.
458/// Ensures that the same group is not in both training and test sets.
459#[derive(Debug, Clone)]
460pub struct GroupShuffleSplit {
461    n_splits: usize,
462    test_size: Option<f64>,
463    train_size: Option<f64>,
464    random_state: Option<u64>,
465}
466
467impl GroupShuffleSplit {
468    /// Create a new GroupShuffleSplit cross-validator
469    pub fn new(n_splits: usize) -> Self {
470        Self {
471            n_splits,
472            test_size: Some(0.2),
473            train_size: None,
474            random_state: None,
475        }
476    }
477
478    /// Set the test size as a proportion (0.0 to 1.0) of the groups
479    pub fn test_size(mut self, size: f64) -> Self {
480        assert!(
481            (0.0..=1.0).contains(&size),
482            "test_size must be between 0.0 and 1.0"
483        );
484        self.test_size = Some(size);
485        self
486    }
487
488    /// Set the train size as a proportion (0.0 to 1.0) of the groups
489    pub fn train_size(mut self, size: f64) -> Self {
490        assert!(
491            (0.0..=1.0).contains(&size),
492            "train_size must be between 0.0 and 1.0"
493        );
494        self.train_size = Some(size);
495        self
496    }
497
498    /// Set the random state for reproducible results
499    pub fn random_state(mut self, seed: u64) -> Self {
500        self.random_state = Some(seed);
501        self
502    }
503
504    /// Split data based on groups
505    pub fn split_with_groups(
506        &self,
507        n_samples: usize,
508        groups: &Array1<i32>,
509    ) -> Vec<(Vec<usize>, Vec<usize>)> {
510        assert_eq!(
511            groups.len(),
512            n_samples,
513            "groups must have the same length as n_samples"
514        );
515
516        // Group indices by group labels
517        let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
518        for (idx, &group) in groups.iter().enumerate() {
519            group_indices.entry(group).or_default().push(idx);
520        }
521
522        let unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
523        let n_groups = unique_groups.len();
524
525        let test_size = self.test_size.unwrap_or(0.2);
526        let train_size = self.train_size.unwrap_or(1.0 - test_size);
527
528        assert!(
529            train_size + test_size <= 1.0,
530            "train_size + test_size cannot exceed 1.0"
531        );
532
533        let n_test_groups = ((n_groups as f64) * test_size).round() as usize;
534        let n_train_groups = ((n_groups as f64) * train_size).round() as usize;
535
536        assert!(
537            n_train_groups + n_test_groups <= n_groups,
538            "train_size + test_size results in more groups than available"
539        );
540
541        let mut rng = match self.random_state {
542            Some(seed) => StdRng::seed_from_u64(seed),
543            None => {
544                use scirs2_core::random::thread_rng;
545                StdRng::from_rng(&mut thread_rng())
546            }
547        };
548
549        let mut splits = Vec::new();
550
551        for _ in 0..self.n_splits {
552            let mut shuffled_groups = unique_groups.clone();
553            shuffled_groups.shuffle(&mut rng);
554
555            let test_groups = &shuffled_groups[..n_test_groups];
556            let train_groups = &shuffled_groups[n_test_groups..n_test_groups + n_train_groups];
557
558            let mut test_indices = Vec::new();
559            for &group in test_groups {
560                test_indices.extend(&group_indices[&group]);
561            }
562
563            let mut train_indices = Vec::new();
564            for &group in train_groups {
565                train_indices.extend(&group_indices[&group]);
566            }
567
568            // Sort for consistency
569            test_indices.sort_unstable();
570            train_indices.sort_unstable();
571
572            splits.push((train_indices, test_indices));
573        }
574
575        splits
576    }
577}
578
579impl CrossValidator for GroupShuffleSplit {
580    fn n_splits(&self) -> usize {
581        self.n_splits
582    }
583
584    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
585        // For the generic interface, we assume y contains group labels
586        let groups =
587            y.expect("GroupShuffleSplit requires group labels to be provided in y parameter");
588        self.split_with_groups(n_samples, groups)
589    }
590}
591
592/// Leave One Group Out cross-validator
593///
594/// Provides train/test splits where each split leaves out one unique group.
595#[derive(Debug, Clone)]
596pub struct LeaveOneGroupOut;
597
598impl Default for LeaveOneGroupOut {
599    fn default() -> Self {
600        Self::new()
601    }
602}
603
604impl LeaveOneGroupOut {
605    /// Create a new LeaveOneGroupOut cross-validator
606    pub fn new() -> Self {
607        LeaveOneGroupOut
608    }
609
610    /// Split data based on groups
611    pub fn split_with_groups(
612        &self,
613        n_samples: usize,
614        groups: &Array1<i32>,
615    ) -> Vec<(Vec<usize>, Vec<usize>)> {
616        assert_eq!(
617            groups.len(),
618            n_samples,
619            "groups must have the same length as n_samples"
620        );
621
622        // Group indices by group labels
623        let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
624        for (idx, &group) in groups.iter().enumerate() {
625            group_indices.entry(group).or_default().push(idx);
626        }
627
628        let mut unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
629        unique_groups.sort();
630
631        let mut splits = Vec::new();
632
633        // Leave out each group one at a time
634        for &test_group in &unique_groups {
635            let test_indices = group_indices[&test_group].clone();
636            let mut train_indices = Vec::new();
637
638            for &train_group in &unique_groups {
639                if train_group != test_group {
640                    train_indices.extend(&group_indices[&train_group]);
641                }
642            }
643
644            // Sort for consistency
645            train_indices.sort_unstable();
646
647            splits.push((train_indices, test_indices));
648        }
649
650        splits
651    }
652}
653
654impl CrossValidator for LeaveOneGroupOut {
655    fn n_splits(&self) -> usize {
656        // This is dynamic based on the number of unique groups
657        0 // Will be determined during split
658    }
659
660    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
661        // For the generic interface, we assume y contains group labels
662        let groups =
663            y.expect("LeaveOneGroupOut requires group labels to be provided in y parameter");
664        self.split_with_groups(n_samples, groups)
665    }
666}
667
668/// Leave P Groups Out cross-validator
669///
670/// Provides train/test splits where each split leaves out P groups.
671#[derive(Debug, Clone)]
672pub struct LeavePGroupsOut {
673    p: usize,
674}
675
676impl LeavePGroupsOut {
677    /// Create a new LeavePGroupsOut cross-validator
678    pub fn new(p: usize) -> Self {
679        assert!(p >= 1, "p must be at least 1");
680        Self { p }
681    }
682
683    /// Split data based on groups
684    pub fn split_with_groups(
685        &self,
686        n_samples: usize,
687        groups: &Array1<i32>,
688    ) -> Vec<(Vec<usize>, Vec<usize>)> {
689        assert_eq!(
690            groups.len(),
691            n_samples,
692            "groups must have the same length as n_samples"
693        );
694
695        // Group indices by group labels
696        let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
697        for (idx, &group) in groups.iter().enumerate() {
698            group_indices.entry(group).or_default().push(idx);
699        }
700
701        let unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
702        let n_groups = unique_groups.len();
703
704        assert!(
705            self.p <= n_groups,
706            "p ({}) cannot be greater than number of groups ({})",
707            self.p,
708            n_groups
709        );
710
711        let mut splits = Vec::new();
712
713        // Generate all combinations of p groups for test sets
714        let group_combinations = combinations(&unique_groups, self.p);
715
716        for test_groups in group_combinations {
717            let mut test_indices = Vec::new();
718            for &group in &test_groups {
719                test_indices.extend(&group_indices[&group]);
720            }
721
722            let mut train_indices = Vec::new();
723            for &group in &unique_groups {
724                if !test_groups.contains(&group) {
725                    train_indices.extend(&group_indices[&group]);
726                }
727            }
728
729            // Sort for consistency
730            test_indices.sort_unstable();
731            train_indices.sort_unstable();
732
733            splits.push((train_indices, test_indices));
734        }
735
736        splits
737    }
738}
739
740impl CrossValidator for LeavePGroupsOut {
741    fn n_splits(&self) -> usize {
742        // This is dynamic based on the number of groups and p
743        0 // Will be determined during split
744    }
745
746    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
747        // For the generic interface, we assume y contains group labels
748        let groups =
749            y.expect("LeavePGroupsOut requires group labels to be provided in y parameter");
750        self.split_with_groups(n_samples, groups)
751    }
752}
753
754/// Utility function to generate combinations
755fn combinations<T: Clone>(items: &[T], k: usize) -> Vec<Vec<T>> {
756    if k == 0 {
757        return vec![vec![]];
758    }
759    if items.is_empty() {
760        return vec![];
761    }
762
763    let first = &items[0];
764    let rest = &items[1..];
765
766    let mut result = Vec::new();
767
768    // Include first element
769    for mut combo in combinations(rest, k - 1) {
770        combo.insert(0, first.clone());
771        result.push(combo);
772    }
773
774    // Exclude first element
775    result.extend(combinations(rest, k));
776
777    result
778}
779
780#[allow(non_snake_case)]
781#[cfg(test)]
782mod tests {
783    use super::*;
784    use scirs2_core::ndarray::array;
785
786    #[test]
787    fn test_group_kfold() {
788        let groups = array![0, 0, 1, 1, 2, 2];
789        let cv = GroupKFold::new(2);
790        let splits = cv.split_with_groups(6, &groups);
791
792        assert_eq!(splits.len(), 2);
793
794        for (train, test) in &splits {
795            // Check that groups don't overlap between train and test
796            let train_groups: std::collections::HashSet<i32> =
797                train.iter().map(|&idx| groups[idx]).collect();
798            let test_groups: std::collections::HashSet<i32> =
799                test.iter().map(|&idx| groups[idx]).collect();
800
801            // No group should appear in both train and test
802            for &test_group in &test_groups {
803                assert!(!train_groups.contains(&test_group));
804            }
805        }
806    }
807
808    #[test]
809    fn test_group_kfold_custom_strategies() {
810        // Test balanced strategy
811        let groups = array![0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 4]; // Varying group sizes
812        let cv_balanced = GroupKFold::new_balanced(3);
813        let splits = cv_balanced.split_with_groups(12, &groups);
814
815        assert_eq!(splits.len(), 3);
816
817        // Check group separation
818        for (train, test) in &splits {
819            let train_groups: std::collections::HashSet<i32> =
820                train.iter().map(|&idx| groups[idx]).collect();
821            let test_groups: std::collections::HashSet<i32> =
822                test.iter().map(|&idx| groups[idx]).collect();
823
824            for &test_group in &test_groups {
825                assert!(!train_groups.contains(&test_group));
826            }
827        }
828
829        // Test size-aware strategy
830        let cv_size_aware = GroupKFold::new_size_aware(3, 3); // max_group_size = 3
831        let splits = cv_size_aware.split_with_groups(12, &groups);
832
833        assert_eq!(splits.len(), 3);
834
835        // Check group separation
836        for (train, test) in &splits {
837            let train_groups: std::collections::HashSet<i32> =
838                train.iter().map(|&idx| groups[idx]).collect();
839            let test_groups: std::collections::HashSet<i32> =
840                test.iter().map(|&idx| groups[idx]).collect();
841
842            for &test_group in &test_groups {
843                assert!(!train_groups.contains(&test_group));
844            }
845        }
846
847        // Test custom strategy assignment
848        let cv_custom = GroupKFold::new(3).group_strategy(GroupStrategy::Balanced);
849        let splits = cv_custom.split_with_groups(12, &groups);
850
851        assert_eq!(splits.len(), 3);
852    }
853
854    #[test]
855    fn test_stratified_group_kfold() {
856        let groups = array![1, 1, 2, 2, 3, 3, 4, 4];
857        let y = array![0, 0, 1, 1, 0, 1, 0, 1];
858        let cv = StratifiedGroupKFold::new(2);
859        let splits = cv.split_with_groups_and_labels(8, &y, &groups);
860
861        assert_eq!(splits.len(), 2);
862
863        // Check that groups don't overlap between train and test
864        for (train_idx, test_idx) in &splits {
865            let train_groups: std::collections::HashSet<i32> =
866                train_idx.iter().map(|&i| groups[i]).collect();
867            let test_groups: std::collections::HashSet<i32> =
868                test_idx.iter().map(|&i| groups[i]).collect();
869
870            assert!(train_groups.is_disjoint(&test_groups));
871
872            // Check that both classes are represented in training set
873            let train_class_0 = train_idx.iter().filter(|&&i| y[i] == 0).count();
874            let train_class_1 = train_idx.iter().filter(|&&i| y[i] == 1).count();
875
876            assert!(train_class_0 > 0);
877            assert!(train_class_1 > 0);
878        }
879    }
880
881    #[test]
882    fn test_group_shuffle_split() {
883        let groups = array![0, 0, 1, 1, 2, 2, 3, 3];
884        let cv = GroupShuffleSplit::new(3).test_size(0.25).random_state(42);
885        let splits = cv.split_with_groups(8, &groups);
886
887        assert_eq!(splits.len(), 3);
888
889        for (train, test) in &splits {
890            // Check that groups don't overlap between train and test
891            let train_groups: std::collections::HashSet<i32> =
892                train.iter().map(|&idx| groups[idx]).collect();
893            let test_groups: std::collections::HashSet<i32> =
894                test.iter().map(|&idx| groups[idx]).collect();
895
896            // No group should appear in both train and test
897            assert!(train_groups.is_disjoint(&test_groups));
898
899            // Check we have the expected number of groups in test
900            assert_eq!(test_groups.len(), 1); // 25% of 4 groups
901        }
902    }
903
904    #[test]
905    fn test_leave_one_group_out() {
906        let groups = array![0, 0, 1, 1, 2, 2];
907        let cv = LeaveOneGroupOut::new();
908        let splits = cv.split_with_groups(6, &groups);
909
910        // Should have as many splits as unique groups
911        assert_eq!(splits.len(), 3);
912
913        for (train, test) in splits.iter() {
914            // Each test set should contain samples from exactly one group
915            let test_groups: std::collections::HashSet<i32> =
916                test.iter().map(|&idx| groups[idx]).collect();
917            assert_eq!(test_groups.len(), 1);
918
919            // Train set should contain samples from all other groups
920            let train_groups: std::collections::HashSet<i32> =
921                train.iter().map(|&idx| groups[idx]).collect();
922            assert_eq!(train_groups.len(), 2);
923
924            // No overlap between train and test groups
925            assert!(train_groups.is_disjoint(&test_groups));
926        }
927    }
928
929    #[test]
930    fn test_leave_p_groups_out() {
931        let groups = array![0, 0, 1, 1, 2, 2, 3, 3];
932        let cv = LeavePGroupsOut::new(2);
933        let splits = cv.split_with_groups(8, &groups);
934
935        // C(4,2) = 6 combinations
936        assert_eq!(splits.len(), 6);
937
938        for (train, test) in &splits {
939            // Each test set should contain samples from exactly 2 groups
940            let test_groups: std::collections::HashSet<i32> =
941                test.iter().map(|&idx| groups[idx]).collect();
942            assert_eq!(test_groups.len(), 2);
943
944            // Train set should contain samples from the other 2 groups
945            let train_groups: std::collections::HashSet<i32> =
946                train.iter().map(|&idx| groups[idx]).collect();
947            assert_eq!(train_groups.len(), 2);
948
949            // No overlap between train and test groups
950            assert!(train_groups.is_disjoint(&test_groups));
951        }
952    }
953}