k_fold_split

Function k_fold_split 

Source
pub fn k_fold_split(
    n_samples: usize,
    n_folds: usize,
    shuffle: bool,
    random_seed: Option<u64>,
) -> Result<CrossValidationFolds>
Expand description

Performs K-fold cross-validation splitting

Splits the dataset into k consecutive folds. Each fold is used once as a validation set while the remaining k-1 folds form the training set.

§Arguments

  • n_samples - Number of samples in the dataset
  • n_folds - Number of folds (must be >= 2 and <= n_samples)
  • shuffle - Whether to shuffle the data before splitting
  • random_seed - Optional random seed for reproducible shuffling

§Returns

A vector of (train_indices, validation_indices) tuples for each fold

§Examples

use scirs2__datasets::utils::k_fold_split;

let folds = k_fold_split(10, 3, true, Some(42)).unwrap();
assert_eq!(folds.len(), 3);

// Each fold should have roughly equal size
for (train_idx, val_idx) in &folds {
    assert!(val_idx.len() >= 3 && val_idx.len() <= 4);
    assert_eq!(train_idx.len() + val_idx.len(), 10);
}
Examples found in repository?
examples/real_world_datasets.rs (line 325)
302fn demonstrate_advanced_operations() -> Result<(), Box<dyn std::error::Error>> {
303    println!("🔧 ADVANCED DATASET OPERATIONS");
304    println!("{}", "-".repeat(40));
305
306    let housing = load_california_housing()?;
307
308    // Data preprocessing pipeline
309    println!("Preprocessing pipeline for California Housing:");
310
311    // 1. Train/test split
312    let (mut train, test) = train_test_split(&housing, 0.2, Some(42))?;
313    println!(
314        "  1. Split: {} train, {} test",
315        train.n_samples(),
316        test.n_samples()
317    );
318
319    // 2. Feature scaling
320    let mut pipeline = MLPipeline::default();
321    train = pipeline.prepare_dataset(&train)?;
322    println!("  2. Standardized features");
323
324    // 3. Cross-validation setup
325    let cv_folds = k_fold_split(train.n_samples(), 5, true, Some(42))?;
326    println!("  3. Created {} CV folds", cv_folds.len());
327
328    // Feature correlation analysis (simplified)
329    println!("  4. Feature analysis:");
330    println!("     • {} numerical features", train.n_features());
331    println!("     • Ready for machine learning models");
332
333    // Custom dataset configuration
334    println!("\nCustom dataset loading configuration:");
335    let config = RealWorldConfig {
336        use_cache: true,
337        download_if_missing: false, // Don't download in demo
338        return_preprocessed: true,
339        subset: Some("small".to_string()),
340        random_state: Some(42),
341        ..Default::default()
342    };
343
344    println!("  • Caching: {}", config.use_cache);
345    println!("  • Download missing: {}", config.download_if_missing);
346    println!("  • Preprocessed: {}", config.return_preprocessed);
347    println!("  • Subset: {:?}", config.subset);
348
349    println!();
350    Ok(())
351}
More examples
Hide additional examples
examples/datasets_cross_validation_demo.rs (line 31)
10fn main() {
11    println!("=== Cross-Validation Demonstration ===\n");
12
13    // Create sample dataset
14    let data = Array2::from_shape_vec((20, 3), (0..60).map(|x| x as f64 / 10.0).collect()).unwrap();
15    let target = Array1::from(
16        (0..20)
17            .map(|i| if i % 2 == 0 { 0.0 } else { 1.0 })
18            .collect::<Vec<_>>(),
19    );
20
21    let dataset = Dataset::new(data.clone(), Some(target.clone()))
22        .with_description("Sample dataset for cross-validation demo".to_string());
23
24    println!("Dataset info:");
25    println!("- Samples: {}", dataset.n_samples());
26    println!("- Features: {}", dataset.n_features());
27    println!("- Description: {}\n", dataset.description.as_ref().unwrap());
28
29    // Demonstrate K-fold cross-validation
30    println!("=== K-Fold Cross-Validation (k=5) ===");
31    let k_folds = k_fold_split(dataset.n_samples(), 5, true, Some(42)).unwrap();
32
33    for (i, (train_indices, val_indices)) in k_folds.iter().enumerate() {
34        println!(
35            "Fold {}: Train, size: {}, Validation size: {}",
36            i + 1,
37            train_indices.len(),
38            val_indices.len()
39        );
40        println!(
41            "  Train indices: {:?}",
42            &train_indices[..5.min(train_indices.len())]
43        );
44        println!("  Val indices: {val_indices:?}");
45    }
46    println!();
47
48    // Demonstrate Stratified K-fold cross-validation
49    println!("=== Stratified K-Fold Cross-Validation (k=4) ===");
50    let stratified_folds = stratified_k_fold_split(&target, 4, true, Some(42)).unwrap();
51
52    for (i, (train_indices, val_indices)) in stratified_folds.iter().enumerate() {
53        // Calculate class distribution in validation set
54        let val_targets: Vec<f64> = val_indices.iter().map(|&idx| target[idx]).collect();
55        let class_0_count = val_targets.iter().filter(|&&x| x == 0.0).count();
56        let class_1_count = val_targets.iter().filter(|&&x| x == 1.0).count();
57
58        println!(
59            "Fold {}: Train, size: {}, Validation size: {}",
60            i + 1,
61            train_indices.len(),
62            val_indices.len()
63        );
64        println!(
65            "  Class distribution in validation: Class 0: {class_0_count}, Class 1: {class_1_count}"
66        );
67    }
68    println!();
69
70    // Demonstrate Time Series cross-validation
71    println!("=== Time Series Cross-Validation ===");
72    let ts_folds = time_series_split(dataset.n_samples(), 3, 3, 1).unwrap();
73
74    for (i, (train_indices, val_indices)) in ts_folds.iter().enumerate() {
75        println!(
76            "Split {}: Train, size: {}, Test size: {}",
77            i + 1,
78            train_indices.len(),
79            val_indices.len()
80        );
81        println!(
82            "  Train range: {} to {}",
83            train_indices.first().unwrap_or(&0),
84            train_indices.last().unwrap_or(&0)
85        );
86        println!(
87            "  Test range: {} to {}",
88            val_indices.first().unwrap_or(&0),
89            val_indices.last().unwrap_or(&0)
90        );
91    }
92    println!();
93
94    // Demonstrate usage with Dataset methods
95    println!("=== Using Cross-Validation with Dataset ===");
96    let first_fold = &k_folds[0];
97    let (train_indices, val_indices) = first_fold;
98
99    // Create training subset
100    let traindata = data.select(ndarray::Axis(0), train_indices);
101    let train_target = target.select(ndarray::Axis(0), train_indices);
102    let traindataset = Dataset::new(traindata, Some(train_target))
103        .with_description("Training fold from K-fold CV".to_string());
104
105    // Create validation subset
106    let valdata = data.select(ndarray::Axis(0), val_indices);
107    let val_target = target.select(ndarray::Axis(0), val_indices);
108    let valdataset = Dataset::new(valdata, Some(val_target))
109        .with_description("Validation fold from K-fold CV".to_string());
110
111    println!(
112        "Training dataset: {} samples, {} features",
113        traindataset.n_samples(),
114        traindataset.n_features()
115    );
116    println!(
117        "Validation dataset: {} samples, {} features",
118        valdataset.n_samples(),
119        valdataset.n_features()
120    );
121
122    println!("\n=== Cross-Validation Demo Complete ===");
123}