sklears_utils/
cross_validation.rs

1//! Advanced cross-validation utilities for machine learning
2//!
3//! This module provides sophisticated cross-validation techniques including
4//! stratified k-fold, time series cross-validation, group-based CV, and more.
5
6use crate::{UtilsError, UtilsResult};
7use scirs2_core::random::rngs::StdRng;
8use scirs2_core::random::{Rng, SeedableRng};
9use std::collections::HashMap;
10
11/// Cross-validation split information
12#[derive(Clone, Debug)]
13pub struct CVSplit {
14    /// Training indices
15    pub train: Vec<usize>,
16    /// Test indices
17    pub test: Vec<usize>,
18}
19
20/// Stratified K-Fold cross-validation generator
21///
22/// Generates k-fold splits while maintaining class distribution in each fold.
23/// Useful for classification tasks with imbalanced classes.
24#[derive(Clone, Debug)]
25pub struct StratifiedKFold {
26    n_splits: usize,
27    shuffle: bool,
28    random_state: Option<u64>,
29}
30
31impl StratifiedKFold {
32    /// Create a new stratified k-fold splitter
33    ///
34    /// # Arguments
35    /// * `n_splits` - Number of folds (must be >= 2)
36    /// * `shuffle` - Whether to shuffle data before splitting
37    /// * `random_state` - Random seed for reproducibility
38    pub fn new(n_splits: usize, shuffle: bool, random_state: Option<u64>) -> UtilsResult<Self> {
39        if n_splits < 2 {
40            return Err(UtilsError::InvalidParameter(
41                "n_splits must be at least 2".to_string(),
42            ));
43        }
44
45        Ok(Self {
46            n_splits,
47            shuffle,
48            random_state,
49        })
50    }
51
52    /// Generate stratified k-fold splits for given labels
53    ///
54    /// # Arguments
55    /// * `y` - Class labels (0-indexed integers)
56    ///
57    /// # Returns
58    /// Vector of CVSplit structures containing train/test indices for each fold
59    pub fn split(&self, y: &[usize]) -> UtilsResult<Vec<CVSplit>> {
60        if y.is_empty() {
61            return Err(UtilsError::InvalidParameter(
62                "Cannot split empty label array".to_string(),
63            ));
64        }
65
66        // Group indices by class
67        let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
68        for (idx, &label) in y.iter().enumerate() {
69            class_indices.entry(label).or_default().push(idx);
70        }
71
72        // Check that each class has at least n_splits samples
73        for (class, indices) in &class_indices {
74            if indices.len() < self.n_splits {
75                return Err(UtilsError::InvalidParameter(format!(
76                    "Class {class} has only {} samples, need at least {} for {}-fold CV",
77                    indices.len(),
78                    self.n_splits,
79                    self.n_splits
80                )));
81            }
82        }
83
84        // Shuffle class indices if requested
85        let mut rng = self
86            .random_state
87            .map(StdRng::seed_from_u64)
88            .unwrap_or_else(|| StdRng::seed_from_u64(42));
89
90        if self.shuffle {
91            for indices in class_indices.values_mut() {
92                Self::shuffle_indices(indices, &mut rng);
93            }
94        }
95
96        // Create folds by distributing each class's samples
97        let mut fold_indices: Vec<Vec<usize>> = vec![Vec::new(); self.n_splits];
98
99        for indices in class_indices.values() {
100            let fold_sizes = Self::distribute_samples(indices.len(), self.n_splits);
101            let mut current_idx = 0;
102
103            for (fold_id, size) in fold_sizes.iter().enumerate() {
104                fold_indices[fold_id].extend(&indices[current_idx..current_idx + size]);
105                current_idx += size;
106            }
107        }
108
109        // Generate train/test splits
110        let mut splits = Vec::with_capacity(self.n_splits);
111        for test_fold_id in 0..self.n_splits {
112            let mut train = Vec::new();
113            for (fold_id, indices) in fold_indices.iter().enumerate() {
114                if fold_id != test_fold_id {
115                    train.extend(indices);
116                }
117            }
118
119            splits.push(CVSplit {
120                train,
121                test: fold_indices[test_fold_id].clone(),
122            });
123        }
124
125        Ok(splits)
126    }
127
128    fn shuffle_indices(indices: &mut [usize], rng: &mut StdRng) {
129        for i in (1..indices.len()).rev() {
130            let j = rng.gen_range(0..=i);
131            indices.swap(i, j);
132        }
133    }
134
135    fn distribute_samples(n_samples: usize, n_folds: usize) -> Vec<usize> {
136        let base_size = n_samples / n_folds;
137        let remainder = n_samples % n_folds;
138
139        (0..n_folds)
140            .map(|i| {
141                if i < remainder {
142                    base_size + 1
143                } else {
144                    base_size
145                }
146            })
147            .collect()
148    }
149}
150
151/// Time Series Cross-Validation
152///
153/// Generates splits suitable for time series data where temporal order must be preserved.
154/// Uses expanding window approach where training set grows with each split.
155#[derive(Clone, Debug)]
156pub struct TimeSeriesSplit {
157    n_splits: usize,
158    test_size: Option<usize>,
159    gap: usize,
160}
161
162impl TimeSeriesSplit {
163    /// Create a new time series cross-validator
164    ///
165    /// # Arguments
166    /// * `n_splits` - Number of splits to generate
167    /// * `test_size` - Size of test set (None = automatic sizing)
168    /// * `gap` - Number of samples to exclude between train and test sets
169    pub fn new(n_splits: usize, test_size: Option<usize>, gap: usize) -> UtilsResult<Self> {
170        if n_splits < 2 {
171            return Err(UtilsError::InvalidParameter(
172                "n_splits must be at least 2".to_string(),
173            ));
174        }
175
176        Ok(Self {
177            n_splits,
178            test_size,
179            gap,
180        })
181    }
182
183    /// Generate time series splits
184    ///
185    /// # Arguments
186    /// * `n_samples` - Total number of samples
187    ///
188    /// # Returns
189    /// Vector of CVSplit structures with expanding training windows
190    pub fn split(&self, n_samples: usize) -> UtilsResult<Vec<CVSplit>> {
191        let test_size = self.test_size.unwrap_or_else(|| {
192            // Default: divide remaining samples after first split
193            (n_samples - (n_samples / (self.n_splits + 1))) / self.n_splits
194        });
195
196        let min_train_size = n_samples / (self.n_splits + 1);
197
198        if min_train_size + self.gap + test_size > n_samples {
199            return Err(UtilsError::InvalidParameter(
200                "Not enough samples for requested split configuration".to_string(),
201            ));
202        }
203
204        let mut splits = Vec::with_capacity(self.n_splits);
205
206        for i in 0..self.n_splits {
207            let train_end = min_train_size + i * test_size;
208            let test_start = train_end + self.gap;
209            let test_end = test_start + test_size;
210
211            if test_end > n_samples {
212                break;
213            }
214
215            splits.push(CVSplit {
216                train: (0..train_end).collect(),
217                test: (test_start..test_end).collect(),
218            });
219        }
220
221        if splits.len() < self.n_splits {
222            return Err(UtilsError::InvalidParameter(
223                "Cannot generate requested number of splits with given parameters".to_string(),
224            ));
225        }
226
227        Ok(splits)
228    }
229}
230
231/// Group K-Fold Cross-Validation
232///
233/// Ensures that samples from the same group are not in both train and test sets.
234/// Useful for preventing data leakage in scenarios like patient data, user data, etc.
235#[derive(Clone, Debug)]
236pub struct GroupKFold {
237    n_splits: usize,
238}
239
240impl GroupKFold {
241    /// Create a new group k-fold splitter
242    pub fn new(n_splits: usize) -> UtilsResult<Self> {
243        if n_splits < 2 {
244            return Err(UtilsError::InvalidParameter(
245                "n_splits must be at least 2".to_string(),
246            ));
247        }
248
249        Ok(Self { n_splits })
250    }
251
252    /// Generate group-based k-fold splits
253    ///
254    /// # Arguments
255    /// * `groups` - Group identifier for each sample
256    ///
257    /// # Returns
258    /// Vector of CVSplit structures ensuring group separation
259    pub fn split(&self, groups: &[usize]) -> UtilsResult<Vec<CVSplit>> {
260        if groups.is_empty() {
261            return Err(UtilsError::InvalidParameter(
262                "Cannot split empty groups array".to_string(),
263            ));
264        }
265
266        // Map groups to sample indices
267        let mut group_to_indices: HashMap<usize, Vec<usize>> = HashMap::new();
268        for (idx, &group) in groups.iter().enumerate() {
269            group_to_indices.entry(group).or_default().push(idx);
270        }
271
272        let unique_groups: Vec<usize> = group_to_indices.keys().copied().collect();
273
274        if unique_groups.len() < self.n_splits {
275            return Err(UtilsError::InvalidParameter(format!(
276                "Number of unique groups ({}) must be >= n_splits ({})",
277                unique_groups.len(),
278                self.n_splits
279            )));
280        }
281
282        // Distribute groups into folds
283        let fold_sizes = Self::distribute_groups(unique_groups.len(), self.n_splits);
284        let mut group_folds: Vec<Vec<usize>> = vec![Vec::new(); self.n_splits];
285        let mut current_idx = 0;
286
287        for (fold_id, size) in fold_sizes.iter().enumerate() {
288            group_folds[fold_id].extend(&unique_groups[current_idx..current_idx + size]);
289            current_idx += size;
290        }
291
292        // Generate splits
293        let mut splits = Vec::with_capacity(self.n_splits);
294
295        for test_fold_id in 0..self.n_splits {
296            let mut train = Vec::new();
297            let mut test = Vec::new();
298
299            for (fold_id, groups_in_fold) in group_folds.iter().enumerate() {
300                let indices: Vec<usize> = groups_in_fold
301                    .iter()
302                    .flat_map(|g| group_to_indices.get(g).unwrap())
303                    .copied()
304                    .collect();
305
306                if fold_id == test_fold_id {
307                    test.extend(indices);
308                } else {
309                    train.extend(indices);
310                }
311            }
312
313            splits.push(CVSplit { train, test });
314        }
315
316        Ok(splits)
317    }
318
319    fn distribute_groups(n_groups: usize, n_folds: usize) -> Vec<usize> {
320        let base_size = n_groups / n_folds;
321        let remainder = n_groups % n_folds;
322
323        (0..n_folds)
324            .map(|i| {
325                if i < remainder {
326                    base_size + 1
327                } else {
328                    base_size
329                }
330            })
331            .collect()
332    }
333}
334
335/// Leave-One-Group-Out Cross-Validation
336///
337/// Creates one split for each unique group, using that group as test set.
338#[derive(Clone, Debug)]
339pub struct LeaveOneGroupOut;
340
341impl LeaveOneGroupOut {
342    /// Create a new leave-one-group-out splitter
343    pub fn new() -> Self {
344        Self
345    }
346
347    /// Generate leave-one-group-out splits
348    ///
349    /// # Arguments
350    /// * `groups` - Group identifier for each sample
351    ///
352    /// # Returns
353    /// Vector of CVSplit structures, one per unique group
354    pub fn split(&self, groups: &[usize]) -> UtilsResult<Vec<CVSplit>> {
355        if groups.is_empty() {
356            return Err(UtilsError::InvalidParameter(
357                "Cannot split empty groups array".to_string(),
358            ));
359        }
360
361        // Map groups to indices
362        let mut group_to_indices: HashMap<usize, Vec<usize>> = HashMap::new();
363        for (idx, &group) in groups.iter().enumerate() {
364            group_to_indices.entry(group).or_default().push(idx);
365        }
366
367        let unique_groups: Vec<usize> = group_to_indices.keys().copied().collect();
368        let mut splits = Vec::with_capacity(unique_groups.len());
369
370        for &test_group in &unique_groups {
371            let mut train = Vec::new();
372            let test = group_to_indices.get(&test_group).unwrap().clone();
373
374            for &group in &unique_groups {
375                if group != test_group {
376                    train.extend(group_to_indices.get(&group).unwrap());
377                }
378            }
379
380            splits.push(CVSplit { train, test });
381        }
382
383        Ok(splits)
384    }
385}
386
387impl Default for LeaveOneGroupOut {
388    fn default() -> Self {
389        Self::new()
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_stratified_kfold_basic() {
399        let y = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
400        let skf = StratifiedKFold::new(3, false, Some(42)).unwrap();
401        let splits = skf.split(&y).unwrap();
402
403        assert_eq!(splits.len(), 3);
404
405        // Check that each fold has samples from all classes
406        for split in &splits {
407            let test_labels: Vec<usize> = split.test.iter().map(|&i| y[i]).collect();
408            assert!(test_labels.contains(&0));
409            assert!(test_labels.contains(&1));
410            assert!(test_labels.contains(&2));
411        }
412    }
413
414    #[test]
415    fn test_stratified_kfold_all_samples_used() {
416        let y = vec![0, 0, 1, 1, 2, 2];
417        let skf = StratifiedKFold::new(2, false, None).unwrap();
418        let splits = skf.split(&y).unwrap();
419
420        assert_eq!(splits.len(), 2);
421
422        let mut all_test_indices: Vec<usize> = Vec::new();
423        for split in &splits {
424            all_test_indices.extend(&split.test);
425        }
426        all_test_indices.sort_unstable();
427
428        assert_eq!(all_test_indices, vec![0, 1, 2, 3, 4, 5]);
429    }
430
431    #[test]
432    fn test_time_series_split_basic() {
433        let tscv = TimeSeriesSplit::new(3, Some(2), 0).unwrap();
434        let splits = tscv.split(10).unwrap();
435
436        assert_eq!(splits.len(), 3);
437
438        // Check expanding window property
439        for (i, split) in splits.iter().enumerate() {
440            assert!(split.train.len() > 0);
441            assert_eq!(split.test.len(), 2);
442            if i > 0 {
443                assert!(split.train.len() > splits[i - 1].train.len());
444            }
445        }
446    }
447
448    #[test]
449    fn test_time_series_split_with_gap() {
450        let tscv = TimeSeriesSplit::new(2, Some(2), 1).unwrap();
451        let splits = tscv.split(10).unwrap();
452
453        for split in &splits {
454            // Check gap is maintained
455            if !split.train.is_empty() && !split.test.is_empty() {
456                let train_max = *split.train.iter().max().unwrap();
457                let test_min = *split.test.iter().min().unwrap();
458                assert!(test_min > train_max); // Gap of at least 1
459            }
460        }
461    }
462
463    #[test]
464    fn test_group_kfold_basic() {
465        let groups = vec![0, 0, 1, 1, 2, 2, 3, 3];
466        let gkf = GroupKFold::new(2).unwrap();
467        let splits = gkf.split(&groups).unwrap();
468
469        assert_eq!(splits.len(), 2);
470
471        // Check group separation
472        for split in &splits {
473            let train_groups: Vec<usize> = split.train.iter().map(|&i| groups[i]).collect();
474            let test_groups: Vec<usize> = split.test.iter().map(|&i| groups[i]).collect();
475
476            // No group should appear in both train and test
477            for &test_group in &test_groups {
478                assert!(!train_groups.contains(&test_group));
479            }
480        }
481    }
482
483    #[test]
484    fn test_leave_one_group_out() {
485        let groups = vec![0, 0, 1, 1, 2, 2];
486        let logo = LeaveOneGroupOut::new();
487        let splits = logo.split(&groups).unwrap();
488
489        assert_eq!(splits.len(), 3); // 3 unique groups
490
491        // Check each split leaves out exactly one group
492        for split in &splits {
493            let test_groups: Vec<usize> = split.test.iter().map(|&idx| groups[idx]).collect();
494            let unique_test_groups: std::collections::HashSet<usize> =
495                test_groups.into_iter().collect();
496            assert_eq!(unique_test_groups.len(), 1);
497        }
498    }
499
500    #[test]
501    fn test_stratified_kfold_error_too_few_samples() {
502        let y = vec![0, 1]; // Only 2 samples
503        let skf = StratifiedKFold::new(3, false, None).unwrap();
504        assert!(skf.split(&y).is_err());
505    }
506
507    #[test]
508    fn test_time_series_split_error_insufficient_samples() {
509        let tscv = TimeSeriesSplit::new(5, Some(10), 0).unwrap();
510        assert!(tscv.split(20).is_err());
511    }
512}