1use crate::{UtilsError, UtilsResult};
7use scirs2_core::random::rngs::StdRng;
8use scirs2_core::random::{Rng, SeedableRng};
9use std::collections::HashMap;
10
11#[derive(Clone, Debug)]
13pub struct CVSplit {
14 pub train: Vec<usize>,
16 pub test: Vec<usize>,
18}
19
20#[derive(Clone, Debug)]
25pub struct StratifiedKFold {
26 n_splits: usize,
27 shuffle: bool,
28 random_state: Option<u64>,
29}
30
31impl StratifiedKFold {
32 pub fn new(n_splits: usize, shuffle: bool, random_state: Option<u64>) -> UtilsResult<Self> {
39 if n_splits < 2 {
40 return Err(UtilsError::InvalidParameter(
41 "n_splits must be at least 2".to_string(),
42 ));
43 }
44
45 Ok(Self {
46 n_splits,
47 shuffle,
48 random_state,
49 })
50 }
51
52 pub fn split(&self, y: &[usize]) -> UtilsResult<Vec<CVSplit>> {
60 if y.is_empty() {
61 return Err(UtilsError::InvalidParameter(
62 "Cannot split empty label array".to_string(),
63 ));
64 }
65
66 let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
68 for (idx, &label) in y.iter().enumerate() {
69 class_indices.entry(label).or_default().push(idx);
70 }
71
72 for (class, indices) in &class_indices {
74 if indices.len() < self.n_splits {
75 return Err(UtilsError::InvalidParameter(format!(
76 "Class {class} has only {} samples, need at least {} for {}-fold CV",
77 indices.len(),
78 self.n_splits,
79 self.n_splits
80 )));
81 }
82 }
83
84 let mut rng = self
86 .random_state
87 .map(StdRng::seed_from_u64)
88 .unwrap_or_else(|| StdRng::seed_from_u64(42));
89
90 if self.shuffle {
91 for indices in class_indices.values_mut() {
92 Self::shuffle_indices(indices, &mut rng);
93 }
94 }
95
96 let mut fold_indices: Vec<Vec<usize>> = vec![Vec::new(); self.n_splits];
98
99 for indices in class_indices.values() {
100 let fold_sizes = Self::distribute_samples(indices.len(), self.n_splits);
101 let mut current_idx = 0;
102
103 for (fold_id, size) in fold_sizes.iter().enumerate() {
104 fold_indices[fold_id].extend(&indices[current_idx..current_idx + size]);
105 current_idx += size;
106 }
107 }
108
109 let mut splits = Vec::with_capacity(self.n_splits);
111 for test_fold_id in 0..self.n_splits {
112 let mut train = Vec::new();
113 for (fold_id, indices) in fold_indices.iter().enumerate() {
114 if fold_id != test_fold_id {
115 train.extend(indices);
116 }
117 }
118
119 splits.push(CVSplit {
120 train,
121 test: fold_indices[test_fold_id].clone(),
122 });
123 }
124
125 Ok(splits)
126 }
127
128 fn shuffle_indices(indices: &mut [usize], rng: &mut StdRng) {
129 for i in (1..indices.len()).rev() {
130 let j = rng.gen_range(0..=i);
131 indices.swap(i, j);
132 }
133 }
134
135 fn distribute_samples(n_samples: usize, n_folds: usize) -> Vec<usize> {
136 let base_size = n_samples / n_folds;
137 let remainder = n_samples % n_folds;
138
139 (0..n_folds)
140 .map(|i| {
141 if i < remainder {
142 base_size + 1
143 } else {
144 base_size
145 }
146 })
147 .collect()
148 }
149}
150
151#[derive(Clone, Debug)]
156pub struct TimeSeriesSplit {
157 n_splits: usize,
158 test_size: Option<usize>,
159 gap: usize,
160}
161
162impl TimeSeriesSplit {
163 pub fn new(n_splits: usize, test_size: Option<usize>, gap: usize) -> UtilsResult<Self> {
170 if n_splits < 2 {
171 return Err(UtilsError::InvalidParameter(
172 "n_splits must be at least 2".to_string(),
173 ));
174 }
175
176 Ok(Self {
177 n_splits,
178 test_size,
179 gap,
180 })
181 }
182
183 pub fn split(&self, n_samples: usize) -> UtilsResult<Vec<CVSplit>> {
191 let test_size = self.test_size.unwrap_or_else(|| {
192 (n_samples - (n_samples / (self.n_splits + 1))) / self.n_splits
194 });
195
196 let min_train_size = n_samples / (self.n_splits + 1);
197
198 if min_train_size + self.gap + test_size > n_samples {
199 return Err(UtilsError::InvalidParameter(
200 "Not enough samples for requested split configuration".to_string(),
201 ));
202 }
203
204 let mut splits = Vec::with_capacity(self.n_splits);
205
206 for i in 0..self.n_splits {
207 let train_end = min_train_size + i * test_size;
208 let test_start = train_end + self.gap;
209 let test_end = test_start + test_size;
210
211 if test_end > n_samples {
212 break;
213 }
214
215 splits.push(CVSplit {
216 train: (0..train_end).collect(),
217 test: (test_start..test_end).collect(),
218 });
219 }
220
221 if splits.len() < self.n_splits {
222 return Err(UtilsError::InvalidParameter(
223 "Cannot generate requested number of splits with given parameters".to_string(),
224 ));
225 }
226
227 Ok(splits)
228 }
229}
230
231#[derive(Clone, Debug)]
236pub struct GroupKFold {
237 n_splits: usize,
238}
239
240impl GroupKFold {
241 pub fn new(n_splits: usize) -> UtilsResult<Self> {
243 if n_splits < 2 {
244 return Err(UtilsError::InvalidParameter(
245 "n_splits must be at least 2".to_string(),
246 ));
247 }
248
249 Ok(Self { n_splits })
250 }
251
252 pub fn split(&self, groups: &[usize]) -> UtilsResult<Vec<CVSplit>> {
260 if groups.is_empty() {
261 return Err(UtilsError::InvalidParameter(
262 "Cannot split empty groups array".to_string(),
263 ));
264 }
265
266 let mut group_to_indices: HashMap<usize, Vec<usize>> = HashMap::new();
268 for (idx, &group) in groups.iter().enumerate() {
269 group_to_indices.entry(group).or_default().push(idx);
270 }
271
272 let unique_groups: Vec<usize> = group_to_indices.keys().copied().collect();
273
274 if unique_groups.len() < self.n_splits {
275 return Err(UtilsError::InvalidParameter(format!(
276 "Number of unique groups ({}) must be >= n_splits ({})",
277 unique_groups.len(),
278 self.n_splits
279 )));
280 }
281
282 let fold_sizes = Self::distribute_groups(unique_groups.len(), self.n_splits);
284 let mut group_folds: Vec<Vec<usize>> = vec![Vec::new(); self.n_splits];
285 let mut current_idx = 0;
286
287 for (fold_id, size) in fold_sizes.iter().enumerate() {
288 group_folds[fold_id].extend(&unique_groups[current_idx..current_idx + size]);
289 current_idx += size;
290 }
291
292 let mut splits = Vec::with_capacity(self.n_splits);
294
295 for test_fold_id in 0..self.n_splits {
296 let mut train = Vec::new();
297 let mut test = Vec::new();
298
299 for (fold_id, groups_in_fold) in group_folds.iter().enumerate() {
300 let indices: Vec<usize> = groups_in_fold
301 .iter()
302 .flat_map(|g| group_to_indices.get(g).unwrap())
303 .copied()
304 .collect();
305
306 if fold_id == test_fold_id {
307 test.extend(indices);
308 } else {
309 train.extend(indices);
310 }
311 }
312
313 splits.push(CVSplit { train, test });
314 }
315
316 Ok(splits)
317 }
318
319 fn distribute_groups(n_groups: usize, n_folds: usize) -> Vec<usize> {
320 let base_size = n_groups / n_folds;
321 let remainder = n_groups % n_folds;
322
323 (0..n_folds)
324 .map(|i| {
325 if i < remainder {
326 base_size + 1
327 } else {
328 base_size
329 }
330 })
331 .collect()
332 }
333}
334
335#[derive(Clone, Debug)]
339pub struct LeaveOneGroupOut;
340
341impl LeaveOneGroupOut {
342 pub fn new() -> Self {
344 Self
345 }
346
347 pub fn split(&self, groups: &[usize]) -> UtilsResult<Vec<CVSplit>> {
355 if groups.is_empty() {
356 return Err(UtilsError::InvalidParameter(
357 "Cannot split empty groups array".to_string(),
358 ));
359 }
360
361 let mut group_to_indices: HashMap<usize, Vec<usize>> = HashMap::new();
363 for (idx, &group) in groups.iter().enumerate() {
364 group_to_indices.entry(group).or_default().push(idx);
365 }
366
367 let unique_groups: Vec<usize> = group_to_indices.keys().copied().collect();
368 let mut splits = Vec::with_capacity(unique_groups.len());
369
370 for &test_group in &unique_groups {
371 let mut train = Vec::new();
372 let test = group_to_indices.get(&test_group).unwrap().clone();
373
374 for &group in &unique_groups {
375 if group != test_group {
376 train.extend(group_to_indices.get(&group).unwrap());
377 }
378 }
379
380 splits.push(CVSplit { train, test });
381 }
382
383 Ok(splits)
384 }
385}
386
387impl Default for LeaveOneGroupOut {
388 fn default() -> Self {
389 Self::new()
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_stratified_kfold_basic() {
399 let y = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
400 let skf = StratifiedKFold::new(3, false, Some(42)).unwrap();
401 let splits = skf.split(&y).unwrap();
402
403 assert_eq!(splits.len(), 3);
404
405 for split in &splits {
407 let test_labels: Vec<usize> = split.test.iter().map(|&i| y[i]).collect();
408 assert!(test_labels.contains(&0));
409 assert!(test_labels.contains(&1));
410 assert!(test_labels.contains(&2));
411 }
412 }
413
414 #[test]
415 fn test_stratified_kfold_all_samples_used() {
416 let y = vec![0, 0, 1, 1, 2, 2];
417 let skf = StratifiedKFold::new(2, false, None).unwrap();
418 let splits = skf.split(&y).unwrap();
419
420 assert_eq!(splits.len(), 2);
421
422 let mut all_test_indices: Vec<usize> = Vec::new();
423 for split in &splits {
424 all_test_indices.extend(&split.test);
425 }
426 all_test_indices.sort_unstable();
427
428 assert_eq!(all_test_indices, vec![0, 1, 2, 3, 4, 5]);
429 }
430
431 #[test]
432 fn test_time_series_split_basic() {
433 let tscv = TimeSeriesSplit::new(3, Some(2), 0).unwrap();
434 let splits = tscv.split(10).unwrap();
435
436 assert_eq!(splits.len(), 3);
437
438 for (i, split) in splits.iter().enumerate() {
440 assert!(split.train.len() > 0);
441 assert_eq!(split.test.len(), 2);
442 if i > 0 {
443 assert!(split.train.len() > splits[i - 1].train.len());
444 }
445 }
446 }
447
448 #[test]
449 fn test_time_series_split_with_gap() {
450 let tscv = TimeSeriesSplit::new(2, Some(2), 1).unwrap();
451 let splits = tscv.split(10).unwrap();
452
453 for split in &splits {
454 if !split.train.is_empty() && !split.test.is_empty() {
456 let train_max = *split.train.iter().max().unwrap();
457 let test_min = *split.test.iter().min().unwrap();
458 assert!(test_min > train_max); }
460 }
461 }
462
463 #[test]
464 fn test_group_kfold_basic() {
465 let groups = vec![0, 0, 1, 1, 2, 2, 3, 3];
466 let gkf = GroupKFold::new(2).unwrap();
467 let splits = gkf.split(&groups).unwrap();
468
469 assert_eq!(splits.len(), 2);
470
471 for split in &splits {
473 let train_groups: Vec<usize> = split.train.iter().map(|&i| groups[i]).collect();
474 let test_groups: Vec<usize> = split.test.iter().map(|&i| groups[i]).collect();
475
476 for &test_group in &test_groups {
478 assert!(!train_groups.contains(&test_group));
479 }
480 }
481 }
482
483 #[test]
484 fn test_leave_one_group_out() {
485 let groups = vec![0, 0, 1, 1, 2, 2];
486 let logo = LeaveOneGroupOut::new();
487 let splits = logo.split(&groups).unwrap();
488
489 assert_eq!(splits.len(), 3); for split in &splits {
493 let test_groups: Vec<usize> = split.test.iter().map(|&idx| groups[idx]).collect();
494 let unique_test_groups: std::collections::HashSet<usize> =
495 test_groups.into_iter().collect();
496 assert_eq!(unique_test_groups.len(), 1);
497 }
498 }
499
500 #[test]
501 fn test_stratified_kfold_error_too_few_samples() {
502 let y = vec![0, 1]; let skf = StratifiedKFold::new(3, false, None).unwrap();
504 assert!(skf.split(&y).is_err());
505 }
506
507 #[test]
508 fn test_time_series_split_error_insufficient_samples() {
509 let tscv = TimeSeriesSplit::new(5, Some(10), 0).unwrap();
510 assert!(tscv.split(20).is_err());
511 }
512}