1use crate::dataset::Dataset;
5use crate::error::Result;
6use crate::pipeline::PipelineModel;
7
8pub type ScoringFn = fn(&[f64], &[f64]) -> f64;
12
13pub fn train_test_split(data: &Dataset, test_ratio: f64, seed: u64) -> (Dataset, Dataset) {
18 let n = data.n_samples();
19 let mut indices: Vec<usize> = (0..n).collect();
20 shuffle(&mut indices, seed);
21
22 let test_size = (n as f64 * test_ratio).round() as usize;
23 let test_size = test_size.max(1).min(n - 1);
24
25 let test_indices = &indices[..test_size];
26 let train_indices = &indices[test_size..];
27
28 (data.subset(train_indices), data.subset(test_indices))
29}
30
31pub fn stratified_split(data: &Dataset, test_ratio: f64, seed: u64) -> (Dataset, Dataset) {
36 let n = data.n_samples();
37
38 let mut class_map: std::collections::HashMap<i64, Vec<usize>> =
40 std::collections::HashMap::new();
41 for i in 0..n {
42 let key = data.target[i] as i64;
43 class_map.entry(key).or_default().push(i);
44 }
45
46 let mut train_indices = Vec::new();
47 let mut test_indices = Vec::new();
48
49 let mut sorted_classes: Vec<i64> = class_map.keys().copied().collect();
51 sorted_classes.sort_unstable();
52
53 let mut rng = crate::rng::FastRng::new(seed);
54 for class in sorted_classes {
55 let mut indices = class_map
56 .remove(&class)
57 .expect("class key from sorted_classes must exist in class_map");
58 for i in (1..indices.len()).rev() {
60 let j = rng.usize(0..=i);
61 indices.swap(i, j);
62 }
63 let test_n = (indices.len() as f64 * test_ratio).round() as usize;
64 let test_n = test_n.max(1).min(indices.len().saturating_sub(1));
65 test_indices.extend_from_slice(&indices[..test_n]);
66 train_indices.extend_from_slice(&indices[test_n..]);
67 }
68
69 (data.subset(&train_indices), data.subset(&test_indices))
70}
71
72pub fn k_fold(data: &Dataset, k: usize, seed: u64) -> Vec<(Dataset, Dataset)> {
76 let n = data.n_samples();
77 let mut indices: Vec<usize> = (0..n).collect();
78 shuffle(&mut indices, seed);
79
80 let fold_size = n / k;
81 let mut folds = Vec::with_capacity(k);
82
83 for i in 0..k {
84 let start = i * fold_size;
85 let end = if i == k - 1 { n } else { start + fold_size };
86 let test_indices: Vec<usize> = indices[start..end].to_vec();
87 let train_indices: Vec<usize> = indices[..start]
88 .iter()
89 .chain(indices[end..].iter())
90 .copied()
91 .collect();
92 folds.push((data.subset(&train_indices), data.subset(&test_indices)));
93 }
94
95 folds
96}
97
98pub fn stratified_k_fold(data: &Dataset, k: usize, seed: u64) -> Vec<(Dataset, Dataset)> {
100 let n = data.n_samples();
101
102 let mut class_map: std::collections::HashMap<i64, Vec<usize>> =
104 std::collections::HashMap::new();
105 for i in 0..n {
106 let key = data.target[i] as i64;
107 class_map.entry(key).or_default().push(i);
108 }
109
110 let mut sorted_classes: Vec<i64> = class_map.keys().copied().collect();
112 sorted_classes.sort_unstable();
113
114 let mut rng = crate::rng::FastRng::new(seed);
115 for class in &sorted_classes {
116 let indices = class_map
117 .get_mut(class)
118 .expect("class key from sorted_classes must exist in class_map");
119 for i in (1..indices.len()).rev() {
120 let j = rng.usize(0..=i);
121 indices.swap(i, j);
122 }
123 }
124
125 let mut fold_indices: Vec<Vec<usize>> = vec![Vec::new(); k];
127 for class in &sorted_classes {
128 let indices = &class_map[class];
129 for (i, &idx) in indices.iter().enumerate() {
130 fold_indices[i % k].push(idx);
131 }
132 }
133
134 let mut folds = Vec::with_capacity(k);
135 let all_indices: Vec<usize> = (0..n).collect();
136
137 for fold in &fold_indices {
138 let test_set: std::collections::HashSet<usize> = fold.iter().copied().collect();
139 let train: Vec<usize> = all_indices
140 .iter()
141 .filter(|i| !test_set.contains(i))
142 .copied()
143 .collect();
144 folds.push((data.subset(&train), data.subset(fold)));
145 }
146
147 folds
148}
149
150pub fn cross_val_score<M: PipelineModel + Clone + Send + Sync>(
169 model: &M,
170 data: &Dataset,
171 k: usize,
172 scorer: ScoringFn,
173 seed: u64,
174) -> Result<Vec<f64>> {
175 let folds = k_fold(data, k, seed);
176 run_cv(model, &folds, scorer)
177}
178
179pub fn cross_val_score_stratified<M: PipelineModel + Clone + Send + Sync>(
181 model: &M,
182 data: &Dataset,
183 k: usize,
184 scorer: ScoringFn,
185 seed: u64,
186) -> Result<Vec<f64>> {
187 let folds = stratified_k_fold(data, k, seed);
188 run_cv(model, &folds, scorer)
189}
190
191fn run_cv<M: PipelineModel + Clone + Send + Sync>(
196 model: &M,
197 folds: &[(Dataset, Dataset)],
198 scorer: ScoringFn,
199) -> Result<Vec<f64>> {
200 use rayon::prelude::*;
201
202 let results: Vec<Result<f64>> = folds
203 .par_iter()
204 .map(|(train, test)| {
205 let mut m = model.clone();
206 m.fit(train)?;
207 let features = test.feature_matrix();
208 let preds = m.predict(&features)?;
209 Ok(scorer(&test.target, &preds))
210 })
211 .collect();
212
213 results.into_iter().collect()
215}
216
217fn shuffle(arr: &mut [usize], seed: u64) {
219 let mut rng = crate::rng::FastRng::new(seed);
220 for i in (1..arr.len()).rev() {
221 let j = rng.usize(0..=i);
222 arr.swap(i, j);
223 }
224}
225
226#[derive(Clone, Debug)]
242#[non_exhaustive]
243pub struct RepeatedKFold {
244 pub n_splits: usize,
246 pub n_repeats: usize,
248 pub seed: u64,
250}
251
252impl RepeatedKFold {
253 pub fn new(n_splits: usize, n_repeats: usize, seed: u64) -> Self {
255 Self {
256 n_splits,
257 n_repeats,
258 seed,
259 }
260 }
261
262 pub fn folds(&self, data: &Dataset) -> Vec<(Dataset, Dataset)> {
264 let mut all_folds = Vec::with_capacity(self.n_splits * self.n_repeats);
265 for rep in 0..self.n_repeats {
266 let rep_seed = self.seed.wrapping_add(rep as u64);
267 all_folds.extend(k_fold(data, self.n_splits, rep_seed));
268 }
269 all_folds
270 }
271}
272
273pub fn repeated_cross_val_score<M: PipelineModel + Clone + Send + Sync>(
277 model: &M,
278 data: &Dataset,
279 n_splits: usize,
280 n_repeats: usize,
281 scorer: ScoringFn,
282 seed: u64,
283) -> Result<Vec<f64>> {
284 let rkf = RepeatedKFold::new(n_splits, n_repeats, seed);
285 let folds = rkf.folds(data);
286 run_cv(model, &folds, scorer)
287}
288
289pub fn group_k_fold(data: &Dataset, groups: &[usize], k: usize) -> Vec<(Dataset, Dataset)> {
309 assert_eq!(
310 groups.len(),
311 data.n_samples(),
312 "groups length must match n_samples"
313 );
314
315 let mut unique_groups: Vec<usize> = Vec::new();
317 for &g in groups {
318 if !unique_groups.contains(&g) {
319 unique_groups.push(g);
320 }
321 }
322
323 let mut group_to_fold = std::collections::HashMap::new();
325 for (i, &g) in unique_groups.iter().enumerate() {
326 group_to_fold.insert(g, i % k);
327 }
328
329 let mut folds = Vec::with_capacity(k);
330 for fold_idx in 0..k {
331 let mut test_indices = Vec::new();
332 let mut train_indices = Vec::new();
333 for (sample_idx, &g) in groups.iter().enumerate() {
334 if group_to_fold[&g] == fold_idx {
335 test_indices.push(sample_idx);
336 } else {
337 train_indices.push(sample_idx);
338 }
339 }
340 folds.push((data.subset(&train_indices), data.subset(&test_indices)));
341 }
342
343 folds
344}
345
346pub fn time_series_split(data: &Dataset, n_splits: usize) -> Vec<(Dataset, Dataset)> {
359 let n = data.n_samples();
360 let chunk = n / (n_splits + 1);
361 let mut folds = Vec::with_capacity(n_splits);
362
363 for i in 0..n_splits {
364 let train_end = (i + 1) * chunk;
365 let test_end = if i == n_splits - 1 {
366 n
367 } else {
368 (i + 2) * chunk
369 };
370 let train_indices: Vec<usize> = (0..train_end).collect();
371 let test_indices: Vec<usize> = (train_end..test_end).collect();
372 folds.push((data.subset(&train_indices), data.subset(&test_indices)));
373 }
374
375 folds
376}
377
378pub fn cross_val_predict<M: PipelineModel + Clone>(
395 model: &M,
396 data: &Dataset,
397 k: usize,
398 seed: u64,
399) -> Result<Vec<f64>> {
400 let n = data.n_samples();
401 let mut indices_all: Vec<usize> = (0..n).collect();
402 shuffle(&mut indices_all, seed);
403
404 let fold_size = n / k;
405 let mut predictions = vec![0.0; n];
406
407 for i in 0..k {
408 let start = i * fold_size;
409 let end = if i == k - 1 { n } else { start + fold_size };
410
411 let test_indices: Vec<usize> = indices_all[start..end].to_vec();
412 let train_indices: Vec<usize> = indices_all[..start]
413 .iter()
414 .chain(indices_all[end..].iter())
415 .copied()
416 .collect();
417
418 let train = data.subset(&train_indices);
419 let test = data.subset(&test_indices);
420
421 let mut m = model.clone();
422 m.fit(&train)?;
423 let features = test.feature_matrix();
424 let preds = m.predict(&features)?;
425
426 for (j, &idx) in test_indices.iter().enumerate() {
427 predictions[idx] = preds[j];
428 }
429 }
430
431 Ok(predictions)
432}
433
434#[cfg(test)]
435#[allow(clippy::float_cmp)]
436mod tests {
437 use super::*;
438 use crate::metrics::accuracy;
439 use crate::tree::DecisionTreeClassifier;
440
441 fn dummy_dataset(n: usize) -> Dataset {
442 let features = vec![(0..n).map(|i| i as f64).collect()];
443 let target = (0..n).map(|i| (i % 3) as f64).collect();
444 Dataset::new(features, target, vec!["x".into()], "y")
445 }
446
447 fn separable_dataset() -> Dataset {
449 let n = 60;
450 let mut f0 = Vec::with_capacity(n);
451 let mut f1 = Vec::with_capacity(n);
452 let mut target = Vec::with_capacity(n);
453 for i in 0..n {
454 if i < n / 2 {
455 f0.push(i as f64);
456 f1.push(i as f64);
457 target.push(0.0);
458 } else {
459 f0.push((i + 100) as f64);
460 f1.push((i + 100) as f64);
461 target.push(1.0);
462 }
463 }
464 Dataset::new(vec![f0, f1], target, vec!["x".into(), "y".into()], "class")
465 }
466
467 #[test]
468 fn test_train_test_split_sizes() {
469 let ds = dummy_dataset(100);
470 let (train, test) = train_test_split(&ds, 0.2, 42);
471 assert_eq!(train.n_samples() + test.n_samples(), 100);
472 assert_eq!(test.n_samples(), 20);
473 }
474
475 #[test]
476 fn test_stratified_split_preserves_ratio() {
477 let ds = dummy_dataset(90); let (train, test) = stratified_split(&ds, 0.2, 42);
479 assert_eq!(train.n_samples() + test.n_samples(), 90);
480
481 let test_class_0 = test.target.iter().filter(|&&v| v == 0.0).count();
482 let test_class_1 = test.target.iter().filter(|&&v| v == 1.0).count();
483 let test_class_2 = test.target.iter().filter(|&&v| v == 2.0).count();
484 assert!((4..=8).contains(&test_class_0));
485 assert!((4..=8).contains(&test_class_1));
486 assert!((4..=8).contains(&test_class_2));
487 }
488
489 #[test]
490 fn test_k_fold_count() {
491 let ds = dummy_dataset(50);
492 let folds = k_fold(&ds, 5, 42);
493 assert_eq!(folds.len(), 5);
494 for (train, test) in &folds {
495 assert_eq!(train.n_samples() + test.n_samples(), 50);
496 }
497 }
498
499 #[test]
504 fn test_cross_val_score_dt() {
505 let ds = separable_dataset();
506 let model = DecisionTreeClassifier::new();
507 let scores = cross_val_score(&model, &ds, 5, accuracy, 42).unwrap();
508 assert_eq!(scores.len(), 5);
509 for &s in &scores {
510 assert!(s >= 0.8, "fold accuracy {s} < 0.8 on well-separated data");
511 }
512 }
513
514 #[test]
515 fn test_cross_val_score_stratified() {
516 let ds = separable_dataset();
517 let model = DecisionTreeClassifier::new();
518 let scores = cross_val_score_stratified(&model, &ds, 5, accuracy, 42).unwrap();
519 assert_eq!(scores.len(), 5);
520 for &s in &scores {
521 assert!(s >= 0.8, "stratified fold accuracy {s} < 0.8");
522 }
523 }
524
525 #[test]
526 fn test_cross_val_score_leave_one_out() {
527 let ds = separable_dataset();
529 let n = ds.n_samples();
530 let model = DecisionTreeClassifier::new();
531 let scores = cross_val_score(&model, &ds, n, accuracy, 42).unwrap();
532 assert_eq!(scores.len(), n);
533 for &s in &scores {
535 assert!(s == 0.0 || s == 1.0);
536 }
537 }
538
539 #[test]
540 fn test_cross_val_score_custom_scorer() {
541 fn always_one(_true: &[f64], _pred: &[f64]) -> f64 {
542 1.0
543 }
544 let ds = separable_dataset();
545 let model = DecisionTreeClassifier::new();
546 let scores = cross_val_score(&model, &ds, 3, always_one, 42).unwrap();
547 assert!(scores.iter().all(|&s| (s - 1.0).abs() < 1e-10));
548 }
549
550 #[test]
555 fn test_repeated_k_fold_count() {
556 let ds = dummy_dataset(50);
557 let rkf = RepeatedKFold::new(5, 3, 42);
558 let folds = rkf.folds(&ds);
559 assert_eq!(folds.len(), 15);
560 for (train, test) in &folds {
561 assert_eq!(train.n_samples() + test.n_samples(), 50);
562 assert!(!test.target.is_empty(), "test fold must not be empty");
563 }
564 }
565
566 #[test]
567 fn test_repeated_cross_val_score() {
568 let ds = separable_dataset();
569 let model = DecisionTreeClassifier::new();
570 let scores = repeated_cross_val_score(&model, &ds, 5, 3, accuracy, 42).unwrap();
571 assert_eq!(scores.len(), 15);
572 for &s in &scores {
573 assert!(s >= 0.5, "repeated CV fold accuracy {s} too low");
574 }
575 }
576
577 #[test]
578 fn test_group_k_fold_no_leakage() {
579 let ds = dummy_dataset(12);
580 let groups = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
582 let folds = group_k_fold(&ds, &groups, 3);
583 assert_eq!(folds.len(), 3);
584
585 for (train, test) in &folds {
586 assert_eq!(train.n_samples() + test.n_samples(), 12);
587 assert_eq!(test.n_samples(), 4);
589 }
590 }
591
592 #[test]
593 fn test_group_k_fold_group_isolation() {
594 let n = 15;
596 let ds = dummy_dataset(n);
597 let groups: Vec<usize> = (0..n).map(|i| i / 3).collect(); let folds = group_k_fold(&ds, &groups, 3);
599
600 for (fold_idx, (_train, test)) in folds.iter().enumerate() {
601 assert!(!test.target.is_empty(), "fold {fold_idx} test set is empty");
604 }
605 }
606
607 #[test]
608 fn test_time_series_split_temporal_order() {
609 let n = 24;
610 let ds = dummy_dataset(n);
611 let folds = time_series_split(&ds, 3);
612 assert_eq!(folds.len(), 3);
613
614 let mut prev_train_size = 0;
616 for (train, test) in &folds {
617 assert!(
618 train.n_samples() > prev_train_size,
619 "training size should grow"
620 );
621 prev_train_size = train.n_samples();
622 assert!(!test.target.is_empty(), "test fold must not be empty");
623 }
624 }
625
626 #[test]
627 fn test_time_series_split_no_future_leak() {
628 let n = 20;
629 let features = vec![(0..n).map(|i| i as f64).collect::<Vec<_>>()];
630 let target = (0..n).map(|i| i as f64).collect();
631 let ds = Dataset::new(features, target, vec!["t".into()], "y");
632
633 let folds = time_series_split(&ds, 4);
634 for (train, test) in &folds {
635 let train_max = train.features[0]
636 .iter()
637 .copied()
638 .fold(f64::NEG_INFINITY, f64::max);
639 let test_min = test.features[0]
640 .iter()
641 .copied()
642 .fold(f64::INFINITY, f64::min);
643 assert!(
644 train_max < test_min,
645 "train max {train_max} must be < test min {test_min}"
646 );
647 }
648 }
649
650 #[test]
651 fn test_cross_val_predict_length() {
652 let ds = separable_dataset();
653 let model = DecisionTreeClassifier::new();
654 let preds = cross_val_predict(&model, &ds, 5, 42).unwrap();
655 assert_eq!(preds.len(), ds.n_samples());
656 }
657
658 #[test]
659 fn test_cross_val_predict_reasonable_accuracy() {
660 let ds = separable_dataset();
661 let model = DecisionTreeClassifier::new();
662 let preds = cross_val_predict(&model, &ds, 5, 42).unwrap();
663 let acc = accuracy(&ds.target, &preds);
664 assert!(
665 acc >= 0.8,
666 "cross_val_predict accuracy {acc} too low on separable data"
667 );
668 }
669}