pub fn k_fold_split(
n_samples: usize,
n_folds: usize,
shuffle: bool,
random_seed: Option<u64>,
) -> Result<CrossValidationFolds>
Expand description
Performs K-fold cross-validation splitting
Splits the dataset into k consecutive folds. Each fold is used once as a validation set while the remaining k-1 folds form the training set.
§Arguments
n_samples
- Number of samples in the datasetn_folds
- Number of folds (must be >= 2 and <= n_samples)shuffle
- Whether to shuffle the data before splittingrandom_seed
- Optional random seed for reproducible shuffling
§Returns
A vector of (train_indices, validation_indices) tuples for each fold
§Examples
use scirs2_datasets::utils::k_fold_split;
let folds = k_fold_split(10, 3, true, Some(42)).unwrap();
assert_eq!(folds.len(), 3);
// Each fold should have roughly equal size
for (train_idx, val_idx) in &folds {
assert!(val_idx.len() >= 3 && val_idx.len() <= 4);
assert_eq!(train_idx.len() + val_idx.len(), 10);
}
Examples found in repository?
examples/cross_validation_demo.rs (line 30)
9fn main() {
10 println!("=== Cross-Validation Demonstration ===\n");
11
12 // Create sample dataset
13 let data = Array2::from_shape_vec((20, 3), (0..60).map(|x| x as f64 / 10.0).collect()).unwrap();
14 let target = Array1::from(
15 (0..20)
16 .map(|i| if i % 2 == 0 { 0.0 } else { 1.0 })
17 .collect::<Vec<_>>(),
18 );
19
20 let dataset = Dataset::new(data.clone(), Some(target.clone()))
21 .with_description("Sample dataset for cross-validation demo".to_string());
22
23 println!("Dataset info:");
24 println!("- Samples: {}", dataset.n_samples());
25 println!("- Features: {}", dataset.n_features());
26 println!("- Description: {}\n", dataset.description.as_ref().unwrap());
27
28 // Demonstrate K-fold cross-validation
29 println!("=== K-Fold Cross-Validation (k=5) ===");
30 let k_folds = k_fold_split(dataset.n_samples(), 5, true, Some(42)).unwrap();
31
32 for (i, (train_indices, val_indices)) in k_folds.iter().enumerate() {
33 println!(
34 "Fold {}: Train size: {}, Validation size: {}",
35 i + 1,
36 train_indices.len(),
37 val_indices.len()
38 );
39 println!(
40 " Train indices: {:?}",
41 &train_indices[..5.min(train_indices.len())]
42 );
43 println!(" Val indices: {:?}", val_indices);
44 }
45 println!();
46
47 // Demonstrate Stratified K-fold cross-validation
48 println!("=== Stratified K-Fold Cross-Validation (k=4) ===");
49 let stratified_folds = stratified_k_fold_split(&target, 4, true, Some(42)).unwrap();
50
51 for (i, (train_indices, val_indices)) in stratified_folds.iter().enumerate() {
52 // Calculate class distribution in validation set
53 let val_targets: Vec<f64> = val_indices.iter().map(|&idx| target[idx]).collect();
54 let class_0_count = val_targets.iter().filter(|&&x| x == 0.0).count();
55 let class_1_count = val_targets.iter().filter(|&&x| x == 1.0).count();
56
57 println!(
58 "Fold {}: Train size: {}, Validation size: {}",
59 i + 1,
60 train_indices.len(),
61 val_indices.len()
62 );
63 println!(
64 " Class distribution in validation: Class 0: {}, Class 1: {}",
65 class_0_count, class_1_count
66 );
67 }
68 println!();
69
70 // Demonstrate Time Series cross-validation
71 println!("=== Time Series Cross-Validation ===");
72 let ts_folds = time_series_split(dataset.n_samples(), 3, 3, 1).unwrap();
73
74 for (i, (train_indices, val_indices)) in ts_folds.iter().enumerate() {
75 println!(
76 "Split {}: Train size: {}, Test size: {}",
77 i + 1,
78 train_indices.len(),
79 val_indices.len()
80 );
81 println!(
82 " Train range: {} to {}",
83 train_indices.first().unwrap_or(&0),
84 train_indices.last().unwrap_or(&0)
85 );
86 println!(
87 " Test range: {} to {}",
88 val_indices.first().unwrap_or(&0),
89 val_indices.last().unwrap_or(&0)
90 );
91 }
92 println!();
93
94 // Demonstrate usage with Dataset methods
95 println!("=== Using Cross-Validation with Dataset ===");
96 let first_fold = &k_folds[0];
97 let (train_indices, val_indices) = first_fold;
98
99 // Create training subset
100 let train_data = data.select(ndarray::Axis(0), train_indices);
101 let train_target = target.select(ndarray::Axis(0), train_indices);
102 let train_dataset = Dataset::new(train_data, Some(train_target))
103 .with_description("Training fold from K-fold CV".to_string());
104
105 // Create validation subset
106 let val_data = data.select(ndarray::Axis(0), val_indices);
107 let val_target = target.select(ndarray::Axis(0), val_indices);
108 let val_dataset = Dataset::new(val_data, Some(val_target))
109 .with_description("Validation fold from K-fold CV".to_string());
110
111 println!(
112 "Training dataset: {} samples, {} features",
113 train_dataset.n_samples(),
114 train_dataset.n_features()
115 );
116 println!(
117 "Validation dataset: {} samples, {} features",
118 val_dataset.n_samples(),
119 val_dataset.n_features()
120 );
121
122 println!("\n=== Cross-Validation Demo Complete ===");
123}