sklears_model_selection/cv/
group_cv.rs

1//! Group-based cross-validation iterators
2//!
3//! This module provides cross-validation iterators that work with group labels to ensure
4//! no data leakage between groups. These are particularly useful when samples are not
5//! independent and can be grouped together (e.g., patients in medical studies, time series
6//! from the same entity, etc.).
7
8use scirs2_core::ndarray::Array1;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::SeedableRng;
11use scirs2_core::SliceRandomExt;
12use std::collections::HashMap;
13
14use crate::cross_validation::CrossValidator;
15
16/// Strategy for defining groups in GroupKFold
17#[derive(Debug, Clone)]
18pub enum GroupStrategy {
19    /// Use provided group labels directly
20    Direct,
21    /// Use balanced distribution of groups across folds
22    Balanced,
23    /// Use size-aware distribution (larger groups get separate folds)
24    SizeAware { max_group_size: usize },
25}
26
27/// Group K-Fold cross-validator with custom group definitions
28///
29/// Ensures that samples from the same group are not in both training and test sets.
30/// Supports custom grouping strategies for advanced use cases.
31#[derive(Debug, Clone)]
32pub struct GroupKFold {
33    n_splits: usize,
34    group_strategy: GroupStrategy,
35}
36
37impl GroupKFold {
38    /// Create a new GroupKFold cross-validator with direct grouping strategy
39    pub fn new(n_splits: usize) -> Self {
40        assert!(n_splits >= 2, "n_splits must be at least 2");
41        Self {
42            n_splits,
43            group_strategy: GroupStrategy::Direct,
44        }
45    }
46
47    /// Create a GroupKFold with balanced group distribution strategy
48    pub fn new_balanced(n_splits: usize) -> Self {
49        assert!(n_splits >= 2, "n_splits must be at least 2");
50        Self {
51            n_splits,
52            group_strategy: GroupStrategy::Balanced,
53        }
54    }
55
56    /// Create a GroupKFold with size-aware group distribution strategy
57    pub fn new_size_aware(n_splits: usize, max_group_size: usize) -> Self {
58        assert!(n_splits >= 2, "n_splits must be at least 2");
59        Self {
60            n_splits,
61            group_strategy: GroupStrategy::SizeAware { max_group_size },
62        }
63    }
64
65    /// Set the group strategy
66    pub fn group_strategy(mut self, strategy: GroupStrategy) -> Self {
67        self.group_strategy = strategy;
68        self
69    }
70
71    /// Split based on groups using the configured strategy to ensure no leakage
72    pub fn split_with_groups(
73        &self,
74        n_samples: usize,
75        groups: &Array1<i32>,
76    ) -> Vec<(Vec<usize>, Vec<usize>)> {
77        assert_eq!(
78            groups.len(),
79            n_samples,
80            "groups must have the same length as n_samples"
81        );
82
83        // Group indices by group labels
84        let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
85        for (idx, &group) in groups.iter().enumerate() {
86            group_indices.entry(group).or_default().push(idx);
87        }
88
89        let mut unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
90        unique_groups.sort();
91
92        assert!(
93            unique_groups.len() >= self.n_splits,
94            "The number of groups ({}) must be at least equal to the number of splits ({})",
95            unique_groups.len(),
96            self.n_splits
97        );
98
99        match &self.group_strategy {
100            GroupStrategy::Direct => self.split_direct(&unique_groups, &group_indices),
101            GroupStrategy::Balanced => self.split_balanced(&unique_groups, &group_indices),
102            GroupStrategy::SizeAware { max_group_size } => {
103                self.split_size_aware(&unique_groups, &group_indices, *max_group_size)
104            }
105        }
106    }
107
108    fn split_direct(
109        &self,
110        unique_groups: &[i32],
111        group_indices: &HashMap<i32, Vec<usize>>,
112    ) -> Vec<(Vec<usize>, Vec<usize>)> {
113        // Original distribution strategy - evenly distribute groups
114        let n_groups = unique_groups.len();
115        let groups_per_fold = n_groups / self.n_splits;
116        let n_larger_folds = n_groups % self.n_splits;
117
118        let mut splits = Vec::new();
119        let mut current_group_idx = 0;
120
121        for i in 0..self.n_splits {
122            let fold_size = if i < n_larger_folds {
123                groups_per_fold + 1
124            } else {
125                groups_per_fold
126            };
127
128            let test_groups = &unique_groups[current_group_idx..current_group_idx + fold_size];
129            let train_groups: Vec<i32> = unique_groups
130                .iter()
131                .filter(|&group| !test_groups.contains(group))
132                .cloned()
133                .collect();
134
135            let mut test_indices = Vec::new();
136            for &group in test_groups {
137                test_indices.extend(&group_indices[&group]);
138            }
139
140            let mut train_indices = Vec::new();
141            for &group in &train_groups {
142                train_indices.extend(&group_indices[&group]);
143            }
144
145            splits.push((train_indices, test_indices));
146            current_group_idx += fold_size;
147        }
148
149        splits
150    }
151
152    fn split_balanced(
153        &self,
154        unique_groups: &[i32],
155        group_indices: &HashMap<i32, Vec<usize>>,
156    ) -> Vec<(Vec<usize>, Vec<usize>)> {
157        // Balanced strategy - try to balance the number of samples in each fold
158        let mut group_sizes: Vec<(i32, usize)> = unique_groups
159            .iter()
160            .map(|&group| (group, group_indices[&group].len()))
161            .collect();
162
163        // Sort by size to distribute large groups first
164        group_sizes.sort_by(|a, b| b.1.cmp(&a.1));
165
166        let mut fold_assignments: Vec<Vec<i32>> = vec![Vec::new(); self.n_splits];
167        let mut fold_sizes: Vec<usize> = vec![0; self.n_splits];
168
169        // Assign each group to the fold with the smallest current size
170        for (group, size) in group_sizes {
171            let min_fold = fold_sizes
172                .iter()
173                .enumerate()
174                .min_by_key(|(_, &size)| size)
175                .map(|(idx, _)| idx)
176                .unwrap();
177
178            fold_assignments[min_fold].push(group);
179            fold_sizes[min_fold] += size;
180        }
181
182        let mut splits = Vec::new();
183        for test_groups in fold_assignments.iter().take(self.n_splits) {
184            let train_groups: Vec<i32> = unique_groups
185                .iter()
186                .filter(|&group| !test_groups.contains(group))
187                .cloned()
188                .collect();
189
190            let mut test_indices = Vec::new();
191            for &group in test_groups {
192                test_indices.extend(&group_indices[&group]);
193            }
194
195            let mut train_indices = Vec::new();
196            for &group in &train_groups {
197                train_indices.extend(&group_indices[&group]);
198            }
199
200            splits.push((train_indices, test_indices));
201        }
202
203        splits
204    }
205
206    fn split_size_aware(
207        &self,
208        unique_groups: &[i32],
209        group_indices: &HashMap<i32, Vec<usize>>,
210        max_group_size: usize,
211    ) -> Vec<(Vec<usize>, Vec<usize>)> {
212        // Size-aware strategy - large groups get their own folds
213        let mut large_groups = Vec::new();
214        let mut small_groups = Vec::new();
215
216        for &group in unique_groups {
217            if group_indices[&group].len() > max_group_size {
218                large_groups.push(group);
219            } else {
220                small_groups.push(group);
221            }
222        }
223
224        let mut fold_assignments: Vec<Vec<i32>> = vec![Vec::new(); self.n_splits];
225        let mut fold_index = 0;
226
227        // Assign large groups to individual folds
228        for group in large_groups {
229            if fold_index < self.n_splits {
230                fold_assignments[fold_index].push(group);
231                fold_index += 1;
232            } else {
233                // If we have more large groups than folds, add to existing folds
234                fold_assignments[fold_index % self.n_splits].push(group);
235            }
236        }
237
238        // Distribute small groups among remaining folds
239        let mut fold_sizes: Vec<usize> = fold_assignments
240            .iter()
241            .map(|groups| groups.iter().map(|&g| group_indices[&g].len()).sum())
242            .collect();
243
244        for group in small_groups {
245            let min_fold = fold_sizes
246                .iter()
247                .enumerate()
248                .min_by_key(|(_, &size)| size)
249                .map(|(idx, _)| idx)
250                .unwrap();
251
252            fold_assignments[min_fold].push(group);
253            fold_sizes[min_fold] += group_indices[&group].len();
254        }
255
256        let mut splits = Vec::new();
257        for test_groups in fold_assignments.iter().take(self.n_splits) {
258            let train_groups: Vec<i32> = unique_groups
259                .iter()
260                .filter(|&group| !test_groups.contains(group))
261                .cloned()
262                .collect();
263
264            let mut test_indices = Vec::new();
265            for &group in test_groups {
266                test_indices.extend(&group_indices[&group]);
267            }
268
269            let mut train_indices = Vec::new();
270            for &group in &train_groups {
271                train_indices.extend(&group_indices[&group]);
272            }
273
274            splits.push((train_indices, test_indices));
275        }
276
277        splits
278    }
279}
280
281impl CrossValidator for GroupKFold {
282    fn n_splits(&self) -> usize {
283        self.n_splits
284    }
285
286    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
287        // For the generic interface, we assume y contains group labels
288        let groups = y.expect("GroupKFold requires group labels to be provided in y parameter");
289        self.split_with_groups(n_samples, groups)
290    }
291}
292
293/// Stratified Group K-Fold cross-validator
294///
295/// This is a variation of GroupKFold that attempts to preserve the percentage of samples
296/// for each class while ensuring that the same group is not in both training and testing sets.
297#[derive(Debug, Clone)]
298pub struct StratifiedGroupKFold {
299    n_splits: usize,
300    shuffle: bool,
301    random_state: Option<u64>,
302}
303
304impl StratifiedGroupKFold {
305    /// Create a new StratifiedGroupKFold cross-validator
306    pub fn new(n_splits: usize) -> Self {
307        assert!(n_splits >= 2, "n_splits must be at least 2");
308        Self {
309            n_splits,
310            shuffle: false,
311            random_state: None,
312        }
313    }
314
315    /// Set whether to shuffle the groups before splitting
316    pub fn shuffle(mut self, shuffle: bool) -> Self {
317        self.shuffle = shuffle;
318        self
319    }
320
321    /// Set the random state for shuffling
322    pub fn random_state(mut self, seed: u64) -> Self {
323        self.random_state = Some(seed);
324        self
325    }
326
327    /// Split data into train/test sets based on stratification and groups
328    pub fn split_with_groups_and_labels(
329        &self,
330        n_samples: usize,
331        y: &Array1<i32>,
332        groups: &Array1<i32>,
333    ) -> Vec<(Vec<usize>, Vec<usize>)> {
334        assert_eq!(
335            n_samples,
336            groups.len(),
337            "n_samples and groups must have the same length"
338        );
339        assert_eq!(
340            n_samples,
341            y.len(),
342            "n_samples and y must have the same length"
343        );
344
345        // Get unique groups and their class distributions
346        let mut group_class_counts: HashMap<i32, HashMap<i32, usize>> = HashMap::new();
347        let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
348
349        for (idx, (&group, &label)) in groups.iter().zip(y.iter()).enumerate() {
350            group_indices.entry(group).or_default().push(idx);
351
352            *group_class_counts
353                .entry(group)
354                .or_default()
355                .entry(label)
356                .or_insert(0) += 1;
357        }
358
359        let mut unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
360        let n_groups = unique_groups.len();
361
362        assert!(
363            self.n_splits <= n_groups,
364            "Cannot have number of splits {} greater than the number of groups {}",
365            self.n_splits,
366            n_groups
367        );
368
369        // Sort groups by size (descending) for better distribution
370        unique_groups.sort_by_key(|&g| {
371            let size: usize = group_class_counts[&g].values().sum();
372            std::cmp::Reverse(size)
373        });
374
375        // Shuffle if requested
376        if self.shuffle {
377            let mut rng = match self.random_state {
378                Some(seed) => StdRng::seed_from_u64(seed),
379                None => {
380                    use scirs2_core::random::thread_rng;
381                    StdRng::from_rng(&mut thread_rng())
382                }
383            };
384            unique_groups.shuffle(&mut rng);
385        }
386
387        // Distribute groups to folds to maintain class balance
388        let mut fold_groups: Vec<Vec<i32>> = vec![Vec::new(); self.n_splits];
389        let mut fold_class_counts: Vec<HashMap<i32, usize>> = vec![HashMap::new(); self.n_splits];
390
391        // Assign groups to folds using a greedy approach
392        for group in unique_groups {
393            // Find the fold with the smallest total size
394            let mut best_fold = 0;
395            let mut min_size = usize::MAX;
396
397            for (fold_idx, fold_counts) in fold_class_counts.iter().enumerate() {
398                let fold_size: usize = fold_counts.values().sum();
399                if fold_size < min_size {
400                    min_size = fold_size;
401                    best_fold = fold_idx;
402                }
403            }
404
405            // Add group to the selected fold
406            fold_groups[best_fold].push(group);
407
408            // Update fold class counts
409            for (&class, &count) in &group_class_counts[&group] {
410                *fold_class_counts[best_fold].entry(class).or_insert(0) += count;
411            }
412        }
413
414        // Generate train/test splits
415        let mut splits = Vec::new();
416
417        for test_fold_idx in 0..self.n_splits {
418            let mut test_indices = Vec::new();
419            let mut train_indices = Vec::new();
420
421            for (fold_idx, groups_in_fold) in fold_groups.iter().enumerate() {
422                for &group in groups_in_fold {
423                    if fold_idx == test_fold_idx {
424                        test_indices.extend(&group_indices[&group]);
425                    } else {
426                        train_indices.extend(&group_indices[&group]);
427                    }
428                }
429            }
430
431            // Sort indices for consistency
432            test_indices.sort_unstable();
433            train_indices.sort_unstable();
434
435            splits.push((train_indices, test_indices));
436        }
437
438        splits
439    }
440}
441
442impl CrossValidator for StratifiedGroupKFold {
443    fn n_splits(&self) -> usize {
444        self.n_splits
445    }
446
447    fn split(&self, _n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
448        panic!("StratifiedGroupKFold requires both groups and labels. Use split_with_groups_and_labels method instead.");
449    }
450}
451
452/// Group Shuffle Split cross-validator
453///
454/// Generates random train/test splits that respect group constraints.
455/// Ensures that the same group is not in both training and test sets.
456#[derive(Debug, Clone)]
457pub struct GroupShuffleSplit {
458    n_splits: usize,
459    test_size: Option<f64>,
460    train_size: Option<f64>,
461    random_state: Option<u64>,
462}
463
464impl GroupShuffleSplit {
465    /// Create a new GroupShuffleSplit cross-validator
466    pub fn new(n_splits: usize) -> Self {
467        Self {
468            n_splits,
469            test_size: Some(0.2),
470            train_size: None,
471            random_state: None,
472        }
473    }
474
475    /// Set the test size as a proportion (0.0 to 1.0) of the groups
476    pub fn test_size(mut self, size: f64) -> Self {
477        assert!(
478            (0.0..=1.0).contains(&size),
479            "test_size must be between 0.0 and 1.0"
480        );
481        self.test_size = Some(size);
482        self
483    }
484
485    /// Set the train size as a proportion (0.0 to 1.0) of the groups
486    pub fn train_size(mut self, size: f64) -> Self {
487        assert!(
488            (0.0..=1.0).contains(&size),
489            "train_size must be between 0.0 and 1.0"
490        );
491        self.train_size = Some(size);
492        self
493    }
494
495    /// Set the random state for reproducible results
496    pub fn random_state(mut self, seed: u64) -> Self {
497        self.random_state = Some(seed);
498        self
499    }
500
501    /// Split data based on groups
502    pub fn split_with_groups(
503        &self,
504        n_samples: usize,
505        groups: &Array1<i32>,
506    ) -> Vec<(Vec<usize>, Vec<usize>)> {
507        assert_eq!(
508            groups.len(),
509            n_samples,
510            "groups must have the same length as n_samples"
511        );
512
513        // Group indices by group labels
514        let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
515        for (idx, &group) in groups.iter().enumerate() {
516            group_indices.entry(group).or_default().push(idx);
517        }
518
519        let unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
520        let n_groups = unique_groups.len();
521
522        let test_size = self.test_size.unwrap_or(0.2);
523        let train_size = self.train_size.unwrap_or(1.0 - test_size);
524
525        assert!(
526            train_size + test_size <= 1.0,
527            "train_size + test_size cannot exceed 1.0"
528        );
529
530        let n_test_groups = ((n_groups as f64) * test_size).round() as usize;
531        let n_train_groups = ((n_groups as f64) * train_size).round() as usize;
532
533        assert!(
534            n_train_groups + n_test_groups <= n_groups,
535            "train_size + test_size results in more groups than available"
536        );
537
538        let mut rng = match self.random_state {
539            Some(seed) => StdRng::seed_from_u64(seed),
540            None => {
541                use scirs2_core::random::thread_rng;
542                StdRng::from_rng(&mut thread_rng())
543            }
544        };
545
546        let mut splits = Vec::new();
547
548        for _ in 0..self.n_splits {
549            let mut shuffled_groups = unique_groups.clone();
550            shuffled_groups.shuffle(&mut rng);
551
552            let test_groups = &shuffled_groups[..n_test_groups];
553            let train_groups = &shuffled_groups[n_test_groups..n_test_groups + n_train_groups];
554
555            let mut test_indices = Vec::new();
556            for &group in test_groups {
557                test_indices.extend(&group_indices[&group]);
558            }
559
560            let mut train_indices = Vec::new();
561            for &group in train_groups {
562                train_indices.extend(&group_indices[&group]);
563            }
564
565            // Sort for consistency
566            test_indices.sort_unstable();
567            train_indices.sort_unstable();
568
569            splits.push((train_indices, test_indices));
570        }
571
572        splits
573    }
574}
575
576impl CrossValidator for GroupShuffleSplit {
577    fn n_splits(&self) -> usize {
578        self.n_splits
579    }
580
581    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
582        // For the generic interface, we assume y contains group labels
583        let groups =
584            y.expect("GroupShuffleSplit requires group labels to be provided in y parameter");
585        self.split_with_groups(n_samples, groups)
586    }
587}
588
589/// Leave One Group Out cross-validator
590///
591/// Provides train/test splits where each split leaves out one unique group.
592#[derive(Debug, Clone)]
593pub struct LeaveOneGroupOut;
594
595impl Default for LeaveOneGroupOut {
596    fn default() -> Self {
597        Self::new()
598    }
599}
600
601impl LeaveOneGroupOut {
602    /// Create a new LeaveOneGroupOut cross-validator
603    pub fn new() -> Self {
604        LeaveOneGroupOut
605    }
606
607    /// Split data based on groups
608    pub fn split_with_groups(
609        &self,
610        n_samples: usize,
611        groups: &Array1<i32>,
612    ) -> Vec<(Vec<usize>, Vec<usize>)> {
613        assert_eq!(
614            groups.len(),
615            n_samples,
616            "groups must have the same length as n_samples"
617        );
618
619        // Group indices by group labels
620        let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
621        for (idx, &group) in groups.iter().enumerate() {
622            group_indices.entry(group).or_default().push(idx);
623        }
624
625        let mut unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
626        unique_groups.sort();
627
628        let mut splits = Vec::new();
629
630        // Leave out each group one at a time
631        for &test_group in &unique_groups {
632            let test_indices = group_indices[&test_group].clone();
633            let mut train_indices = Vec::new();
634
635            for &train_group in &unique_groups {
636                if train_group != test_group {
637                    train_indices.extend(&group_indices[&train_group]);
638                }
639            }
640
641            // Sort for consistency
642            train_indices.sort_unstable();
643
644            splits.push((train_indices, test_indices));
645        }
646
647        splits
648    }
649}
650
651impl CrossValidator for LeaveOneGroupOut {
652    fn n_splits(&self) -> usize {
653        // This is dynamic based on the number of unique groups
654        0 // Will be determined during split
655    }
656
657    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
658        // For the generic interface, we assume y contains group labels
659        let groups =
660            y.expect("LeaveOneGroupOut requires group labels to be provided in y parameter");
661        self.split_with_groups(n_samples, groups)
662    }
663}
664
665/// Leave P Groups Out cross-validator
666///
667/// Provides train/test splits where each split leaves out P groups.
668#[derive(Debug, Clone)]
669pub struct LeavePGroupsOut {
670    p: usize,
671}
672
673impl LeavePGroupsOut {
674    /// Create a new LeavePGroupsOut cross-validator
675    pub fn new(p: usize) -> Self {
676        assert!(p >= 1, "p must be at least 1");
677        Self { p }
678    }
679
680    /// Split data based on groups
681    pub fn split_with_groups(
682        &self,
683        n_samples: usize,
684        groups: &Array1<i32>,
685    ) -> Vec<(Vec<usize>, Vec<usize>)> {
686        assert_eq!(
687            groups.len(),
688            n_samples,
689            "groups must have the same length as n_samples"
690        );
691
692        // Group indices by group labels
693        let mut group_indices: HashMap<i32, Vec<usize>> = HashMap::new();
694        for (idx, &group) in groups.iter().enumerate() {
695            group_indices.entry(group).or_default().push(idx);
696        }
697
698        let unique_groups: Vec<i32> = group_indices.keys().cloned().collect();
699        let n_groups = unique_groups.len();
700
701        assert!(
702            self.p <= n_groups,
703            "p ({}) cannot be greater than number of groups ({})",
704            self.p,
705            n_groups
706        );
707
708        let mut splits = Vec::new();
709
710        // Generate all combinations of p groups for test sets
711        let group_combinations = combinations(&unique_groups, self.p);
712
713        for test_groups in group_combinations {
714            let mut test_indices = Vec::new();
715            for &group in &test_groups {
716                test_indices.extend(&group_indices[&group]);
717            }
718
719            let mut train_indices = Vec::new();
720            for &group in &unique_groups {
721                if !test_groups.contains(&group) {
722                    train_indices.extend(&group_indices[&group]);
723                }
724            }
725
726            // Sort for consistency
727            test_indices.sort_unstable();
728            train_indices.sort_unstable();
729
730            splits.push((train_indices, test_indices));
731        }
732
733        splits
734    }
735}
736
737impl CrossValidator for LeavePGroupsOut {
738    fn n_splits(&self) -> usize {
739        // This is dynamic based on the number of groups and p
740        0 // Will be determined during split
741    }
742
743    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
744        // For the generic interface, we assume y contains group labels
745        let groups =
746            y.expect("LeavePGroupsOut requires group labels to be provided in y parameter");
747        self.split_with_groups(n_samples, groups)
748    }
749}
750
751/// Utility function to generate combinations
752fn combinations<T: Clone>(items: &[T], k: usize) -> Vec<Vec<T>> {
753    if k == 0 {
754        return vec![vec![]];
755    }
756    if items.is_empty() {
757        return vec![];
758    }
759
760    let first = &items[0];
761    let rest = &items[1..];
762
763    let mut result = Vec::new();
764
765    // Include first element
766    for mut combo in combinations(rest, k - 1) {
767        combo.insert(0, first.clone());
768        result.push(combo);
769    }
770
771    // Exclude first element
772    result.extend(combinations(rest, k));
773
774    result
775}
776
777#[allow(non_snake_case)]
778#[cfg(test)]
779mod tests {
780    use super::*;
781    use scirs2_core::ndarray::array;
782
783    #[test]
784    fn test_group_kfold() {
785        let groups = array![0, 0, 1, 1, 2, 2];
786        let cv = GroupKFold::new(2);
787        let splits = cv.split_with_groups(6, &groups);
788
789        assert_eq!(splits.len(), 2);
790
791        for (train, test) in &splits {
792            // Check that groups don't overlap between train and test
793            let train_groups: std::collections::HashSet<i32> =
794                train.iter().map(|&idx| groups[idx]).collect();
795            let test_groups: std::collections::HashSet<i32> =
796                test.iter().map(|&idx| groups[idx]).collect();
797
798            // No group should appear in both train and test
799            for &test_group in &test_groups {
800                assert!(!train_groups.contains(&test_group));
801            }
802        }
803    }
804
805    #[test]
806    fn test_group_kfold_custom_strategies() {
807        // Test balanced strategy
808        let groups = array![0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 4]; // Varying group sizes
809        let cv_balanced = GroupKFold::new_balanced(3);
810        let splits = cv_balanced.split_with_groups(12, &groups);
811
812        assert_eq!(splits.len(), 3);
813
814        // Check group separation
815        for (train, test) in &splits {
816            let train_groups: std::collections::HashSet<i32> =
817                train.iter().map(|&idx| groups[idx]).collect();
818            let test_groups: std::collections::HashSet<i32> =
819                test.iter().map(|&idx| groups[idx]).collect();
820
821            for &test_group in &test_groups {
822                assert!(!train_groups.contains(&test_group));
823            }
824        }
825
826        // Test size-aware strategy
827        let cv_size_aware = GroupKFold::new_size_aware(3, 3); // max_group_size = 3
828        let splits = cv_size_aware.split_with_groups(12, &groups);
829
830        assert_eq!(splits.len(), 3);
831
832        // Check group separation
833        for (train, test) in &splits {
834            let train_groups: std::collections::HashSet<i32> =
835                train.iter().map(|&idx| groups[idx]).collect();
836            let test_groups: std::collections::HashSet<i32> =
837                test.iter().map(|&idx| groups[idx]).collect();
838
839            for &test_group in &test_groups {
840                assert!(!train_groups.contains(&test_group));
841            }
842        }
843
844        // Test custom strategy assignment
845        let cv_custom = GroupKFold::new(3).group_strategy(GroupStrategy::Balanced);
846        let splits = cv_custom.split_with_groups(12, &groups);
847
848        assert_eq!(splits.len(), 3);
849    }
850
851    #[test]
852    fn test_stratified_group_kfold() {
853        let groups = array![1, 1, 2, 2, 3, 3, 4, 4];
854        let y = array![0, 0, 1, 1, 0, 1, 0, 1];
855        let cv = StratifiedGroupKFold::new(2);
856        let splits = cv.split_with_groups_and_labels(8, &y, &groups);
857
858        assert_eq!(splits.len(), 2);
859
860        // Check that groups don't overlap between train and test
861        for (train_idx, test_idx) in &splits {
862            let train_groups: std::collections::HashSet<i32> =
863                train_idx.iter().map(|&i| groups[i]).collect();
864            let test_groups: std::collections::HashSet<i32> =
865                test_idx.iter().map(|&i| groups[i]).collect();
866
867            assert!(train_groups.is_disjoint(&test_groups));
868
869            // Check that both classes are represented in training set
870            let train_class_0 = train_idx.iter().filter(|&&i| y[i] == 0).count();
871            let train_class_1 = train_idx.iter().filter(|&&i| y[i] == 1).count();
872
873            assert!(train_class_0 > 0);
874            assert!(train_class_1 > 0);
875        }
876    }
877
878    #[test]
879    fn test_group_shuffle_split() {
880        let groups = array![0, 0, 1, 1, 2, 2, 3, 3];
881        let cv = GroupShuffleSplit::new(3).test_size(0.25).random_state(42);
882        let splits = cv.split_with_groups(8, &groups);
883
884        assert_eq!(splits.len(), 3);
885
886        for (train, test) in &splits {
887            // Check that groups don't overlap between train and test
888            let train_groups: std::collections::HashSet<i32> =
889                train.iter().map(|&idx| groups[idx]).collect();
890            let test_groups: std::collections::HashSet<i32> =
891                test.iter().map(|&idx| groups[idx]).collect();
892
893            // No group should appear in both train and test
894            assert!(train_groups.is_disjoint(&test_groups));
895
896            // Check we have the expected number of groups in test
897            assert_eq!(test_groups.len(), 1); // 25% of 4 groups
898        }
899    }
900
901    #[test]
902    fn test_leave_one_group_out() {
903        let groups = array![0, 0, 1, 1, 2, 2];
904        let cv = LeaveOneGroupOut::new();
905        let splits = cv.split_with_groups(6, &groups);
906
907        // Should have as many splits as unique groups
908        assert_eq!(splits.len(), 3);
909
910        for (train, test) in splits.iter() {
911            // Each test set should contain samples from exactly one group
912            let test_groups: std::collections::HashSet<i32> =
913                test.iter().map(|&idx| groups[idx]).collect();
914            assert_eq!(test_groups.len(), 1);
915
916            // Train set should contain samples from all other groups
917            let train_groups: std::collections::HashSet<i32> =
918                train.iter().map(|&idx| groups[idx]).collect();
919            assert_eq!(train_groups.len(), 2);
920
921            // No overlap between train and test groups
922            assert!(train_groups.is_disjoint(&test_groups));
923        }
924    }
925
926    #[test]
927    fn test_leave_p_groups_out() {
928        let groups = array![0, 0, 1, 1, 2, 2, 3, 3];
929        let cv = LeavePGroupsOut::new(2);
930        let splits = cv.split_with_groups(8, &groups);
931
932        // C(4,2) = 6 combinations
933        assert_eq!(splits.len(), 6);
934
935        for (train, test) in &splits {
936            // Each test set should contain samples from exactly 2 groups
937            let test_groups: std::collections::HashSet<i32> =
938                test.iter().map(|&idx| groups[idx]).collect();
939            assert_eq!(test_groups.len(), 2);
940
941            // Train set should contain samples from the other 2 groups
942            let train_groups: std::collections::HashSet<i32> =
943                train.iter().map(|&idx| groups[idx]).collect();
944            assert_eq!(train_groups.len(), 2);
945
946            // No overlap between train and test groups
947            assert!(train_groups.is_disjoint(&test_groups));
948        }
949    }
950}