scirs2_datasets/utils/
splitting.rs

1//! Data splitting utilities for machine learning workflows
2//!
3//! This module provides various functions for splitting datasets into training,
4//! validation, and test sets. It includes support for simple train-test splits,
5//! cross-validation (both standard and stratified), and time series splitting.
6
7use crate::error::{DatasetsError, Result};
8use crate::utils::Dataset;
9use scirs2_core::ndarray::Array1;
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::prelude::*;
12use scirs2_core::random::rngs::StdRng;
13use scirs2_core::random::seq::SliceRandom;
14use std::collections::HashMap;
15
16/// Cross-validation fold indices
17///
18/// Each element is a tuple of (train_indices, validation_indices)
19/// where indices refer to samples in the original dataset.
20pub type CrossValidationFolds = Vec<(Vec<usize>, Vec<usize>)>;
21
22/// Split a dataset into training and test sets
23///
24/// This function creates a random split of the dataset while preserving
25/// the metadata and feature information in both resulting datasets.
26///
27/// # Arguments
28///
29/// * `dataset` - The dataset to split
30/// * `test_size` - Fraction of samples to include in test set (0.0 to 1.0)
31/// * `random_seed` - Optional random seed for reproducible splits
32///
33/// # Returns
34///
35/// A tuple of (train_dataset, test_dataset)
36///
37/// # Examples
38///
39/// ```rust
40/// use scirs2_core::ndarray::Array2;
41/// use scirs2_datasets::utils::{Dataset, train_test_split};
42///
43/// let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
44/// let dataset = Dataset::new(data, None);
45///
46/// let (train, test) = train_test_split(&dataset, 0.3, Some(42)).unwrap();
47/// assert_eq!(train.n_samples() + test.n_samples(), 10);
48/// ```
49#[allow(dead_code)]
50pub fn train_test_split(
51    dataset: &Dataset,
52    test_size: f64,
53    random_seed: Option<u64>,
54) -> Result<(Dataset, Dataset)> {
55    if test_size <= 0.0 || test_size >= 1.0 {
56        return Err(DatasetsError::InvalidFormat(
57            "test_size must be between 0 and 1".to_string(),
58        ));
59    }
60
61    let n_samples = dataset.n_samples();
62    let n_test = (n_samples as f64 * test_size).round() as usize;
63    let n_train = n_samples - n_test;
64
65    if n_train == 0 || n_test == 0 {
66        return Err(DatasetsError::InvalidFormat(
67            "Both train and test sets must have at least one sample".to_string(),
68        ));
69    }
70
71    // Create shuffled indices
72    let mut indices: Vec<usize> = (0..n_samples).collect();
73    let mut rng = match random_seed {
74        Some(_seed) => StdRng::seed_from_u64(_seed),
75        None => {
76            let mut r = thread_rng();
77            StdRng::seed_from_u64(r.next_u64())
78        }
79    };
80    indices.shuffle(&mut rng);
81
82    let train_indices = &indices[0..n_train];
83    let test_indices = &indices[n_train..];
84
85    // Create training dataset
86    let train_data = dataset
87        .data
88        .select(scirs2_core::ndarray::Axis(0), train_indices);
89    let train_target = dataset
90        .target
91        .as_ref()
92        .map(|t| t.select(scirs2_core::ndarray::Axis(0), train_indices));
93
94    let mut train_dataset = Dataset::new(train_data, train_target);
95    if let Some(featurenames) = &dataset.featurenames {
96        train_dataset = train_dataset.with_featurenames(featurenames.clone());
97    }
98    if let Some(description) = &dataset.description {
99        train_dataset = train_dataset.with_description(description.clone());
100    }
101
102    // Create test dataset
103    let test_data = dataset
104        .data
105        .select(scirs2_core::ndarray::Axis(0), test_indices);
106    let test_target = dataset
107        .target
108        .as_ref()
109        .map(|t| t.select(scirs2_core::ndarray::Axis(0), test_indices));
110
111    let mut test_dataset = Dataset::new(test_data, test_target);
112    if let Some(featurenames) = &dataset.featurenames {
113        test_dataset = test_dataset.with_featurenames(featurenames.clone());
114    }
115    if let Some(description) = &dataset.description {
116        test_dataset = test_dataset.with_description(description.clone());
117    }
118
119    Ok((train_dataset, test_dataset))
120}
121
122/// Performs K-fold cross-validation splitting
123///
124/// Splits the dataset into k consecutive folds. Each fold is used once as a validation
125/// set while the remaining k-1 folds form the training set.
126///
127/// # Arguments
128///
129/// * `n_samples` - Number of samples in the dataset
130/// * `n_folds` - Number of folds (must be >= 2 and <= n_samples)
131/// * `shuffle` - Whether to shuffle the data before splitting
132/// * `random_seed` - Optional random seed for reproducible shuffling
133///
134/// # Returns
135///
136/// A vector of (train_indices, validation_indices) tuples for each fold
137///
138/// # Examples
139///
140/// ```rust
141/// use scirs2_datasets::utils::k_fold_split;
142///
143/// let folds = k_fold_split(10, 3, true, Some(42)).unwrap();
144/// assert_eq!(folds.len(), 3);
145///
146/// // Each fold should have roughly equal size
147/// for (train_idx, val_idx) in &folds {
148///     assert!(val_idx.len() >= 3 && val_idx.len() <= 4);
149///     assert_eq!(train_idx.len() + val_idx.len(), 10);
150/// }
151/// ```
152#[allow(dead_code)]
153pub fn k_fold_split(
154    n_samples: usize,
155    n_folds: usize,
156    shuffle: bool,
157    random_seed: Option<u64>,
158) -> Result<CrossValidationFolds> {
159    if n_folds < 2 {
160        return Err(DatasetsError::InvalidFormat(
161            "Number of _folds must be at least 2".to_string(),
162        ));
163    }
164
165    if n_folds > n_samples {
166        return Err(DatasetsError::InvalidFormat(
167            "Number of _folds cannot exceed number of _samples".to_string(),
168        ));
169    }
170
171    let mut indices: Vec<usize> = (0..n_samples).collect();
172
173    if shuffle {
174        let mut rng = match random_seed {
175            Some(_seed) => StdRng::seed_from_u64(_seed),
176            None => {
177                let mut r = thread_rng();
178                StdRng::seed_from_u64(r.next_u64())
179            }
180        };
181        indices.shuffle(&mut rng);
182    }
183
184    let mut folds = Vec::new();
185    let fold_size = n_samples / n_folds;
186    let remainder = n_samples % n_folds;
187
188    for i in 0..n_folds {
189        let start = i * fold_size + i.min(remainder);
190        let end = start + fold_size + if i < remainder { 1 } else { 0 };
191
192        let validation_indices = indices[start..end].to_vec();
193        let mut train_indices = Vec::new();
194        train_indices.extend(&indices[0..start]);
195        train_indices.extend(&indices[end..]);
196
197        folds.push((train_indices, validation_indices));
198    }
199
200    Ok(folds)
201}
202
203/// Performs stratified K-fold cross-validation splitting
204///
205/// Splits the dataset into k folds while preserving the percentage of samples
206/// for each target class in each fold. This is useful for classification tasks
207/// with imbalanced datasets.
208///
209/// # Arguments
210///
211/// * `targets` - Target values for stratification
212/// * `n_folds` - Number of folds (must be >= 2)
213/// * `shuffle` - Whether to shuffle the data before splitting
214/// * `random_seed` - Optional random seed for reproducible shuffling
215///
216/// # Returns
217///
218/// A vector of (train_indices, validation_indices) tuples for each fold
219///
220/// # Examples
221///
222/// ```rust
223/// use scirs2_core::ndarray::Array1;
224/// use scirs2_datasets::utils::stratified_k_fold_split;
225///
226/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 0.0, 1.0]);
227/// let folds = stratified_k_fold_split(&targets, 2, true, Some(42)).unwrap();
228/// assert_eq!(folds.len(), 2);
229///
230/// // Each fold should maintain class proportions
231/// for (train_idx, val_idx) in &folds {
232///     assert_eq!(train_idx.len() + val_idx.len(), 6);
233/// }
234/// ```
235#[allow(dead_code)]
236pub fn stratified_k_fold_split(
237    targets: &Array1<f64>,
238    n_folds: usize,
239    shuffle: bool,
240    random_seed: Option<u64>,
241) -> Result<CrossValidationFolds> {
242    if n_folds < 2 {
243        return Err(DatasetsError::InvalidFormat(
244            "Number of _folds must be at least 2".to_string(),
245        ));
246    }
247
248    let n_samples = targets.len();
249    if n_folds > n_samples {
250        return Err(DatasetsError::InvalidFormat(
251            "Number of _folds cannot exceed number of samples".to_string(),
252        ));
253    }
254
255    // Group indices by target class
256    let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
257
258    for (i, &target) in targets.iter().enumerate() {
259        let class = target.round() as i64;
260        class_indices.entry(class).or_default().push(i);
261    }
262
263    // Shuffle indices within each class if requested
264    if shuffle {
265        let mut rng = match random_seed {
266            Some(_seed) => StdRng::seed_from_u64(_seed),
267            None => {
268                let mut r = thread_rng();
269                StdRng::seed_from_u64(r.next_u64())
270            }
271        };
272
273        for indices in class_indices.values_mut() {
274            indices.shuffle(&mut rng);
275        }
276    }
277
278    // Create _folds while maintaining class proportions
279    let mut folds = vec![Vec::new(); n_folds];
280
281    for (_, indices) in class_indices {
282        let class_size = indices.len();
283        let fold_size = class_size / n_folds;
284        let remainder = class_size % n_folds;
285
286        for (i, fold) in folds.iter_mut().enumerate() {
287            let start = i * fold_size + i.min(remainder);
288            let end = start + fold_size + if i < remainder { 1 } else { 0 };
289            fold.extend(&indices[start..end]);
290        }
291    }
292
293    // Convert to (train, validation) pairs
294    let cv_folds = (0..n_folds)
295        .map(|i| {
296            let validation_indices = folds[i].clone();
297            let mut train_indices = Vec::new();
298            for (j, fold) in folds.iter().enumerate() {
299                if i != j {
300                    train_indices.extend(fold);
301                }
302            }
303            (train_indices, validation_indices)
304        })
305        .collect();
306
307    Ok(cv_folds)
308}
309
310/// Performs time series cross-validation splitting
311///
312/// Creates splits suitable for time series data where future observations
313/// should not be used to predict past observations. Each training set contains
314/// all observations up to a certain point, and the validation set contains
315/// the next `n_test_samples` observations.
316///
317/// # Arguments
318///
319/// * `n_samples` - Number of samples in the dataset
320/// * `n_splits` - Number of splits to create
321/// * `n_test_samples` - Number of samples in each test set
322/// * `gap` - Number of samples to skip between train and test sets (default: 0)
323///
324/// # Returns
325///
326/// A vector of (train_indices, validation_indices) tuples for each split
327///
328/// # Examples
329///
330/// ```rust
331/// use scirs2_datasets::utils::time_series_split;
332///
333/// let folds = time_series_split(100, 5, 10, 0).unwrap();
334/// assert_eq!(folds.len(), 5);
335///
336/// // Training sets should be increasing in size
337/// for i in 1..folds.len() {
338///     assert!(folds[i].0.len() > folds[i-1].0.len());
339/// }
340/// ```
341#[allow(dead_code)]
342pub fn time_series_split(
343    n_samples: usize,
344    n_splits: usize,
345    n_test_samples: usize,
346    gap: usize,
347) -> Result<CrossValidationFolds> {
348    if n_splits < 1 {
349        return Err(DatasetsError::InvalidFormat(
350            "Number of _splits must be at least 1".to_string(),
351        ));
352    }
353
354    if n_test_samples < 1 {
355        return Err(DatasetsError::InvalidFormat(
356            "Number of test _samples must be at least 1".to_string(),
357        ));
358    }
359
360    // Calculate minimum _samples needed
361    let min_samples_needed = n_test_samples + gap + n_splits;
362    if n_samples < min_samples_needed {
363        return Err(DatasetsError::InvalidFormat(format!(
364            "Not enough _samples for time series split. Need at least {min_samples_needed}, got {n_samples}"
365        )));
366    }
367
368    let mut folds = Vec::new();
369    let test_starts = (0..n_splits)
370        .map(|i| {
371            let split_size = (n_samples - n_test_samples - gap) / n_splits;
372            split_size * (i + 1) + gap
373        })
374        .collect::<Vec<_>>();
375
376    for &test_start in &test_starts {
377        let train_end = test_start - gap;
378        let test_end = test_start + n_test_samples;
379
380        if test_end > n_samples {
381            break;
382        }
383
384        let train_indices = (0..train_end).collect();
385        let test_indices = (test_start..test_end).collect();
386
387        folds.push((train_indices, test_indices));
388    }
389
390    if folds.is_empty() {
391        return Err(DatasetsError::InvalidFormat(
392            "Could not create any valid time series _splits".to_string(),
393        ));
394    }
395
396    Ok(folds)
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use scirs2_core::ndarray::array;
403
404    #[test]
405    fn test_train_test_split() {
406        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]];
407        let target = Some(array![0.0, 1.0, 0.0, 1.0, 0.0]);
408        let dataset = Dataset::new(data, target);
409
410        let (train, test) = train_test_split(&dataset, 0.4, Some(42)).unwrap();
411
412        assert_eq!(train.n_samples() + test.n_samples(), 5);
413        assert_eq!(test.n_samples(), 2); // 40% of 5 samples
414        assert_eq!(train.n_samples(), 3); // Remaining samples
415    }
416
417    #[test]
418    fn test_train_test_split_invalid_size() {
419        let data = array![[1.0, 2.0]];
420        let dataset = Dataset::new(data, None);
421
422        // Test invalid test sizes
423        assert!(train_test_split(&dataset, 0.0, None).is_err());
424        assert!(train_test_split(&dataset, 1.0, None).is_err());
425        assert!(train_test_split(&dataset, 1.5, None).is_err());
426    }
427
428    #[test]
429    fn test_k_fold_split() {
430        let folds = k_fold_split(10, 3, false, Some(42)).unwrap();
431
432        assert_eq!(folds.len(), 3);
433
434        // Check that all samples are covered exactly once in validation
435        let mut all_validation_indices: Vec<usize> = Vec::new();
436        for (_, val_indices) in &folds {
437            all_validation_indices.extend(val_indices);
438        }
439        all_validation_indices.sort();
440
441        let expected: Vec<usize> = (0..10).collect();
442        assert_eq!(all_validation_indices, expected);
443    }
444
445    #[test]
446    fn test_k_fold_split_invalid_params() {
447        // Too few folds
448        assert!(k_fold_split(10, 1, false, None).is_err());
449
450        // Too many folds
451        assert!(k_fold_split(5, 6, false, None).is_err());
452    }
453
454    #[test]
455    fn test_stratified_k_fold_split() {
456        let targets = array![0.0, 0.0, 1.0, 1.0, 0.0, 1.0]; // 3 class 0, 3 class 1
457        let folds = stratified_k_fold_split(&targets, 2, false, Some(42)).unwrap();
458
459        assert_eq!(folds.len(), 2);
460
461        // Check that all samples are covered
462        let mut all_validation_indices: Vec<usize> = Vec::new();
463        for (_, val_indices) in &folds {
464            all_validation_indices.extend(val_indices);
465        }
466        all_validation_indices.sort();
467
468        let expected: Vec<usize> = (0..6).collect();
469        assert_eq!(all_validation_indices, expected);
470    }
471
472    #[test]
473    fn test_time_series_split() {
474        let folds = time_series_split(20, 3, 5, 1).unwrap();
475
476        assert_eq!(folds.len(), 3);
477
478        // Check that training sets are increasing in size
479        for i in 1..folds.len() {
480            assert!(folds[i].0.len() > folds[i - 1].0.len());
481        }
482
483        // Check that validation sets have correct size
484        for (_, val_indices) in &folds {
485            assert_eq!(val_indices.len(), 5);
486        }
487    }
488
489    #[test]
490    fn test_time_series_split_insufficient_data() {
491        // Not enough samples
492        assert!(time_series_split(5, 3, 5, 1).is_err());
493
494        // Invalid parameters
495        assert!(time_series_split(100, 0, 10, 0).is_err());
496        assert!(time_series_split(100, 5, 0, 0).is_err());
497    }
498}