1use crate::error::{DatasetsError, Result};
8use crate::utils::Dataset;
9use ndarray::Array1;
10use rand::prelude::*;
11use rand::rng;
12use rand::rngs::StdRng;
13use std::collections::HashMap;
14
15pub type CrossValidationFolds = Vec<(Vec<usize>, Vec<usize>)>;
20
21pub fn train_test_split(
49 dataset: &Dataset,
50 test_size: f64,
51 random_seed: Option<u64>,
52) -> Result<(Dataset, Dataset)> {
53 if test_size <= 0.0 || test_size >= 1.0 {
54 return Err(DatasetsError::InvalidFormat(
55 "test_size must be between 0 and 1".to_string(),
56 ));
57 }
58
59 let n_samples = dataset.n_samples();
60 let n_test = (n_samples as f64 * test_size).round() as usize;
61 let n_train = n_samples - n_test;
62
63 if n_train == 0 || n_test == 0 {
64 return Err(DatasetsError::InvalidFormat(
65 "Both train and test sets must have at least one sample".to_string(),
66 ));
67 }
68
69 let mut indices: Vec<usize> = (0..n_samples).collect();
71 let mut rng = match random_seed {
72 Some(seed) => StdRng::seed_from_u64(seed),
73 None => {
74 let mut r = rng();
75 StdRng::seed_from_u64(r.next_u64())
76 }
77 };
78 indices.shuffle(&mut rng);
79
80 let train_indices = &indices[0..n_train];
81 let test_indices = &indices[n_train..];
82
83 let train_data = dataset.data.select(ndarray::Axis(0), train_indices);
85 let train_target = dataset
86 .target
87 .as_ref()
88 .map(|t| t.select(ndarray::Axis(0), train_indices));
89
90 let mut train_dataset = Dataset::new(train_data, train_target);
91 if let Some(feature_names) = &dataset.feature_names {
92 train_dataset = train_dataset.with_feature_names(feature_names.clone());
93 }
94 if let Some(description) = &dataset.description {
95 train_dataset = train_dataset.with_description(description.clone());
96 }
97
98 let test_data = dataset.data.select(ndarray::Axis(0), test_indices);
100 let test_target = dataset
101 .target
102 .as_ref()
103 .map(|t| t.select(ndarray::Axis(0), test_indices));
104
105 let mut test_dataset = Dataset::new(test_data, test_target);
106 if let Some(feature_names) = &dataset.feature_names {
107 test_dataset = test_dataset.with_feature_names(feature_names.clone());
108 }
109 if let Some(description) = &dataset.description {
110 test_dataset = test_dataset.with_description(description.clone());
111 }
112
113 Ok((train_dataset, test_dataset))
114}
115
116pub fn k_fold_split(
147 n_samples: usize,
148 n_folds: usize,
149 shuffle: bool,
150 random_seed: Option<u64>,
151) -> Result<CrossValidationFolds> {
152 if n_folds < 2 {
153 return Err(DatasetsError::InvalidFormat(
154 "Number of folds must be at least 2".to_string(),
155 ));
156 }
157
158 if n_folds > n_samples {
159 return Err(DatasetsError::InvalidFormat(
160 "Number of folds cannot exceed number of samples".to_string(),
161 ));
162 }
163
164 let mut indices: Vec<usize> = (0..n_samples).collect();
165
166 if shuffle {
167 let mut rng = match random_seed {
168 Some(seed) => StdRng::seed_from_u64(seed),
169 None => {
170 let mut r = rng();
171 StdRng::seed_from_u64(r.next_u64())
172 }
173 };
174 indices.shuffle(&mut rng);
175 }
176
177 let mut folds = Vec::new();
178 let fold_size = n_samples / n_folds;
179 let remainder = n_samples % n_folds;
180
181 for i in 0..n_folds {
182 let start = i * fold_size + i.min(remainder);
183 let end = start + fold_size + if i < remainder { 1 } else { 0 };
184
185 let validation_indices = indices[start..end].to_vec();
186 let mut train_indices = Vec::new();
187 train_indices.extend(&indices[0..start]);
188 train_indices.extend(&indices[end..]);
189
190 folds.push((train_indices, validation_indices));
191 }
192
193 Ok(folds)
194}
195
196pub fn stratified_k_fold_split(
229 targets: &Array1<f64>,
230 n_folds: usize,
231 shuffle: bool,
232 random_seed: Option<u64>,
233) -> Result<CrossValidationFolds> {
234 if n_folds < 2 {
235 return Err(DatasetsError::InvalidFormat(
236 "Number of folds must be at least 2".to_string(),
237 ));
238 }
239
240 let n_samples = targets.len();
241 if n_folds > n_samples {
242 return Err(DatasetsError::InvalidFormat(
243 "Number of folds cannot exceed number of samples".to_string(),
244 ));
245 }
246
247 let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
249
250 for (i, &target) in targets.iter().enumerate() {
251 let class = target.round() as i64;
252 class_indices.entry(class).or_default().push(i);
253 }
254
255 if shuffle {
257 let mut rng = match random_seed {
258 Some(seed) => StdRng::seed_from_u64(seed),
259 None => {
260 let mut r = rng();
261 StdRng::seed_from_u64(r.next_u64())
262 }
263 };
264
265 for indices in class_indices.values_mut() {
266 indices.shuffle(&mut rng);
267 }
268 }
269
270 let mut folds = vec![Vec::new(); n_folds];
272
273 for (_, indices) in class_indices {
274 let class_size = indices.len();
275 let fold_size = class_size / n_folds;
276 let remainder = class_size % n_folds;
277
278 for (i, fold) in folds.iter_mut().enumerate() {
279 let start = i * fold_size + i.min(remainder);
280 let end = start + fold_size + if i < remainder { 1 } else { 0 };
281 fold.extend(&indices[start..end]);
282 }
283 }
284
285 let cv_folds = (0..n_folds)
287 .map(|i| {
288 let validation_indices = folds[i].clone();
289 let mut train_indices = Vec::new();
290 for (j, fold) in folds.iter().enumerate() {
291 if i != j {
292 train_indices.extend(fold);
293 }
294 }
295 (train_indices, validation_indices)
296 })
297 .collect();
298
299 Ok(cv_folds)
300}
301
302pub fn time_series_split(
334 n_samples: usize,
335 n_splits: usize,
336 n_test_samples: usize,
337 gap: usize,
338) -> Result<CrossValidationFolds> {
339 if n_splits < 1 {
340 return Err(DatasetsError::InvalidFormat(
341 "Number of splits must be at least 1".to_string(),
342 ));
343 }
344
345 if n_test_samples < 1 {
346 return Err(DatasetsError::InvalidFormat(
347 "Number of test samples must be at least 1".to_string(),
348 ));
349 }
350
351 let min_samples_needed = n_test_samples + gap + n_splits;
353 if n_samples < min_samples_needed {
354 return Err(DatasetsError::InvalidFormat(format!(
355 "Not enough samples for time series split. Need at least {}, got {}",
356 min_samples_needed, n_samples
357 )));
358 }
359
360 let mut folds = Vec::new();
361 let test_starts = (0..n_splits)
362 .map(|i| {
363 let split_size = (n_samples - n_test_samples - gap) / n_splits;
364 split_size * (i + 1) + gap
365 })
366 .collect::<Vec<_>>();
367
368 for &test_start in &test_starts {
369 let train_end = test_start - gap;
370 let test_end = test_start + n_test_samples;
371
372 if test_end > n_samples {
373 break;
374 }
375
376 let train_indices = (0..train_end).collect();
377 let test_indices = (test_start..test_end).collect();
378
379 folds.push((train_indices, test_indices));
380 }
381
382 if folds.is_empty() {
383 return Err(DatasetsError::InvalidFormat(
384 "Could not create any valid time series splits".to_string(),
385 ));
386 }
387
388 Ok(folds)
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use ndarray::array;
395
396 #[test]
397 fn test_train_test_split() {
398 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]];
399 let target = Some(array![0.0, 1.0, 0.0, 1.0, 0.0]);
400 let dataset = Dataset::new(data, target);
401
402 let (train, test) = train_test_split(&dataset, 0.4, Some(42)).unwrap();
403
404 assert_eq!(train.n_samples() + test.n_samples(), 5);
405 assert_eq!(test.n_samples(), 2); assert_eq!(train.n_samples(), 3); }
408
409 #[test]
410 fn test_train_test_split_invalid_size() {
411 let data = array![[1.0, 2.0]];
412 let dataset = Dataset::new(data, None);
413
414 assert!(train_test_split(&dataset, 0.0, None).is_err());
416 assert!(train_test_split(&dataset, 1.0, None).is_err());
417 assert!(train_test_split(&dataset, 1.5, None).is_err());
418 }
419
420 #[test]
421 fn test_k_fold_split() {
422 let folds = k_fold_split(10, 3, false, Some(42)).unwrap();
423
424 assert_eq!(folds.len(), 3);
425
426 let mut all_validation_indices: Vec<usize> = Vec::new();
428 for (_, val_indices) in &folds {
429 all_validation_indices.extend(val_indices);
430 }
431 all_validation_indices.sort();
432
433 let expected: Vec<usize> = (0..10).collect();
434 assert_eq!(all_validation_indices, expected);
435 }
436
437 #[test]
438 fn test_k_fold_split_invalid_params() {
439 assert!(k_fold_split(10, 1, false, None).is_err());
441
442 assert!(k_fold_split(5, 6, false, None).is_err());
444 }
445
446 #[test]
447 fn test_stratified_k_fold_split() {
448 let targets = array![0.0, 0.0, 1.0, 1.0, 0.0, 1.0]; let folds = stratified_k_fold_split(&targets, 2, false, Some(42)).unwrap();
450
451 assert_eq!(folds.len(), 2);
452
453 let mut all_validation_indices: Vec<usize> = Vec::new();
455 for (_, val_indices) in &folds {
456 all_validation_indices.extend(val_indices);
457 }
458 all_validation_indices.sort();
459
460 let expected: Vec<usize> = (0..6).collect();
461 assert_eq!(all_validation_indices, expected);
462 }
463
464 #[test]
465 fn test_time_series_split() {
466 let folds = time_series_split(20, 3, 5, 1).unwrap();
467
468 assert_eq!(folds.len(), 3);
469
470 for i in 1..folds.len() {
472 assert!(folds[i].0.len() > folds[i - 1].0.len());
473 }
474
475 for (_, val_indices) in &folds {
477 assert_eq!(val_indices.len(), 5);
478 }
479 }
480
481 #[test]
482 fn test_time_series_split_insufficient_data() {
483 assert!(time_series_split(5, 3, 5, 1).is_err());
485
486 assert!(time_series_split(100, 0, 10, 0).is_err());
488 assert!(time_series_split(100, 5, 0, 0).is_err());
489 }
490}