Skip to main content

scry_learn/
split.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Train/test splitting and cross-validation utilities.
3
4use crate::dataset::Dataset;
5use crate::error::Result;
6use crate::pipeline::PipelineModel;
7
8/// Scoring function signature: `(y_true, y_pred) -> score`.
9///
10/// Use `metrics::accuracy` or any `fn(&[f64], &[f64]) -> f64`.
11pub type ScoringFn = fn(&[f64], &[f64]) -> f64;
12
13/// Split a dataset into training and test sets.
14///
15/// `test_ratio` should be between 0.0 and 1.0 (e.g., 0.2 for 80/20 split).
16/// The `seed` controls the random shuffle for reproducibility.
17pub fn train_test_split(data: &Dataset, test_ratio: f64, seed: u64) -> (Dataset, Dataset) {
18    let n = data.n_samples();
19    let mut indices: Vec<usize> = (0..n).collect();
20    shuffle(&mut indices, seed);
21
22    let test_size = (n as f64 * test_ratio).round() as usize;
23    let test_size = test_size.max(1).min(n - 1);
24
25    let test_indices = &indices[..test_size];
26    let train_indices = &indices[test_size..];
27
28    (data.subset(train_indices), data.subset(test_indices))
29}
30
31/// Stratified train/test split — preserves class proportions.
32///
33/// Groups samples by target value and splits each group independently,
34/// ensuring the ratio of each class is maintained in both sets.
35pub fn stratified_split(data: &Dataset, test_ratio: f64, seed: u64) -> (Dataset, Dataset) {
36    let n = data.n_samples();
37
38    // Group indices by class.
39    let mut class_map: std::collections::HashMap<i64, Vec<usize>> =
40        std::collections::HashMap::new();
41    for i in 0..n {
42        let key = data.target[i] as i64;
43        class_map.entry(key).or_default().push(i);
44    }
45
46    let mut train_indices = Vec::new();
47    let mut test_indices = Vec::new();
48
49    // Sort class keys for deterministic iteration order.
50    let mut sorted_classes: Vec<i64> = class_map.keys().copied().collect();
51    sorted_classes.sort_unstable();
52
53    let mut rng = crate::rng::FastRng::new(seed);
54    for class in sorted_classes {
55        let mut indices = class_map
56            .remove(&class)
57            .expect("class key from sorted_classes must exist in class_map");
58        // Shuffle within each class.
59        for i in (1..indices.len()).rev() {
60            let j = rng.usize(0..=i);
61            indices.swap(i, j);
62        }
63        let test_n = (indices.len() as f64 * test_ratio).round() as usize;
64        let test_n = test_n.max(1).min(indices.len().saturating_sub(1));
65        test_indices.extend_from_slice(&indices[..test_n]);
66        train_indices.extend_from_slice(&indices[test_n..]);
67    }
68
69    (data.subset(&train_indices), data.subset(&test_indices))
70}
71
72/// K-fold cross-validation splits.
73///
74/// Returns `k` pairs of (train, test) datasets.
75pub fn k_fold(data: &Dataset, k: usize, seed: u64) -> Vec<(Dataset, Dataset)> {
76    let n = data.n_samples();
77    let mut indices: Vec<usize> = (0..n).collect();
78    shuffle(&mut indices, seed);
79
80    let fold_size = n / k;
81    let mut folds = Vec::with_capacity(k);
82
83    for i in 0..k {
84        let start = i * fold_size;
85        let end = if i == k - 1 { n } else { start + fold_size };
86        let test_indices: Vec<usize> = indices[start..end].to_vec();
87        let train_indices: Vec<usize> = indices[..start]
88            .iter()
89            .chain(indices[end..].iter())
90            .copied()
91            .collect();
92        folds.push((data.subset(&train_indices), data.subset(&test_indices)));
93    }
94
95    folds
96}
97
98/// Stratified k-fold cross-validation.
99pub fn stratified_k_fold(data: &Dataset, k: usize, seed: u64) -> Vec<(Dataset, Dataset)> {
100    let n = data.n_samples();
101
102    // Group by class and shuffle within each class.
103    let mut class_map: std::collections::HashMap<i64, Vec<usize>> =
104        std::collections::HashMap::new();
105    for i in 0..n {
106        let key = data.target[i] as i64;
107        class_map.entry(key).or_default().push(i);
108    }
109
110    // Sort class keys for deterministic iteration order.
111    let mut sorted_classes: Vec<i64> = class_map.keys().copied().collect();
112    sorted_classes.sort_unstable();
113
114    let mut rng = crate::rng::FastRng::new(seed);
115    for class in &sorted_classes {
116        let indices = class_map
117            .get_mut(class)
118            .expect("class key from sorted_classes must exist in class_map");
119        for i in (1..indices.len()).rev() {
120            let j = rng.usize(0..=i);
121            indices.swap(i, j);
122        }
123    }
124
125    // Round-robin assign samples to folds.
126    let mut fold_indices: Vec<Vec<usize>> = vec![Vec::new(); k];
127    for class in &sorted_classes {
128        let indices = &class_map[class];
129        for (i, &idx) in indices.iter().enumerate() {
130            fold_indices[i % k].push(idx);
131        }
132    }
133
134    let mut folds = Vec::with_capacity(k);
135    let all_indices: Vec<usize> = (0..n).collect();
136
137    for fold in &fold_indices {
138        let test_set: std::collections::HashSet<usize> = fold.iter().copied().collect();
139        let train: Vec<usize> = all_indices
140            .iter()
141            .filter(|i| !test_set.contains(i))
142            .copied()
143            .collect();
144        folds.push((data.subset(&train), data.subset(fold)));
145    }
146
147    folds
148}
149
150// ---------------------------------------------------------------------------
151// Cross-validation scoring
152// ---------------------------------------------------------------------------
153
154/// Run k-fold cross-validation, returning per-fold scores.
155///
156/// Clones the model for each fold, fits on the training split, predicts on
157/// the test split, and computes `scorer(y_true, y_pred)`.
158///
159/// ```ignore
160/// use scry_learn::prelude::*;
161/// use scry_learn::split::{cross_val_score, ScoringFn};
162///
163/// let scores = cross_val_score(
164///     &DecisionTreeClassifier::new(),
165///     &data, 5, accuracy as ScoringFn, 42,
166/// ).unwrap();
167/// ```
168pub fn cross_val_score<M: PipelineModel + Clone + Send + Sync>(
169    model: &M,
170    data: &Dataset,
171    k: usize,
172    scorer: ScoringFn,
173    seed: u64,
174) -> Result<Vec<f64>> {
175    let folds = k_fold(data, k, seed);
176    run_cv(model, &folds, scorer)
177}
178
179/// Stratified k-fold cross-validation — preserves class balance in each fold.
180pub fn cross_val_score_stratified<M: PipelineModel + Clone + Send + Sync>(
181    model: &M,
182    data: &Dataset,
183    k: usize,
184    scorer: ScoringFn,
185    seed: u64,
186) -> Result<Vec<f64>> {
187    let folds = stratified_k_fold(data, k, seed);
188    run_cv(model, &folds, scorer)
189}
190
191/// Shared implementation: fit + predict + score for each fold.
192///
193/// Folds are evaluated in parallel using rayon when multiple cores are
194/// available. Each fold clones the model independently.
195fn run_cv<M: PipelineModel + Clone + Send + Sync>(
196    model: &M,
197    folds: &[(Dataset, Dataset)],
198    scorer: ScoringFn,
199) -> Result<Vec<f64>> {
200    use rayon::prelude::*;
201
202    let results: Vec<Result<f64>> = folds
203        .par_iter()
204        .map(|(train, test)| {
205            let mut m = model.clone();
206            m.fit(train)?;
207            let features = test.feature_matrix();
208            let preds = m.predict(&features)?;
209            Ok(scorer(&test.target, &preds))
210        })
211        .collect();
212
213    // Collect results, propagating the first error if any fold failed.
214    results.into_iter().collect()
215}
216
217/// Fisher-Yates shuffle with a seeded RNG.
218fn shuffle(arr: &mut [usize], seed: u64) {
219    let mut rng = crate::rng::FastRng::new(seed);
220    for i in (1..arr.len()).rev() {
221        let j = rng.usize(0..=i);
222        arr.swap(i, j);
223    }
224}
225
226// ---------------------------------------------------------------------------
227// RepeatedKFold
228// ---------------------------------------------------------------------------
229
230/// Repeated k-fold cross-validation.
231///
232/// Repeats standard k-fold `n_repeats` times, each with a different random
233/// shuffle (using `seed + repeat_idx`), yielding `n_splits × n_repeats` folds.
234///
235/// # Example
236///
237/// ```ignore
238/// let rkf = RepeatedKFold::new(5, 3, 42);
239/// let folds = rkf.folds(&data); // 15 (train, test) pairs
240/// ```
241#[derive(Clone, Debug)]
242#[non_exhaustive]
243pub struct RepeatedKFold {
244    /// Number of folds per repetition.
245    pub n_splits: usize,
246    /// Number of repetitions.
247    pub n_repeats: usize,
248    /// Base random seed.
249    pub seed: u64,
250}
251
252impl RepeatedKFold {
253    /// Create a new `RepeatedKFold` splitter.
254    pub fn new(n_splits: usize, n_repeats: usize, seed: u64) -> Self {
255        Self {
256            n_splits,
257            n_repeats,
258            seed,
259        }
260    }
261
262    /// Generate all `n_splits × n_repeats` (train, test) pairs.
263    pub fn folds(&self, data: &Dataset) -> Vec<(Dataset, Dataset)> {
264        let mut all_folds = Vec::with_capacity(self.n_splits * self.n_repeats);
265        for rep in 0..self.n_repeats {
266            let rep_seed = self.seed.wrapping_add(rep as u64);
267            all_folds.extend(k_fold(data, self.n_splits, rep_seed));
268        }
269        all_folds
270    }
271}
272
273/// Convenience: run repeated k-fold CV on a clonable model.
274///
275/// Returns per-fold scores across all `n_splits × n_repeats` folds.
276pub fn repeated_cross_val_score<M: PipelineModel + Clone + Send + Sync>(
277    model: &M,
278    data: &Dataset,
279    n_splits: usize,
280    n_repeats: usize,
281    scorer: ScoringFn,
282    seed: u64,
283) -> Result<Vec<f64>> {
284    let rkf = RepeatedKFold::new(n_splits, n_repeats, seed);
285    let folds = rkf.folds(data);
286    run_cv(model, &folds, scorer)
287}
288
289// ---------------------------------------------------------------------------
290// GroupKFold
291// ---------------------------------------------------------------------------
292
293/// Group-aware k-fold: no group appears in both train and test within a fold.
294///
295/// Groups are assigned to folds round-robin by unique group index. This
296/// prevents data leakage when samples from the same group (e.g. patient)
297/// are correlated.
298///
299/// # Arguments
300///
301/// * `data`   — the dataset to split
302/// * `groups` — group label per sample (length must equal `data.n_samples()`)
303/// * `k`      — number of folds
304///
305/// # Panics
306///
307/// Panics if `groups.len() != data.n_samples()`.
308pub fn group_k_fold(data: &Dataset, groups: &[usize], k: usize) -> Vec<(Dataset, Dataset)> {
309    assert_eq!(
310        groups.len(),
311        data.n_samples(),
312        "groups length must match n_samples"
313    );
314
315    // Collect unique groups in order of first appearance.
316    let mut unique_groups: Vec<usize> = Vec::new();
317    for &g in groups {
318        if !unique_groups.contains(&g) {
319            unique_groups.push(g);
320        }
321    }
322
323    // Assign each group to a fold (round-robin).
324    let mut group_to_fold = std::collections::HashMap::new();
325    for (i, &g) in unique_groups.iter().enumerate() {
326        group_to_fold.insert(g, i % k);
327    }
328
329    let mut folds = Vec::with_capacity(k);
330    for fold_idx in 0..k {
331        let mut test_indices = Vec::new();
332        let mut train_indices = Vec::new();
333        for (sample_idx, &g) in groups.iter().enumerate() {
334            if group_to_fold[&g] == fold_idx {
335                test_indices.push(sample_idx);
336            } else {
337                train_indices.push(sample_idx);
338            }
339        }
340        folds.push((data.subset(&train_indices), data.subset(&test_indices)));
341    }
342
343    folds
344}
345
346// ---------------------------------------------------------------------------
347// TimeSeriesSplit
348// ---------------------------------------------------------------------------
349
350/// Time-series cross-validation with expanding training window.
351///
352/// Produces `n_splits` folds where each test set immediately follows its
353/// training set. Data order is preserved — no shuffling.
354///
355/// For `n_splits` folds and `n` samples, each test chunk has size
356/// `n / (n_splits + 1)`. Fold *i* trains on `[0 .. (i+1)*chunk]` and tests
357/// on `[(i+1)*chunk .. (i+2)*chunk]`.
358pub fn time_series_split(data: &Dataset, n_splits: usize) -> Vec<(Dataset, Dataset)> {
359    let n = data.n_samples();
360    let chunk = n / (n_splits + 1);
361    let mut folds = Vec::with_capacity(n_splits);
362
363    for i in 0..n_splits {
364        let train_end = (i + 1) * chunk;
365        let test_end = if i == n_splits - 1 {
366            n
367        } else {
368            (i + 2) * chunk
369        };
370        let train_indices: Vec<usize> = (0..train_end).collect();
371        let test_indices: Vec<usize> = (train_end..test_end).collect();
372        folds.push((data.subset(&train_indices), data.subset(&test_indices)));
373    }
374
375    folds
376}
377
378// ---------------------------------------------------------------------------
379// cross_val_predict
380// ---------------------------------------------------------------------------
381
382/// Out-of-fold predictions for every sample.
383///
384/// Trains on k-1 folds, predicts the held-out fold, and reassembles
385/// predictions in the original sample order. The returned vector has length
386/// `data.n_samples()`.
387///
388/// # Example
389///
390/// ```ignore
391/// let preds = cross_val_predict(&model, &data, 5, 42)?;
392/// assert_eq!(preds.len(), data.n_samples());
393/// ```
394pub fn cross_val_predict<M: PipelineModel + Clone>(
395    model: &M,
396    data: &Dataset,
397    k: usize,
398    seed: u64,
399) -> Result<Vec<f64>> {
400    let n = data.n_samples();
401    let mut indices_all: Vec<usize> = (0..n).collect();
402    shuffle(&mut indices_all, seed);
403
404    let fold_size = n / k;
405    let mut predictions = vec![0.0; n];
406
407    for i in 0..k {
408        let start = i * fold_size;
409        let end = if i == k - 1 { n } else { start + fold_size };
410
411        let test_indices: Vec<usize> = indices_all[start..end].to_vec();
412        let train_indices: Vec<usize> = indices_all[..start]
413            .iter()
414            .chain(indices_all[end..].iter())
415            .copied()
416            .collect();
417
418        let train = data.subset(&train_indices);
419        let test = data.subset(&test_indices);
420
421        let mut m = model.clone();
422        m.fit(&train)?;
423        let features = test.feature_matrix();
424        let preds = m.predict(&features)?;
425
426        for (j, &idx) in test_indices.iter().enumerate() {
427            predictions[idx] = preds[j];
428        }
429    }
430
431    Ok(predictions)
432}
433
434#[cfg(test)]
435#[allow(clippy::float_cmp)]
436mod tests {
437    use super::*;
438    use crate::metrics::accuracy;
439    use crate::tree::DecisionTreeClassifier;
440
441    fn dummy_dataset(n: usize) -> Dataset {
442        let features = vec![(0..n).map(|i| i as f64).collect()];
443        let target = (0..n).map(|i| (i % 3) as f64).collect();
444        Dataset::new(features, target, vec!["x".into()], "y")
445    }
446
447    /// A well-separated 2-class dataset for reliable CV testing.
448    fn separable_dataset() -> Dataset {
449        let n = 60;
450        let mut f0 = Vec::with_capacity(n);
451        let mut f1 = Vec::with_capacity(n);
452        let mut target = Vec::with_capacity(n);
453        for i in 0..n {
454            if i < n / 2 {
455                f0.push(i as f64);
456                f1.push(i as f64);
457                target.push(0.0);
458            } else {
459                f0.push((i + 100) as f64);
460                f1.push((i + 100) as f64);
461                target.push(1.0);
462            }
463        }
464        Dataset::new(vec![f0, f1], target, vec!["x".into(), "y".into()], "class")
465    }
466
467    #[test]
468    fn test_train_test_split_sizes() {
469        let ds = dummy_dataset(100);
470        let (train, test) = train_test_split(&ds, 0.2, 42);
471        assert_eq!(train.n_samples() + test.n_samples(), 100);
472        assert_eq!(test.n_samples(), 20);
473    }
474
475    #[test]
476    fn test_stratified_split_preserves_ratio() {
477        let ds = dummy_dataset(90); // 30 each of class 0, 1, 2
478        let (train, test) = stratified_split(&ds, 0.2, 42);
479        assert_eq!(train.n_samples() + test.n_samples(), 90);
480
481        let test_class_0 = test.target.iter().filter(|&&v| v == 0.0).count();
482        let test_class_1 = test.target.iter().filter(|&&v| v == 1.0).count();
483        let test_class_2 = test.target.iter().filter(|&&v| v == 2.0).count();
484        assert!((4..=8).contains(&test_class_0));
485        assert!((4..=8).contains(&test_class_1));
486        assert!((4..=8).contains(&test_class_2));
487    }
488
489    #[test]
490    fn test_k_fold_count() {
491        let ds = dummy_dataset(50);
492        let folds = k_fold(&ds, 5, 42);
493        assert_eq!(folds.len(), 5);
494        for (train, test) in &folds {
495            assert_eq!(train.n_samples() + test.n_samples(), 50);
496        }
497    }
498
499    // -----------------------------------------------------------------------
500    // Cross-validation scorer tests
501    // -----------------------------------------------------------------------
502
503    #[test]
504    fn test_cross_val_score_dt() {
505        let ds = separable_dataset();
506        let model = DecisionTreeClassifier::new();
507        let scores = cross_val_score(&model, &ds, 5, accuracy, 42).unwrap();
508        assert_eq!(scores.len(), 5);
509        for &s in &scores {
510            assert!(s >= 0.8, "fold accuracy {s} < 0.8 on well-separated data");
511        }
512    }
513
514    #[test]
515    fn test_cross_val_score_stratified() {
516        let ds = separable_dataset();
517        let model = DecisionTreeClassifier::new();
518        let scores = cross_val_score_stratified(&model, &ds, 5, accuracy, 42).unwrap();
519        assert_eq!(scores.len(), 5);
520        for &s in &scores {
521            assert!(s >= 0.8, "stratified fold accuracy {s} < 0.8");
522        }
523    }
524
525    #[test]
526    fn test_cross_val_score_leave_one_out() {
527        // k = n for leave-one-out cross-validation.
528        let ds = separable_dataset();
529        let n = ds.n_samples();
530        let model = DecisionTreeClassifier::new();
531        let scores = cross_val_score(&model, &ds, n, accuracy, 42).unwrap();
532        assert_eq!(scores.len(), n);
533        // Each fold has 1 test sample, so score is 0.0 or 1.0.
534        for &s in &scores {
535            assert!(s == 0.0 || s == 1.0);
536        }
537    }
538
539    #[test]
540    fn test_cross_val_score_custom_scorer() {
541        fn always_one(_true: &[f64], _pred: &[f64]) -> f64 {
542            1.0
543        }
544        let ds = separable_dataset();
545        let model = DecisionTreeClassifier::new();
546        let scores = cross_val_score(&model, &ds, 3, always_one, 42).unwrap();
547        assert!(scores.iter().all(|&s| (s - 1.0).abs() < 1e-10));
548    }
549
550    // -----------------------------------------------------------------------
551    // Session 15: New CV strategies
552    // -----------------------------------------------------------------------
553
554    #[test]
555    fn test_repeated_k_fold_count() {
556        let ds = dummy_dataset(50);
557        let rkf = RepeatedKFold::new(5, 3, 42);
558        let folds = rkf.folds(&ds);
559        assert_eq!(folds.len(), 15);
560        for (train, test) in &folds {
561            assert_eq!(train.n_samples() + test.n_samples(), 50);
562            assert!(!test.target.is_empty(), "test fold must not be empty");
563        }
564    }
565
566    #[test]
567    fn test_repeated_cross_val_score() {
568        let ds = separable_dataset();
569        let model = DecisionTreeClassifier::new();
570        let scores = repeated_cross_val_score(&model, &ds, 5, 3, accuracy, 42).unwrap();
571        assert_eq!(scores.len(), 15);
572        for &s in &scores {
573            assert!(s >= 0.5, "repeated CV fold accuracy {s} too low");
574        }
575    }
576
577    #[test]
578    fn test_group_k_fold_no_leakage() {
579        let ds = dummy_dataset(12);
580        // 3 groups: 0,0,0,0, 1,1,1,1, 2,2,2,2
581        let groups = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
582        let folds = group_k_fold(&ds, &groups, 3);
583        assert_eq!(folds.len(), 3);
584
585        for (train, test) in &folds {
586            assert_eq!(train.n_samples() + test.n_samples(), 12);
587            // Each fold's test set should have exactly 4 samples (one group).
588            assert_eq!(test.n_samples(), 4);
589        }
590    }
591
592    #[test]
593    fn test_group_k_fold_group_isolation() {
594        // Verify no group appears in both train and test.
595        let n = 15;
596        let ds = dummy_dataset(n);
597        let groups: Vec<usize> = (0..n).map(|i| i / 3).collect(); // 5 groups of 3
598        let folds = group_k_fold(&ds, &groups, 3);
599
600        for (fold_idx, (_train, test)) in folds.iter().enumerate() {
601            // Reconstruct test indices from target values (unique due to dummy_dataset)
602            // Just verify sizes are correct.
603            assert!(!test.target.is_empty(), "fold {fold_idx} test set is empty");
604        }
605    }
606
607    #[test]
608    fn test_time_series_split_temporal_order() {
609        let n = 24;
610        let ds = dummy_dataset(n);
611        let folds = time_series_split(&ds, 3);
612        assert_eq!(folds.len(), 3);
613
614        // Expanding window: each successive fold has a larger training set.
615        let mut prev_train_size = 0;
616        for (train, test) in &folds {
617            assert!(
618                train.n_samples() > prev_train_size,
619                "training size should grow"
620            );
621            prev_train_size = train.n_samples();
622            assert!(!test.target.is_empty(), "test fold must not be empty");
623        }
624    }
625
626    #[test]
627    fn test_time_series_split_no_future_leak() {
628        let n = 20;
629        let features = vec![(0..n).map(|i| i as f64).collect::<Vec<_>>()];
630        let target = (0..n).map(|i| i as f64).collect();
631        let ds = Dataset::new(features, target, vec!["t".into()], "y");
632
633        let folds = time_series_split(&ds, 4);
634        for (train, test) in &folds {
635            let train_max = train.features[0]
636                .iter()
637                .copied()
638                .fold(f64::NEG_INFINITY, f64::max);
639            let test_min = test.features[0]
640                .iter()
641                .copied()
642                .fold(f64::INFINITY, f64::min);
643            assert!(
644                train_max < test_min,
645                "train max {train_max} must be < test min {test_min}"
646            );
647        }
648    }
649
650    #[test]
651    fn test_cross_val_predict_length() {
652        let ds = separable_dataset();
653        let model = DecisionTreeClassifier::new();
654        let preds = cross_val_predict(&model, &ds, 5, 42).unwrap();
655        assert_eq!(preds.len(), ds.n_samples());
656    }
657
658    #[test]
659    fn test_cross_val_predict_reasonable_accuracy() {
660        let ds = separable_dataset();
661        let model = DecisionTreeClassifier::new();
662        let preds = cross_val_predict(&model, &ds, 5, 42).unwrap();
663        let acc = accuracy(&ds.target, &preds);
664        assert!(
665            acc >= 0.8,
666            "cross_val_predict accuracy {acc} too low on separable data"
667        );
668    }
669}