Skip to main content

tensorlogic_train/
crossval.rs

1//! Cross-validation utilities for model evaluation.
2//!
3//! This module provides various cross-validation strategies:
4//! - K-fold cross-validation
5//! - Stratified K-fold (maintains class distribution)
6//! - Time series split (preserves temporal order)
7//! - Leave-one-out cross-validation
8//! - Custom split strategies
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::random::{SeedableRng, StdRng};
12use std::collections::HashMap;
13
14/// Trait for cross-validation splitting strategies.
15pub trait CrossValidationSplit {
16    /// Get the number of splits.
17    fn num_splits(&self) -> usize;
18
19    /// Get the train/validation indices for a specific fold.
20    ///
21    /// # Arguments
22    /// * `fold` - Fold index (0 to num_splits - 1)
23    /// * `n_samples` - Total number of samples
24    ///
25    /// # Returns
26    /// (train_indices, validation_indices)
27    fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)>;
28}
29
30/// K-fold cross-validation.
31///
32/// Splits the data into K equally-sized folds. Each fold is used once as validation
33/// while the K-1 remaining folds form the training set.
34#[derive(Debug, Clone)]
35pub struct KFold {
36    /// Number of folds.
37    pub n_splits: usize,
38    /// Whether to shuffle the data before splitting.
39    pub shuffle: bool,
40    /// Random seed for shuffling.
41    pub random_seed: u64,
42}
43
44impl KFold {
45    /// Create a new K-fold splitter.
46    ///
47    /// # Arguments
48    /// * `n_splits` - Number of folds (must be >= 2)
49    pub fn new(n_splits: usize) -> TrainResult<Self> {
50        if n_splits < 2 {
51            return Err(TrainError::InvalidParameter(
52                "n_splits must be at least 2".to_string(),
53            ));
54        }
55        Ok(Self {
56            n_splits,
57            shuffle: false,
58            random_seed: 42,
59        })
60    }
61
62    /// Enable shuffling with a specific seed.
63    pub fn with_shuffle(mut self, seed: u64) -> Self {
64        self.shuffle = true;
65        self.random_seed = seed;
66        self
67    }
68}
69
70impl CrossValidationSplit for KFold {
71    fn num_splits(&self) -> usize {
72        self.n_splits
73    }
74
75    fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
76        if fold >= self.n_splits {
77            return Err(TrainError::InvalidParameter(format!(
78                "fold {} is out of range [0, {})",
79                fold, self.n_splits
80            )));
81        }
82
83        // Create indices
84        let mut indices: Vec<usize> = (0..n_samples).collect();
85
86        // Shuffle if requested
87        if self.shuffle {
88            let mut rng = StdRng::seed_from_u64(self.random_seed);
89            for i in (1..n_samples).rev() {
90                let j = rng.gen_range(0..=i);
91                indices.swap(i, j);
92            }
93        }
94
95        // Split into folds
96        let fold_size = n_samples / self.n_splits;
97        let remainder = n_samples % self.n_splits;
98
99        let mut fold_sizes = vec![fold_size; self.n_splits];
100        for fold in fold_sizes.iter_mut().take(remainder) {
101            *fold += 1;
102        }
103
104        // Compute fold boundaries
105        let mut boundaries = vec![0];
106        for size in &fold_sizes {
107            boundaries.push(boundaries.last().unwrap() + size);
108        }
109
110        // Get validation indices for this fold
111        let val_start = boundaries[fold];
112        let val_end = boundaries[fold + 1];
113        let val_indices: Vec<usize> = indices[val_start..val_end].to_vec();
114
115        // Get training indices (all others)
116        let mut train_indices = Vec::new();
117        train_indices.extend_from_slice(&indices[..val_start]);
118        train_indices.extend_from_slice(&indices[val_end..]);
119
120        Ok((train_indices, val_indices))
121    }
122}
123
124/// Stratified K-fold cross-validation.
125///
126/// Maintains class distribution in each fold (useful for imbalanced datasets).
127#[derive(Debug, Clone)]
128pub struct StratifiedKFold {
129    /// Number of folds.
130    pub n_splits: usize,
131    /// Whether to shuffle the data before splitting.
132    pub shuffle: bool,
133    /// Random seed for shuffling.
134    pub random_seed: u64,
135}
136
137impl StratifiedKFold {
138    /// Create a new stratified K-fold splitter.
139    ///
140    /// # Arguments
141    /// * `n_splits` - Number of folds (must be >= 2)
142    pub fn new(n_splits: usize) -> TrainResult<Self> {
143        if n_splits < 2 {
144            return Err(TrainError::InvalidParameter(
145                "n_splits must be at least 2".to_string(),
146            ));
147        }
148        Ok(Self {
149            n_splits,
150            shuffle: true,
151            random_seed: 42,
152        })
153    }
154
155    /// Set random seed for shuffling.
156    pub fn with_seed(mut self, seed: u64) -> Self {
157        self.random_seed = seed;
158        self
159    }
160
161    /// Get stratified split based on class labels.
162    ///
163    /// # Arguments
164    /// * `fold` - Fold index
165    /// * `labels` - Class labels for each sample
166    pub fn get_stratified_split(
167        &self,
168        fold: usize,
169        labels: &[usize],
170    ) -> TrainResult<(Vec<usize>, Vec<usize>)> {
171        if fold >= self.n_splits {
172            return Err(TrainError::InvalidParameter(format!(
173                "fold {} is out of range [0, {})",
174                fold, self.n_splits
175            )));
176        }
177
178        // Group indices by class
179        let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
180        for (i, &label) in labels.iter().enumerate() {
181            class_indices.entry(label).or_default().push(i);
182        }
183
184        // Shuffle each class if requested
185        if self.shuffle {
186            let mut rng = StdRng::seed_from_u64(self.random_seed);
187            for indices in class_indices.values_mut() {
188                for i in (1..indices.len()).rev() {
189                    let j = rng.gen_range(0..=i);
190                    indices.swap(i, j);
191                }
192            }
193        }
194
195        // Split each class into folds
196        let mut train_indices = Vec::new();
197        let mut val_indices = Vec::new();
198
199        for indices in class_indices.values() {
200            let class_size = indices.len();
201            let fold_size = class_size / self.n_splits;
202            let remainder = class_size % self.n_splits;
203
204            let mut fold_sizes = vec![fold_size; self.n_splits];
205            for fold in fold_sizes.iter_mut().take(remainder) {
206                *fold += 1;
207            }
208
209            // Compute fold boundaries
210            let mut boundaries = vec![0];
211            for size in &fold_sizes {
212                boundaries.push(boundaries.last().unwrap() + size);
213            }
214
215            // Get validation indices for this fold
216            let val_start = boundaries[fold];
217            let val_end = boundaries[fold + 1];
218            val_indices.extend_from_slice(&indices[val_start..val_end]);
219
220            // Get training indices (all others)
221            train_indices.extend_from_slice(&indices[..val_start]);
222            train_indices.extend_from_slice(&indices[val_end..]);
223        }
224
225        Ok((train_indices, val_indices))
226    }
227}
228
229impl CrossValidationSplit for StratifiedKFold {
230    fn num_splits(&self) -> usize {
231        self.n_splits
232    }
233
234    fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
235        // Default implementation: uniform distribution
236        // For actual stratification, use get_stratified_split with labels
237        let labels: Vec<usize> = (0..n_samples).map(|i| i % self.n_splits).collect();
238        self.get_stratified_split(fold, &labels)
239    }
240}
241
242/// Time series split for temporal data.
243///
244/// Respects the temporal order of data. Each training set consists of data
245/// before the validation set (no data leakage from future).
246#[derive(Debug, Clone)]
247pub struct TimeSeriesSplit {
248    /// Number of splits.
249    pub n_splits: usize,
250    /// Minimum training set size.
251    pub min_train_size: Option<usize>,
252    /// Maximum training set size (for sliding window).
253    pub max_train_size: Option<usize>,
254}
255
256impl TimeSeriesSplit {
257    /// Create a new time series split.
258    ///
259    /// # Arguments
260    /// * `n_splits` - Number of splits (must be >= 2)
261    pub fn new(n_splits: usize) -> TrainResult<Self> {
262        if n_splits < 2 {
263            return Err(TrainError::InvalidParameter(
264                "n_splits must be at least 2".to_string(),
265            ));
266        }
267        Ok(Self {
268            n_splits,
269            min_train_size: None,
270            max_train_size: None,
271        })
272    }
273
274    /// Set minimum training set size.
275    pub fn with_min_train_size(mut self, size: usize) -> Self {
276        self.min_train_size = Some(size);
277        self
278    }
279
280    /// Set maximum training set size (for sliding window).
281    pub fn with_max_train_size(mut self, size: usize) -> Self {
282        self.max_train_size = Some(size);
283        self
284    }
285}
286
287impl CrossValidationSplit for TimeSeriesSplit {
288    fn num_splits(&self) -> usize {
289        self.n_splits
290    }
291
292    fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
293        if fold >= self.n_splits {
294            return Err(TrainError::InvalidParameter(format!(
295                "fold {} is out of range [0, {})",
296                fold, self.n_splits
297            )));
298        }
299
300        // Compute validation set size
301        let test_size = n_samples / (self.n_splits + 1);
302        if test_size == 0 {
303            return Err(TrainError::InvalidParameter(
304                "Not enough samples for time series split".to_string(),
305            ));
306        }
307
308        // Validation set for this fold
309        let val_start = (fold + 1) * test_size;
310        let val_end = ((fold + 2) * test_size).min(n_samples);
311
312        // Training set: all data before validation
313        let train_end = val_start;
314        let train_start = if let Some(max_size) = self.max_train_size {
315            train_end.saturating_sub(max_size)
316        } else if let Some(min_size) = self.min_train_size {
317            if train_end < min_size {
318                return Err(TrainError::InvalidParameter(
319                    "Not enough samples for min_train_size".to_string(),
320                ));
321            }
322            0
323        } else {
324            0
325        };
326
327        let train_indices: Vec<usize> = (train_start..train_end).collect();
328        let val_indices: Vec<usize> = (val_start..val_end).collect();
329
330        if train_indices.is_empty() {
331            return Err(TrainError::InvalidParameter(
332                "Training set is empty for this fold".to_string(),
333            ));
334        }
335
336        Ok((train_indices, val_indices))
337    }
338}
339
340/// Leave-one-out cross-validation.
341///
342/// Each sample is used once as validation while all others form the training set.
343/// Useful for very small datasets but computationally expensive.
344#[derive(Debug, Clone, Default)]
345pub struct LeaveOneOut;
346
347impl LeaveOneOut {
348    /// Create a new leave-one-out splitter.
349    pub fn new() -> Self {
350        Self
351    }
352}
353
354impl CrossValidationSplit for LeaveOneOut {
355    fn num_splits(&self) -> usize {
356        // This is a placeholder; actual number depends on n_samples
357        usize::MAX
358    }
359
360    fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
361        if fold >= n_samples {
362            return Err(TrainError::InvalidParameter(format!(
363                "fold {} is out of range [0, {})",
364                fold, n_samples
365            )));
366        }
367
368        // Validation: single sample
369        let val_indices = vec![fold];
370
371        // Training: all other samples
372        let mut train_indices: Vec<usize> = (0..fold).collect();
373        train_indices.extend(fold + 1..n_samples);
374
375        Ok((train_indices, val_indices))
376    }
377}
378
379/// Cross-validation result aggregator.
380#[derive(Debug, Clone)]
381pub struct CrossValidationResults {
382    /// Scores for each fold.
383    pub fold_scores: Vec<f64>,
384    /// Additional metrics for each fold.
385    pub fold_metrics: Vec<HashMap<String, f64>>,
386}
387
388impl CrossValidationResults {
389    /// Create a new result aggregator.
390    pub fn new() -> Self {
391        Self {
392            fold_scores: Vec::new(),
393            fold_metrics: Vec::new(),
394        }
395    }
396
397    /// Add a fold result.
398    pub fn add_fold(&mut self, score: f64, metrics: HashMap<String, f64>) {
399        self.fold_scores.push(score);
400        self.fold_metrics.push(metrics);
401    }
402
403    /// Get mean score across all folds.
404    pub fn mean_score(&self) -> f64 {
405        if self.fold_scores.is_empty() {
406            return 0.0;
407        }
408        self.fold_scores.iter().sum::<f64>() / self.fold_scores.len() as f64
409    }
410
411    /// Get standard deviation of scores.
412    pub fn std_score(&self) -> f64 {
413        if self.fold_scores.len() <= 1 {
414            return 0.0;
415        }
416
417        let mean = self.mean_score();
418        let variance = self
419            .fold_scores
420            .iter()
421            .map(|&score| (score - mean).powi(2))
422            .sum::<f64>()
423            / (self.fold_scores.len() - 1) as f64;
424
425        variance.sqrt()
426    }
427
428    /// Get mean of a specific metric across all folds.
429    pub fn mean_metric(&self, metric_name: &str) -> Option<f64> {
430        if self.fold_metrics.is_empty() {
431            return None;
432        }
433
434        let mut sum = 0.0;
435        let mut count = 0;
436
437        for metrics in &self.fold_metrics {
438            if let Some(&value) = metrics.get(metric_name) {
439                sum += value;
440                count += 1;
441            }
442        }
443
444        if count > 0 {
445            Some(sum / count as f64)
446        } else {
447            None
448        }
449    }
450
451    /// Get number of folds.
452    pub fn num_folds(&self) -> usize {
453        self.fold_scores.len()
454    }
455}
456
457impl Default for CrossValidationResults {
458    fn default() -> Self {
459        Self::new()
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[test]
468    fn test_kfold_basic() {
469        let kfold = KFold::new(3).unwrap();
470        assert_eq!(kfold.num_splits(), 3);
471
472        let (train, val) = kfold.get_split(0, 10).unwrap();
473        assert!(!train.is_empty());
474        assert!(!val.is_empty());
475
476        // Train and validation should be disjoint
477        for &idx in &val {
478            assert!(!train.contains(&idx));
479        }
480
481        // Together should cover all indices
482        let mut all_indices = train.clone();
483        all_indices.extend(&val);
484        all_indices.sort();
485        assert_eq!(all_indices, (0..10).collect::<Vec<_>>());
486    }
487
488    #[test]
489    fn test_kfold_with_shuffle() {
490        let kfold = KFold::new(3).unwrap().with_shuffle(42);
491        let (train1, val1) = kfold.get_split(0, 10).unwrap();
492        let (train2, val2) = kfold.get_split(0, 10).unwrap();
493
494        // Same seed should produce same results
495        assert_eq!(train1, train2);
496        assert_eq!(val1, val2);
497    }
498
499    #[test]
500    fn test_kfold_invalid() {
501        assert!(KFold::new(1).is_err());
502        let kfold = KFold::new(3).unwrap();
503        assert!(kfold.get_split(5, 10).is_err()); // fold out of range
504    }
505
506    #[test]
507    fn test_stratified_kfold() {
508        let skfold = StratifiedKFold::new(3).unwrap();
509        assert_eq!(skfold.num_splits(), 3);
510
511        // Create balanced labels: [0, 0, 0, 1, 1, 1, 2, 2, 2]
512        let labels = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
513
514        let (_train, val) = skfold.get_stratified_split(0, &labels).unwrap();
515
516        // Check that validation set has samples from each class
517        let mut val_classes: Vec<usize> = val.iter().map(|&i| labels[i]).collect();
518        val_classes.sort();
519        val_classes.dedup();
520
521        // Should have at least some class diversity
522        assert!(!val.is_empty());
523    }
524
525    #[test]
526    fn test_time_series_split() {
527        let ts_split = TimeSeriesSplit::new(3).unwrap();
528        assert_eq!(ts_split.num_splits(), 3);
529
530        let (train, val) = ts_split.get_split(0, 10).unwrap();
531
532        // Training indices should be before validation indices
533        if !train.is_empty() && !val.is_empty() {
534            assert!(train.iter().max().unwrap() < val.iter().min().unwrap());
535        }
536    }
537
538    #[test]
539    fn test_time_series_split_with_window() {
540        let ts_split = TimeSeriesSplit::new(3)
541            .unwrap()
542            .with_min_train_size(2)
543            .with_max_train_size(5);
544
545        let (train, val) = ts_split.get_split(1, 20).unwrap();
546
547        // Training set should respect max size
548        assert!(train.len() <= 5);
549        assert!(!val.is_empty());
550    }
551
552    #[test]
553    fn test_time_series_split_invalid() {
554        let ts_split = TimeSeriesSplit::new(3).unwrap();
555
556        // Too few samples
557        assert!(ts_split.get_split(0, 2).is_err());
558
559        // Fold out of range
560        assert!(ts_split.get_split(5, 10).is_err());
561    }
562
563    #[test]
564    fn test_leave_one_out() {
565        let loo = LeaveOneOut::new();
566
567        let (train, val) = loo.get_split(0, 5).unwrap();
568
569        assert_eq!(val.len(), 1);
570        assert_eq!(train.len(), 4);
571        assert_eq!(val[0], 0);
572
573        let (train, val) = loo.get_split(3, 5).unwrap();
574        assert_eq!(val[0], 3);
575        assert_eq!(train.len(), 4);
576    }
577
578    #[test]
579    fn test_leave_one_out_invalid() {
580        let loo = LeaveOneOut::new();
581        assert!(loo.get_split(5, 5).is_err()); // fold out of range
582    }
583
584    #[test]
585    fn test_cv_results() {
586        let mut results = CrossValidationResults::new();
587
588        let mut metrics1 = HashMap::new();
589        metrics1.insert("accuracy".to_string(), 0.9);
590        results.add_fold(0.85, metrics1);
591
592        let mut metrics2 = HashMap::new();
593        metrics2.insert("accuracy".to_string(), 0.95);
594        results.add_fold(0.90, metrics2);
595
596        let mut metrics3 = HashMap::new();
597        metrics3.insert("accuracy".to_string(), 0.92);
598        results.add_fold(0.88, metrics3);
599
600        assert_eq!(results.num_folds(), 3);
601
602        // Mean score: (0.85 + 0.90 + 0.88) / 3 = 0.876666...
603        let mean = results.mean_score();
604        assert!((mean - 0.8766666).abs() < 1e-6);
605
606        // Standard deviation
607        let std = results.std_score();
608        assert!(std > 0.0);
609
610        // Mean metric
611        let mean_acc = results.mean_metric("accuracy").unwrap();
612        assert!((mean_acc - 0.923333).abs() < 1e-5);
613    }
614
615    #[test]
616    fn test_cv_results_empty() {
617        let results = CrossValidationResults::new();
618        assert_eq!(results.mean_score(), 0.0);
619        assert_eq!(results.std_score(), 0.0);
620        assert_eq!(results.num_folds(), 0);
621        assert!(results.mean_metric("accuracy").is_none());
622    }
623
624    #[test]
625    fn test_kfold_all_folds() {
626        let kfold = KFold::new(5).unwrap();
627        let n_samples = 20;
628
629        let mut all_val_indices = Vec::new();
630
631        // Collect validation indices from all folds
632        for fold in 0..5 {
633            let (_, val) = kfold.get_split(fold, n_samples).unwrap();
634            all_val_indices.extend(val);
635        }
636
637        all_val_indices.sort();
638
639        // All samples should appear exactly once in validation sets
640        assert_eq!(all_val_indices.len(), n_samples);
641        assert_eq!(all_val_indices, (0..n_samples).collect::<Vec<_>>());
642    }
643}