1use crate::error::{DatasetsError, Result};
8use crate::utils::Dataset;
9use scirs2_core::ndarray::Array1;
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::prelude::*;
12use scirs2_core::random::rngs::StdRng;
13use scirs2_core::random::seq::SliceRandom;
14use std::collections::HashMap;
15
16pub type CrossValidationFolds = Vec<(Vec<usize>, Vec<usize>)>;
21
22#[allow(dead_code)]
50pub fn train_test_split(
51 dataset: &Dataset,
52 test_size: f64,
53 random_seed: Option<u64>,
54) -> Result<(Dataset, Dataset)> {
55 if test_size <= 0.0 || test_size >= 1.0 {
56 return Err(DatasetsError::InvalidFormat(
57 "test_size must be between 0 and 1".to_string(),
58 ));
59 }
60
61 let n_samples = dataset.n_samples();
62 let n_test = (n_samples as f64 * test_size).round() as usize;
63 let n_train = n_samples - n_test;
64
65 if n_train == 0 || n_test == 0 {
66 return Err(DatasetsError::InvalidFormat(
67 "Both train and test sets must have at least one sample".to_string(),
68 ));
69 }
70
71 let mut indices: Vec<usize> = (0..n_samples).collect();
73 let mut rng = match random_seed {
74 Some(_seed) => StdRng::seed_from_u64(_seed),
75 None => {
76 let mut r = thread_rng();
77 StdRng::seed_from_u64(r.next_u64())
78 }
79 };
80 indices.shuffle(&mut rng);
81
82 let train_indices = &indices[0..n_train];
83 let test_indices = &indices[n_train..];
84
85 let train_data = dataset
87 .data
88 .select(scirs2_core::ndarray::Axis(0), train_indices);
89 let train_target = dataset
90 .target
91 .as_ref()
92 .map(|t| t.select(scirs2_core::ndarray::Axis(0), train_indices));
93
94 let mut train_dataset = Dataset::new(train_data, train_target);
95 if let Some(featurenames) = &dataset.featurenames {
96 train_dataset = train_dataset.with_featurenames(featurenames.clone());
97 }
98 if let Some(description) = &dataset.description {
99 train_dataset = train_dataset.with_description(description.clone());
100 }
101
102 let test_data = dataset
104 .data
105 .select(scirs2_core::ndarray::Axis(0), test_indices);
106 let test_target = dataset
107 .target
108 .as_ref()
109 .map(|t| t.select(scirs2_core::ndarray::Axis(0), test_indices));
110
111 let mut test_dataset = Dataset::new(test_data, test_target);
112 if let Some(featurenames) = &dataset.featurenames {
113 test_dataset = test_dataset.with_featurenames(featurenames.clone());
114 }
115 if let Some(description) = &dataset.description {
116 test_dataset = test_dataset.with_description(description.clone());
117 }
118
119 Ok((train_dataset, test_dataset))
120}
121
122#[allow(dead_code)]
153pub fn k_fold_split(
154 n_samples: usize,
155 n_folds: usize,
156 shuffle: bool,
157 random_seed: Option<u64>,
158) -> Result<CrossValidationFolds> {
159 if n_folds < 2 {
160 return Err(DatasetsError::InvalidFormat(
161 "Number of _folds must be at least 2".to_string(),
162 ));
163 }
164
165 if n_folds > n_samples {
166 return Err(DatasetsError::InvalidFormat(
167 "Number of _folds cannot exceed number of _samples".to_string(),
168 ));
169 }
170
171 let mut indices: Vec<usize> = (0..n_samples).collect();
172
173 if shuffle {
174 let mut rng = match random_seed {
175 Some(_seed) => StdRng::seed_from_u64(_seed),
176 None => {
177 let mut r = thread_rng();
178 StdRng::seed_from_u64(r.next_u64())
179 }
180 };
181 indices.shuffle(&mut rng);
182 }
183
184 let mut folds = Vec::new();
185 let fold_size = n_samples / n_folds;
186 let remainder = n_samples % n_folds;
187
188 for i in 0..n_folds {
189 let start = i * fold_size + i.min(remainder);
190 let end = start + fold_size + if i < remainder { 1 } else { 0 };
191
192 let validation_indices = indices[start..end].to_vec();
193 let mut train_indices = Vec::new();
194 train_indices.extend(&indices[0..start]);
195 train_indices.extend(&indices[end..]);
196
197 folds.push((train_indices, validation_indices));
198 }
199
200 Ok(folds)
201}
202
203#[allow(dead_code)]
236pub fn stratified_k_fold_split(
237 targets: &Array1<f64>,
238 n_folds: usize,
239 shuffle: bool,
240 random_seed: Option<u64>,
241) -> Result<CrossValidationFolds> {
242 if n_folds < 2 {
243 return Err(DatasetsError::InvalidFormat(
244 "Number of _folds must be at least 2".to_string(),
245 ));
246 }
247
248 let n_samples = targets.len();
249 if n_folds > n_samples {
250 return Err(DatasetsError::InvalidFormat(
251 "Number of _folds cannot exceed number of samples".to_string(),
252 ));
253 }
254
255 let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
257
258 for (i, &target) in targets.iter().enumerate() {
259 let class = target.round() as i64;
260 class_indices.entry(class).or_default().push(i);
261 }
262
263 if shuffle {
265 let mut rng = match random_seed {
266 Some(_seed) => StdRng::seed_from_u64(_seed),
267 None => {
268 let mut r = thread_rng();
269 StdRng::seed_from_u64(r.next_u64())
270 }
271 };
272
273 for indices in class_indices.values_mut() {
274 indices.shuffle(&mut rng);
275 }
276 }
277
278 let mut folds = vec![Vec::new(); n_folds];
280
281 for (_, indices) in class_indices {
282 let class_size = indices.len();
283 let fold_size = class_size / n_folds;
284 let remainder = class_size % n_folds;
285
286 for (i, fold) in folds.iter_mut().enumerate() {
287 let start = i * fold_size + i.min(remainder);
288 let end = start + fold_size + if i < remainder { 1 } else { 0 };
289 fold.extend(&indices[start..end]);
290 }
291 }
292
293 let cv_folds = (0..n_folds)
295 .map(|i| {
296 let validation_indices = folds[i].clone();
297 let mut train_indices = Vec::new();
298 for (j, fold) in folds.iter().enumerate() {
299 if i != j {
300 train_indices.extend(fold);
301 }
302 }
303 (train_indices, validation_indices)
304 })
305 .collect();
306
307 Ok(cv_folds)
308}
309
310#[allow(dead_code)]
342pub fn time_series_split(
343 n_samples: usize,
344 n_splits: usize,
345 n_test_samples: usize,
346 gap: usize,
347) -> Result<CrossValidationFolds> {
348 if n_splits < 1 {
349 return Err(DatasetsError::InvalidFormat(
350 "Number of _splits must be at least 1".to_string(),
351 ));
352 }
353
354 if n_test_samples < 1 {
355 return Err(DatasetsError::InvalidFormat(
356 "Number of test _samples must be at least 1".to_string(),
357 ));
358 }
359
360 let min_samples_needed = n_test_samples + gap + n_splits;
362 if n_samples < min_samples_needed {
363 return Err(DatasetsError::InvalidFormat(format!(
364 "Not enough _samples for time series split. Need at least {min_samples_needed}, got {n_samples}"
365 )));
366 }
367
368 let mut folds = Vec::new();
369 let test_starts = (0..n_splits)
370 .map(|i| {
371 let split_size = (n_samples - n_test_samples - gap) / n_splits;
372 split_size * (i + 1) + gap
373 })
374 .collect::<Vec<_>>();
375
376 for &test_start in &test_starts {
377 let train_end = test_start - gap;
378 let test_end = test_start + n_test_samples;
379
380 if test_end > n_samples {
381 break;
382 }
383
384 let train_indices = (0..train_end).collect();
385 let test_indices = (test_start..test_end).collect();
386
387 folds.push((train_indices, test_indices));
388 }
389
390 if folds.is_empty() {
391 return Err(DatasetsError::InvalidFormat(
392 "Could not create any valid time series _splits".to_string(),
393 ));
394 }
395
396 Ok(folds)
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use scirs2_core::ndarray::array;
403
404 #[test]
405 fn test_train_test_split() {
406 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]];
407 let target = Some(array![0.0, 1.0, 0.0, 1.0, 0.0]);
408 let dataset = Dataset::new(data, target);
409
410 let (train, test) = train_test_split(&dataset, 0.4, Some(42)).unwrap();
411
412 assert_eq!(train.n_samples() + test.n_samples(), 5);
413 assert_eq!(test.n_samples(), 2); assert_eq!(train.n_samples(), 3); }
416
417 #[test]
418 fn test_train_test_split_invalid_size() {
419 let data = array![[1.0, 2.0]];
420 let dataset = Dataset::new(data, None);
421
422 assert!(train_test_split(&dataset, 0.0, None).is_err());
424 assert!(train_test_split(&dataset, 1.0, None).is_err());
425 assert!(train_test_split(&dataset, 1.5, None).is_err());
426 }
427
428 #[test]
429 fn test_k_fold_split() {
430 let folds = k_fold_split(10, 3, false, Some(42)).unwrap();
431
432 assert_eq!(folds.len(), 3);
433
434 let mut all_validation_indices: Vec<usize> = Vec::new();
436 for (_, val_indices) in &folds {
437 all_validation_indices.extend(val_indices);
438 }
439 all_validation_indices.sort();
440
441 let expected: Vec<usize> = (0..10).collect();
442 assert_eq!(all_validation_indices, expected);
443 }
444
445 #[test]
446 fn test_k_fold_split_invalid_params() {
447 assert!(k_fold_split(10, 1, false, None).is_err());
449
450 assert!(k_fold_split(5, 6, false, None).is_err());
452 }
453
454 #[test]
455 fn test_stratified_k_fold_split() {
456 let targets = array![0.0, 0.0, 1.0, 1.0, 0.0, 1.0]; let folds = stratified_k_fold_split(&targets, 2, false, Some(42)).unwrap();
458
459 assert_eq!(folds.len(), 2);
460
461 let mut all_validation_indices: Vec<usize> = Vec::new();
463 for (_, val_indices) in &folds {
464 all_validation_indices.extend(val_indices);
465 }
466 all_validation_indices.sort();
467
468 let expected: Vec<usize> = (0..6).collect();
469 assert_eq!(all_validation_indices, expected);
470 }
471
472 #[test]
473 fn test_time_series_split() {
474 let folds = time_series_split(20, 3, 5, 1).unwrap();
475
476 assert_eq!(folds.len(), 3);
477
478 for i in 1..folds.len() {
480 assert!(folds[i].0.len() > folds[i - 1].0.len());
481 }
482
483 for (_, val_indices) in &folds {
485 assert_eq!(val_indices.len(), 5);
486 }
487 }
488
489 #[test]
490 fn test_time_series_split_insufficient_data() {
491 assert!(time_series_split(5, 3, 5, 1).is_err());
493
494 assert!(time_series_split(100, 0, 10, 0).is_err());
496 assert!(time_series_split(100, 5, 0, 0).is_err());
497 }
498}