Skip to main content

tensorlogic_train/
crossval.rs

1//! Cross-validation utilities for model evaluation.
2//!
3//! This module provides various cross-validation strategies:
4//! - K-fold cross-validation
5//! - Stratified K-fold (maintains class distribution)
6//! - Time series split (preserves temporal order)
7//! - Leave-one-out cross-validation
8//! - Custom split strategies
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::random::{SeedableRng, StdRng};
12use std::collections::HashMap;
13
14/// Trait for cross-validation splitting strategies.
15pub trait CrossValidationSplit {
16    /// Get the number of splits.
17    fn num_splits(&self) -> usize;
18
19    /// Get the train/validation indices for a specific fold.
20    ///
21    /// # Arguments
22    /// * `fold` - Fold index (0 to num_splits - 1)
23    /// * `n_samples` - Total number of samples
24    ///
25    /// # Returns
26    /// (train_indices, validation_indices)
27    fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)>;
28}
29
30/// K-fold cross-validation.
31///
32/// Splits the data into K equally-sized folds. Each fold is used once as validation
33/// while the K-1 remaining folds form the training set.
34#[derive(Debug, Clone)]
35pub struct KFold {
36    /// Number of folds.
37    pub n_splits: usize,
38    /// Whether to shuffle the data before splitting.
39    pub shuffle: bool,
40    /// Random seed for shuffling.
41    pub random_seed: u64,
42}
43
44impl KFold {
45    /// Create a new K-fold splitter.
46    ///
47    /// # Arguments
48    /// * `n_splits` - Number of folds (must be >= 2)
49    pub fn new(n_splits: usize) -> TrainResult<Self> {
50        if n_splits < 2 {
51            return Err(TrainError::InvalidParameter(
52                "n_splits must be at least 2".to_string(),
53            ));
54        }
55        Ok(Self {
56            n_splits,
57            shuffle: false,
58            random_seed: 42,
59        })
60    }
61
62    /// Enable shuffling with a specific seed.
63    pub fn with_shuffle(mut self, seed: u64) -> Self {
64        self.shuffle = true;
65        self.random_seed = seed;
66        self
67    }
68}
69
70impl CrossValidationSplit for KFold {
71    fn num_splits(&self) -> usize {
72        self.n_splits
73    }
74
75    fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
76        if fold >= self.n_splits {
77            return Err(TrainError::InvalidParameter(format!(
78                "fold {} is out of range [0, {})",
79                fold, self.n_splits
80            )));
81        }
82
83        // Create indices
84        let mut indices: Vec<usize> = (0..n_samples).collect();
85
86        // Shuffle if requested
87        if self.shuffle {
88            let mut rng = StdRng::seed_from_u64(self.random_seed);
89            for i in (1..n_samples).rev() {
90                let j = rng.gen_range(0..=i);
91                indices.swap(i, j);
92            }
93        }
94
95        // Split into folds
96        let fold_size = n_samples / self.n_splits;
97        let remainder = n_samples % self.n_splits;
98
99        let mut fold_sizes = vec![fold_size; self.n_splits];
100        for fold in fold_sizes.iter_mut().take(remainder) {
101            *fold += 1;
102        }
103
104        // Compute fold boundaries
105        let mut boundaries = vec![0];
106        for size in &fold_sizes {
107            boundaries.push(
108                boundaries
109                    .last()
110                    .expect("boundaries is initialized non-empty")
111                    + size,
112            );
113        }
114
115        // Get validation indices for this fold
116        let val_start = boundaries[fold];
117        let val_end = boundaries[fold + 1];
118        let val_indices: Vec<usize> = indices[val_start..val_end].to_vec();
119
120        // Get training indices (all others)
121        let mut train_indices = Vec::new();
122        train_indices.extend_from_slice(&indices[..val_start]);
123        train_indices.extend_from_slice(&indices[val_end..]);
124
125        Ok((train_indices, val_indices))
126    }
127}
128
129/// Stratified K-fold cross-validation.
130///
131/// Maintains class distribution in each fold (useful for imbalanced datasets).
132#[derive(Debug, Clone)]
133pub struct StratifiedKFold {
134    /// Number of folds.
135    pub n_splits: usize,
136    /// Whether to shuffle the data before splitting.
137    pub shuffle: bool,
138    /// Random seed for shuffling.
139    pub random_seed: u64,
140}
141
142impl StratifiedKFold {
143    /// Create a new stratified K-fold splitter.
144    ///
145    /// # Arguments
146    /// * `n_splits` - Number of folds (must be >= 2)
147    pub fn new(n_splits: usize) -> TrainResult<Self> {
148        if n_splits < 2 {
149            return Err(TrainError::InvalidParameter(
150                "n_splits must be at least 2".to_string(),
151            ));
152        }
153        Ok(Self {
154            n_splits,
155            shuffle: true,
156            random_seed: 42,
157        })
158    }
159
160    /// Set random seed for shuffling.
161    pub fn with_seed(mut self, seed: u64) -> Self {
162        self.random_seed = seed;
163        self
164    }
165
166    /// Get stratified split based on class labels.
167    ///
168    /// # Arguments
169    /// * `fold` - Fold index
170    /// * `labels` - Class labels for each sample
171    pub fn get_stratified_split(
172        &self,
173        fold: usize,
174        labels: &[usize],
175    ) -> TrainResult<(Vec<usize>, Vec<usize>)> {
176        if fold >= self.n_splits {
177            return Err(TrainError::InvalidParameter(format!(
178                "fold {} is out of range [0, {})",
179                fold, self.n_splits
180            )));
181        }
182
183        // Group indices by class
184        let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
185        for (i, &label) in labels.iter().enumerate() {
186            class_indices.entry(label).or_default().push(i);
187        }
188
189        // Shuffle each class if requested
190        if self.shuffle {
191            let mut rng = StdRng::seed_from_u64(self.random_seed);
192            for indices in class_indices.values_mut() {
193                for i in (1..indices.len()).rev() {
194                    let j = rng.gen_range(0..=i);
195                    indices.swap(i, j);
196                }
197            }
198        }
199
200        // Split each class into folds
201        let mut train_indices = Vec::new();
202        let mut val_indices = Vec::new();
203
204        for indices in class_indices.values() {
205            let class_size = indices.len();
206            let fold_size = class_size / self.n_splits;
207            let remainder = class_size % self.n_splits;
208
209            let mut fold_sizes = vec![fold_size; self.n_splits];
210            for fold in fold_sizes.iter_mut().take(remainder) {
211                *fold += 1;
212            }
213
214            // Compute fold boundaries
215            let mut boundaries = vec![0];
216            for size in &fold_sizes {
217                boundaries.push(
218                    boundaries
219                        .last()
220                        .expect("boundaries is initialized non-empty")
221                        + size,
222                );
223            }
224
225            // Get validation indices for this fold
226            let val_start = boundaries[fold];
227            let val_end = boundaries[fold + 1];
228            val_indices.extend_from_slice(&indices[val_start..val_end]);
229
230            // Get training indices (all others)
231            train_indices.extend_from_slice(&indices[..val_start]);
232            train_indices.extend_from_slice(&indices[val_end..]);
233        }
234
235        Ok((train_indices, val_indices))
236    }
237}
238
239impl CrossValidationSplit for StratifiedKFold {
240    fn num_splits(&self) -> usize {
241        self.n_splits
242    }
243
244    fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
245        // Default implementation: uniform distribution
246        // For actual stratification, use get_stratified_split with labels
247        let labels: Vec<usize> = (0..n_samples).map(|i| i % self.n_splits).collect();
248        self.get_stratified_split(fold, &labels)
249    }
250}
251
252/// Time series split for temporal data.
253///
254/// Respects the temporal order of data. Each training set consists of data
255/// before the validation set (no data leakage from future).
256#[derive(Debug, Clone)]
257pub struct TimeSeriesSplit {
258    /// Number of splits.
259    pub n_splits: usize,
260    /// Minimum training set size.
261    pub min_train_size: Option<usize>,
262    /// Maximum training set size (for sliding window).
263    pub max_train_size: Option<usize>,
264}
265
266impl TimeSeriesSplit {
267    /// Create a new time series split.
268    ///
269    /// # Arguments
270    /// * `n_splits` - Number of splits (must be >= 2)
271    pub fn new(n_splits: usize) -> TrainResult<Self> {
272        if n_splits < 2 {
273            return Err(TrainError::InvalidParameter(
274                "n_splits must be at least 2".to_string(),
275            ));
276        }
277        Ok(Self {
278            n_splits,
279            min_train_size: None,
280            max_train_size: None,
281        })
282    }
283
284    /// Set minimum training set size.
285    pub fn with_min_train_size(mut self, size: usize) -> Self {
286        self.min_train_size = Some(size);
287        self
288    }
289
290    /// Set maximum training set size (for sliding window).
291    pub fn with_max_train_size(mut self, size: usize) -> Self {
292        self.max_train_size = Some(size);
293        self
294    }
295}
296
297impl CrossValidationSplit for TimeSeriesSplit {
298    fn num_splits(&self) -> usize {
299        self.n_splits
300    }
301
302    fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
303        if fold >= self.n_splits {
304            return Err(TrainError::InvalidParameter(format!(
305                "fold {} is out of range [0, {})",
306                fold, self.n_splits
307            )));
308        }
309
310        // Compute validation set size
311        let test_size = n_samples / (self.n_splits + 1);
312        if test_size == 0 {
313            return Err(TrainError::InvalidParameter(
314                "Not enough samples for time series split".to_string(),
315            ));
316        }
317
318        // Validation set for this fold
319        let val_start = (fold + 1) * test_size;
320        let val_end = ((fold + 2) * test_size).min(n_samples);
321
322        // Training set: all data before validation
323        let train_end = val_start;
324        let train_start = if let Some(max_size) = self.max_train_size {
325            train_end.saturating_sub(max_size)
326        } else if let Some(min_size) = self.min_train_size {
327            if train_end < min_size {
328                return Err(TrainError::InvalidParameter(
329                    "Not enough samples for min_train_size".to_string(),
330                ));
331            }
332            0
333        } else {
334            0
335        };
336
337        let train_indices: Vec<usize> = (train_start..train_end).collect();
338        let val_indices: Vec<usize> = (val_start..val_end).collect();
339
340        if train_indices.is_empty() {
341            return Err(TrainError::InvalidParameter(
342                "Training set is empty for this fold".to_string(),
343            ));
344        }
345
346        Ok((train_indices, val_indices))
347    }
348}
349
350/// Leave-one-out cross-validation.
351///
352/// Each sample is used once as validation while all others form the training set.
353/// Useful for very small datasets but computationally expensive.
354#[derive(Debug, Clone, Default)]
355pub struct LeaveOneOut;
356
357impl LeaveOneOut {
358    /// Create a new leave-one-out splitter.
359    pub fn new() -> Self {
360        Self
361    }
362}
363
364impl CrossValidationSplit for LeaveOneOut {
365    fn num_splits(&self) -> usize {
366        // This is a placeholder; actual number depends on n_samples
367        usize::MAX
368    }
369
370    fn get_split(&self, fold: usize, n_samples: usize) -> TrainResult<(Vec<usize>, Vec<usize>)> {
371        if fold >= n_samples {
372            return Err(TrainError::InvalidParameter(format!(
373                "fold {} is out of range [0, {})",
374                fold, n_samples
375            )));
376        }
377
378        // Validation: single sample
379        let val_indices = vec![fold];
380
381        // Training: all other samples
382        let mut train_indices: Vec<usize> = (0..fold).collect();
383        train_indices.extend(fold + 1..n_samples);
384
385        Ok((train_indices, val_indices))
386    }
387}
388
389/// Cross-validation result aggregator.
390#[derive(Debug, Clone)]
391pub struct CrossValidationResults {
392    /// Scores for each fold.
393    pub fold_scores: Vec<f64>,
394    /// Additional metrics for each fold.
395    pub fold_metrics: Vec<HashMap<String, f64>>,
396}
397
398impl CrossValidationResults {
399    /// Create a new result aggregator.
400    pub fn new() -> Self {
401        Self {
402            fold_scores: Vec::new(),
403            fold_metrics: Vec::new(),
404        }
405    }
406
407    /// Add a fold result.
408    pub fn add_fold(&mut self, score: f64, metrics: HashMap<String, f64>) {
409        self.fold_scores.push(score);
410        self.fold_metrics.push(metrics);
411    }
412
413    /// Get mean score across all folds.
414    pub fn mean_score(&self) -> f64 {
415        if self.fold_scores.is_empty() {
416            return 0.0;
417        }
418        self.fold_scores.iter().sum::<f64>() / self.fold_scores.len() as f64
419    }
420
421    /// Get standard deviation of scores.
422    pub fn std_score(&self) -> f64 {
423        if self.fold_scores.len() <= 1 {
424            return 0.0;
425        }
426
427        let mean = self.mean_score();
428        let variance = self
429            .fold_scores
430            .iter()
431            .map(|&score| (score - mean).powi(2))
432            .sum::<f64>()
433            / (self.fold_scores.len() - 1) as f64;
434
435        variance.sqrt()
436    }
437
438    /// Get mean of a specific metric across all folds.
439    pub fn mean_metric(&self, metric_name: &str) -> Option<f64> {
440        if self.fold_metrics.is_empty() {
441            return None;
442        }
443
444        let mut sum = 0.0;
445        let mut count = 0;
446
447        for metrics in &self.fold_metrics {
448            if let Some(&value) = metrics.get(metric_name) {
449                sum += value;
450                count += 1;
451            }
452        }
453
454        if count > 0 {
455            Some(sum / count as f64)
456        } else {
457            None
458        }
459    }
460
461    /// Get number of folds.
462    pub fn num_folds(&self) -> usize {
463        self.fold_scores.len()
464    }
465}
466
467impl Default for CrossValidationResults {
468    fn default() -> Self {
469        Self::new()
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476
477    #[test]
478    fn test_kfold_basic() {
479        let kfold = KFold::new(3).expect("unwrap");
480        assert_eq!(kfold.num_splits(), 3);
481
482        let (train, val) = kfold.get_split(0, 10).expect("unwrap");
483        assert!(!train.is_empty());
484        assert!(!val.is_empty());
485
486        // Train and validation should be disjoint
487        for &idx in &val {
488            assert!(!train.contains(&idx));
489        }
490
491        // Together should cover all indices
492        let mut all_indices = train.clone();
493        all_indices.extend(&val);
494        all_indices.sort();
495        assert_eq!(all_indices, (0..10).collect::<Vec<_>>());
496    }
497
498    #[test]
499    fn test_kfold_with_shuffle() {
500        let kfold = KFold::new(3).expect("unwrap").with_shuffle(42);
501        let (train1, val1) = kfold.get_split(0, 10).expect("unwrap");
502        let (train2, val2) = kfold.get_split(0, 10).expect("unwrap");
503
504        // Same seed should produce same results
505        assert_eq!(train1, train2);
506        assert_eq!(val1, val2);
507    }
508
509    #[test]
510    fn test_kfold_invalid() {
511        assert!(KFold::new(1).is_err());
512        let kfold = KFold::new(3).expect("unwrap");
513        assert!(kfold.get_split(5, 10).is_err()); // fold out of range
514    }
515
516    #[test]
517    fn test_stratified_kfold() {
518        let skfold = StratifiedKFold::new(3).expect("unwrap");
519        assert_eq!(skfold.num_splits(), 3);
520
521        // Create balanced labels: [0, 0, 0, 1, 1, 1, 2, 2, 2]
522        let labels = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
523
524        let (_train, val) = skfold.get_stratified_split(0, &labels).expect("unwrap");
525
526        // Check that validation set has samples from each class
527        let mut val_classes: Vec<usize> = val.iter().map(|&i| labels[i]).collect();
528        val_classes.sort();
529        val_classes.dedup();
530
531        // Should have at least some class diversity
532        assert!(!val.is_empty());
533    }
534
535    #[test]
536    fn test_time_series_split() {
537        let ts_split = TimeSeriesSplit::new(3).expect("unwrap");
538        assert_eq!(ts_split.num_splits(), 3);
539
540        let (train, val) = ts_split.get_split(0, 10).expect("unwrap");
541
542        // Training indices should be before validation indices
543        if !train.is_empty() && !val.is_empty() {
544            assert!(train.iter().max().expect("unwrap") < val.iter().min().expect("unwrap"));
545        }
546    }
547
548    #[test]
549    fn test_time_series_split_with_window() {
550        let ts_split = TimeSeriesSplit::new(3)
551            .expect("unwrap")
552            .with_min_train_size(2)
553            .with_max_train_size(5);
554
555        let (train, val) = ts_split.get_split(1, 20).expect("unwrap");
556
557        // Training set should respect max size
558        assert!(train.len() <= 5);
559        assert!(!val.is_empty());
560    }
561
562    #[test]
563    fn test_time_series_split_invalid() {
564        let ts_split = TimeSeriesSplit::new(3).expect("unwrap");
565
566        // Too few samples
567        assert!(ts_split.get_split(0, 2).is_err());
568
569        // Fold out of range
570        assert!(ts_split.get_split(5, 10).is_err());
571    }
572
573    #[test]
574    fn test_leave_one_out() {
575        let loo = LeaveOneOut::new();
576
577        let (train, val) = loo.get_split(0, 5).expect("unwrap");
578
579        assert_eq!(val.len(), 1);
580        assert_eq!(train.len(), 4);
581        assert_eq!(val[0], 0);
582
583        let (train, val) = loo.get_split(3, 5).expect("unwrap");
584        assert_eq!(val[0], 3);
585        assert_eq!(train.len(), 4);
586    }
587
588    #[test]
589    fn test_leave_one_out_invalid() {
590        let loo = LeaveOneOut::new();
591        assert!(loo.get_split(5, 5).is_err()); // fold out of range
592    }
593
594    #[test]
595    fn test_cv_results() {
596        let mut results = CrossValidationResults::new();
597
598        let mut metrics1 = HashMap::new();
599        metrics1.insert("accuracy".to_string(), 0.9);
600        results.add_fold(0.85, metrics1);
601
602        let mut metrics2 = HashMap::new();
603        metrics2.insert("accuracy".to_string(), 0.95);
604        results.add_fold(0.90, metrics2);
605
606        let mut metrics3 = HashMap::new();
607        metrics3.insert("accuracy".to_string(), 0.92);
608        results.add_fold(0.88, metrics3);
609
610        assert_eq!(results.num_folds(), 3);
611
612        // Mean score: (0.85 + 0.90 + 0.88) / 3 = 0.876666...
613        let mean = results.mean_score();
614        assert!((mean - 0.8766666).abs() < 1e-6);
615
616        // Standard deviation
617        let std = results.std_score();
618        assert!(std > 0.0);
619
620        // Mean metric
621        let mean_acc = results.mean_metric("accuracy").expect("unwrap");
622        assert!((mean_acc - 0.923333).abs() < 1e-5);
623    }
624
625    #[test]
626    fn test_cv_results_empty() {
627        let results = CrossValidationResults::new();
628        assert_eq!(results.mean_score(), 0.0);
629        assert_eq!(results.std_score(), 0.0);
630        assert_eq!(results.num_folds(), 0);
631        assert!(results.mean_metric("accuracy").is_none());
632    }
633
634    #[test]
635    fn test_kfold_all_folds() {
636        let kfold = KFold::new(5).expect("unwrap");
637        let n_samples = 20;
638
639        let mut all_val_indices = Vec::new();
640
641        // Collect validation indices from all folds
642        for fold in 0..5 {
643            let (_, val) = kfold.get_split(fold, n_samples).expect("unwrap");
644            all_val_indices.extend(val);
645        }
646
647        all_val_indices.sort();
648
649        // All samples should appear exactly once in validation sets
650        assert_eq!(all_val_indices.len(), n_samples);
651        assert_eq!(all_val_indices, (0..n_samples).collect::<Vec<_>>());
652    }
653}