scirs2_datasets/utils/
splitting.rs

1//! Data splitting utilities for machine learning workflows
2//!
3//! This module provides various functions for splitting datasets into training,
4//! validation, and test sets. It includes support for simple train-test splits,
5//! cross-validation (both standard and stratified), and time series splitting.
6
7use crate::error::{DatasetsError, Result};
8use crate::utils::Dataset;
9use ndarray::Array1;
10use rand::prelude::*;
11use rand::rng;
12use rand::rngs::StdRng;
13use std::collections::HashMap;
14
15/// Cross-validation fold indices
16///
17/// Each element is a tuple of (train_indices, validation_indices)
18/// where indices refer to samples in the original dataset.
19pub type CrossValidationFolds = Vec<(Vec<usize>, Vec<usize>)>;
20
21/// Split a dataset into training and test sets
22///
23/// This function creates a random split of the dataset while preserving
24/// the metadata and feature information in both resulting datasets.
25///
26/// # Arguments
27///
28/// * `dataset` - The dataset to split
29/// * `test_size` - Fraction of samples to include in test set (0.0 to 1.0)
30/// * `random_seed` - Optional random seed for reproducible splits
31///
32/// # Returns
33///
34/// A tuple of (train_dataset, test_dataset)
35///
36/// # Examples
37///
38/// ```rust
39/// use ndarray::Array2;
40/// use scirs2_datasets::utils::{Dataset, train_test_split};
41///
42/// let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
43/// let dataset = Dataset::new(data, None);
44///
45/// let (train, test) = train_test_split(&dataset, 0.3, Some(42)).unwrap();
46/// assert_eq!(train.n_samples() + test.n_samples(), 10);
47/// ```
48pub fn train_test_split(
49    dataset: &Dataset,
50    test_size: f64,
51    random_seed: Option<u64>,
52) -> Result<(Dataset, Dataset)> {
53    if test_size <= 0.0 || test_size >= 1.0 {
54        return Err(DatasetsError::InvalidFormat(
55            "test_size must be between 0 and 1".to_string(),
56        ));
57    }
58
59    let n_samples = dataset.n_samples();
60    let n_test = (n_samples as f64 * test_size).round() as usize;
61    let n_train = n_samples - n_test;
62
63    if n_train == 0 || n_test == 0 {
64        return Err(DatasetsError::InvalidFormat(
65            "Both train and test sets must have at least one sample".to_string(),
66        ));
67    }
68
69    // Create shuffled indices
70    let mut indices: Vec<usize> = (0..n_samples).collect();
71    let mut rng = match random_seed {
72        Some(seed) => StdRng::seed_from_u64(seed),
73        None => {
74            let mut r = rng();
75            StdRng::seed_from_u64(r.next_u64())
76        }
77    };
78    indices.shuffle(&mut rng);
79
80    let train_indices = &indices[0..n_train];
81    let test_indices = &indices[n_train..];
82
83    // Create training dataset
84    let train_data = dataset.data.select(ndarray::Axis(0), train_indices);
85    let train_target = dataset
86        .target
87        .as_ref()
88        .map(|t| t.select(ndarray::Axis(0), train_indices));
89
90    let mut train_dataset = Dataset::new(train_data, train_target);
91    if let Some(feature_names) = &dataset.feature_names {
92        train_dataset = train_dataset.with_feature_names(feature_names.clone());
93    }
94    if let Some(description) = &dataset.description {
95        train_dataset = train_dataset.with_description(description.clone());
96    }
97
98    // Create test dataset
99    let test_data = dataset.data.select(ndarray::Axis(0), test_indices);
100    let test_target = dataset
101        .target
102        .as_ref()
103        .map(|t| t.select(ndarray::Axis(0), test_indices));
104
105    let mut test_dataset = Dataset::new(test_data, test_target);
106    if let Some(feature_names) = &dataset.feature_names {
107        test_dataset = test_dataset.with_feature_names(feature_names.clone());
108    }
109    if let Some(description) = &dataset.description {
110        test_dataset = test_dataset.with_description(description.clone());
111    }
112
113    Ok((train_dataset, test_dataset))
114}
115
116/// Performs K-fold cross-validation splitting
117///
118/// Splits the dataset into k consecutive folds. Each fold is used once as a validation
119/// set while the remaining k-1 folds form the training set.
120///
121/// # Arguments
122///
123/// * `n_samples` - Number of samples in the dataset
124/// * `n_folds` - Number of folds (must be >= 2 and <= n_samples)
125/// * `shuffle` - Whether to shuffle the data before splitting
126/// * `random_seed` - Optional random seed for reproducible shuffling
127///
128/// # Returns
129///
130/// A vector of (train_indices, validation_indices) tuples for each fold
131///
132/// # Examples
133///
134/// ```rust
135/// use scirs2_datasets::utils::k_fold_split;
136///
137/// let folds = k_fold_split(10, 3, true, Some(42)).unwrap();
138/// assert_eq!(folds.len(), 3);
139///
140/// // Each fold should have roughly equal size
141/// for (train_idx, val_idx) in &folds {
142///     assert!(val_idx.len() >= 3 && val_idx.len() <= 4);
143///     assert_eq!(train_idx.len() + val_idx.len(), 10);
144/// }
145/// ```
146pub fn k_fold_split(
147    n_samples: usize,
148    n_folds: usize,
149    shuffle: bool,
150    random_seed: Option<u64>,
151) -> Result<CrossValidationFolds> {
152    if n_folds < 2 {
153        return Err(DatasetsError::InvalidFormat(
154            "Number of folds must be at least 2".to_string(),
155        ));
156    }
157
158    if n_folds > n_samples {
159        return Err(DatasetsError::InvalidFormat(
160            "Number of folds cannot exceed number of samples".to_string(),
161        ));
162    }
163
164    let mut indices: Vec<usize> = (0..n_samples).collect();
165
166    if shuffle {
167        let mut rng = match random_seed {
168            Some(seed) => StdRng::seed_from_u64(seed),
169            None => {
170                let mut r = rng();
171                StdRng::seed_from_u64(r.next_u64())
172            }
173        };
174        indices.shuffle(&mut rng);
175    }
176
177    let mut folds = Vec::new();
178    let fold_size = n_samples / n_folds;
179    let remainder = n_samples % n_folds;
180
181    for i in 0..n_folds {
182        let start = i * fold_size + i.min(remainder);
183        let end = start + fold_size + if i < remainder { 1 } else { 0 };
184
185        let validation_indices = indices[start..end].to_vec();
186        let mut train_indices = Vec::new();
187        train_indices.extend(&indices[0..start]);
188        train_indices.extend(&indices[end..]);
189
190        folds.push((train_indices, validation_indices));
191    }
192
193    Ok(folds)
194}
195
196/// Performs stratified K-fold cross-validation splitting
197///
198/// Splits the dataset into k folds while preserving the percentage of samples
199/// for each target class in each fold. This is useful for classification tasks
200/// with imbalanced datasets.
201///
202/// # Arguments
203///
204/// * `targets` - Target values for stratification
205/// * `n_folds` - Number of folds (must be >= 2)
206/// * `shuffle` - Whether to shuffle the data before splitting
207/// * `random_seed` - Optional random seed for reproducible shuffling
208///
209/// # Returns
210///
211/// A vector of (train_indices, validation_indices) tuples for each fold
212///
213/// # Examples
214///
215/// ```rust
216/// use ndarray::Array1;
217/// use scirs2_datasets::utils::stratified_k_fold_split;
218///
219/// let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 0.0, 1.0]);
220/// let folds = stratified_k_fold_split(&targets, 2, true, Some(42)).unwrap();
221/// assert_eq!(folds.len(), 2);
222///
223/// // Each fold should maintain class proportions
224/// for (train_idx, val_idx) in &folds {
225///     assert_eq!(train_idx.len() + val_idx.len(), 6);
226/// }
227/// ```
228pub fn stratified_k_fold_split(
229    targets: &Array1<f64>,
230    n_folds: usize,
231    shuffle: bool,
232    random_seed: Option<u64>,
233) -> Result<CrossValidationFolds> {
234    if n_folds < 2 {
235        return Err(DatasetsError::InvalidFormat(
236            "Number of folds must be at least 2".to_string(),
237        ));
238    }
239
240    let n_samples = targets.len();
241    if n_folds > n_samples {
242        return Err(DatasetsError::InvalidFormat(
243            "Number of folds cannot exceed number of samples".to_string(),
244        ));
245    }
246
247    // Group indices by target class
248    let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
249
250    for (i, &target) in targets.iter().enumerate() {
251        let class = target.round() as i64;
252        class_indices.entry(class).or_default().push(i);
253    }
254
255    // Shuffle indices within each class if requested
256    if shuffle {
257        let mut rng = match random_seed {
258            Some(seed) => StdRng::seed_from_u64(seed),
259            None => {
260                let mut r = rng();
261                StdRng::seed_from_u64(r.next_u64())
262            }
263        };
264
265        for indices in class_indices.values_mut() {
266            indices.shuffle(&mut rng);
267        }
268    }
269
270    // Create folds while maintaining class proportions
271    let mut folds = vec![Vec::new(); n_folds];
272
273    for (_, indices) in class_indices {
274        let class_size = indices.len();
275        let fold_size = class_size / n_folds;
276        let remainder = class_size % n_folds;
277
278        for (i, fold) in folds.iter_mut().enumerate() {
279            let start = i * fold_size + i.min(remainder);
280            let end = start + fold_size + if i < remainder { 1 } else { 0 };
281            fold.extend(&indices[start..end]);
282        }
283    }
284
285    // Convert to (train, validation) pairs
286    let cv_folds = (0..n_folds)
287        .map(|i| {
288            let validation_indices = folds[i].clone();
289            let mut train_indices = Vec::new();
290            for (j, fold) in folds.iter().enumerate() {
291                if i != j {
292                    train_indices.extend(fold);
293                }
294            }
295            (train_indices, validation_indices)
296        })
297        .collect();
298
299    Ok(cv_folds)
300}
301
302/// Performs time series cross-validation splitting
303///
304/// Creates splits suitable for time series data where future observations
305/// should not be used to predict past observations. Each training set contains
306/// all observations up to a certain point, and the validation set contains
307/// the next `n_test_samples` observations.
308///
309/// # Arguments
310///
311/// * `n_samples` - Number of samples in the dataset
312/// * `n_splits` - Number of splits to create
313/// * `n_test_samples` - Number of samples in each test set
314/// * `gap` - Number of samples to skip between train and test sets (default: 0)
315///
316/// # Returns
317///
318/// A vector of (train_indices, validation_indices) tuples for each split
319///
320/// # Examples
321///
322/// ```rust
323/// use scirs2_datasets::utils::time_series_split;
324///
325/// let folds = time_series_split(100, 5, 10, 0).unwrap();
326/// assert_eq!(folds.len(), 5);
327///
328/// // Training sets should be increasing in size
329/// for i in 1..folds.len() {
330///     assert!(folds[i].0.len() > folds[i-1].0.len());
331/// }
332/// ```
333pub fn time_series_split(
334    n_samples: usize,
335    n_splits: usize,
336    n_test_samples: usize,
337    gap: usize,
338) -> Result<CrossValidationFolds> {
339    if n_splits < 1 {
340        return Err(DatasetsError::InvalidFormat(
341            "Number of splits must be at least 1".to_string(),
342        ));
343    }
344
345    if n_test_samples < 1 {
346        return Err(DatasetsError::InvalidFormat(
347            "Number of test samples must be at least 1".to_string(),
348        ));
349    }
350
351    // Calculate minimum samples needed
352    let min_samples_needed = n_test_samples + gap + n_splits;
353    if n_samples < min_samples_needed {
354        return Err(DatasetsError::InvalidFormat(format!(
355            "Not enough samples for time series split. Need at least {}, got {}",
356            min_samples_needed, n_samples
357        )));
358    }
359
360    let mut folds = Vec::new();
361    let test_starts = (0..n_splits)
362        .map(|i| {
363            let split_size = (n_samples - n_test_samples - gap) / n_splits;
364            split_size * (i + 1) + gap
365        })
366        .collect::<Vec<_>>();
367
368    for &test_start in &test_starts {
369        let train_end = test_start - gap;
370        let test_end = test_start + n_test_samples;
371
372        if test_end > n_samples {
373            break;
374        }
375
376        let train_indices = (0..train_end).collect();
377        let test_indices = (test_start..test_end).collect();
378
379        folds.push((train_indices, test_indices));
380    }
381
382    if folds.is_empty() {
383        return Err(DatasetsError::InvalidFormat(
384            "Could not create any valid time series splits".to_string(),
385        ));
386    }
387
388    Ok(folds)
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use ndarray::array;
395
396    #[test]
397    fn test_train_test_split() {
398        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]];
399        let target = Some(array![0.0, 1.0, 0.0, 1.0, 0.0]);
400        let dataset = Dataset::new(data, target);
401
402        let (train, test) = train_test_split(&dataset, 0.4, Some(42)).unwrap();
403
404        assert_eq!(train.n_samples() + test.n_samples(), 5);
405        assert_eq!(test.n_samples(), 2); // 40% of 5 samples
406        assert_eq!(train.n_samples(), 3); // Remaining samples
407    }
408
409    #[test]
410    fn test_train_test_split_invalid_size() {
411        let data = array![[1.0, 2.0]];
412        let dataset = Dataset::new(data, None);
413
414        // Test invalid test sizes
415        assert!(train_test_split(&dataset, 0.0, None).is_err());
416        assert!(train_test_split(&dataset, 1.0, None).is_err());
417        assert!(train_test_split(&dataset, 1.5, None).is_err());
418    }
419
420    #[test]
421    fn test_k_fold_split() {
422        let folds = k_fold_split(10, 3, false, Some(42)).unwrap();
423
424        assert_eq!(folds.len(), 3);
425
426        // Check that all samples are covered exactly once in validation
427        let mut all_validation_indices: Vec<usize> = Vec::new();
428        for (_, val_indices) in &folds {
429            all_validation_indices.extend(val_indices);
430        }
431        all_validation_indices.sort();
432
433        let expected: Vec<usize> = (0..10).collect();
434        assert_eq!(all_validation_indices, expected);
435    }
436
437    #[test]
438    fn test_k_fold_split_invalid_params() {
439        // Too few folds
440        assert!(k_fold_split(10, 1, false, None).is_err());
441
442        // Too many folds
443        assert!(k_fold_split(5, 6, false, None).is_err());
444    }
445
446    #[test]
447    fn test_stratified_k_fold_split() {
448        let targets = array![0.0, 0.0, 1.0, 1.0, 0.0, 1.0]; // 3 class 0, 3 class 1
449        let folds = stratified_k_fold_split(&targets, 2, false, Some(42)).unwrap();
450
451        assert_eq!(folds.len(), 2);
452
453        // Check that all samples are covered
454        let mut all_validation_indices: Vec<usize> = Vec::new();
455        for (_, val_indices) in &folds {
456            all_validation_indices.extend(val_indices);
457        }
458        all_validation_indices.sort();
459
460        let expected: Vec<usize> = (0..6).collect();
461        assert_eq!(all_validation_indices, expected);
462    }
463
464    #[test]
465    fn test_time_series_split() {
466        let folds = time_series_split(20, 3, 5, 1).unwrap();
467
468        assert_eq!(folds.len(), 3);
469
470        // Check that training sets are increasing in size
471        for i in 1..folds.len() {
472            assert!(folds[i].0.len() > folds[i - 1].0.len());
473        }
474
475        // Check that validation sets have correct size
476        for (_, val_indices) in &folds {
477            assert_eq!(val_indices.len(), 5);
478        }
479    }
480
481    #[test]
482    fn test_time_series_split_insufficient_data() {
483        // Not enough samples
484        assert!(time_series_split(5, 3, 5, 1).is_err());
485
486        // Invalid parameters
487        assert!(time_series_split(100, 0, 10, 0).is_err());
488        assert!(time_series_split(100, 5, 0, 0).is_err());
489    }
490}