sklears_model_selection/cv/
basic_cv.rs

1//! Basic cross-validation iterators
2
3use super::CrossValidator;
4use scirs2_core::ndarray::Array1;
5use scirs2_core::random::rngs::StdRng;
6use scirs2_core::random::SeedableRng;
7use scirs2_core::SliceRandomExt;
8use std::collections::HashMap;
9
10/// K-Fold cross-validation iterator
11#[derive(Debug, Clone)]
12pub struct KFold {
13    n_splits: usize,
14    shuffle: bool,
15    random_state: Option<u64>,
16}
17
18impl KFold {
19    /// Create a new KFold cross-validator
20    pub fn new(n_splits: usize) -> Self {
21        assert!(n_splits >= 2, "n_splits must be at least 2");
22        Self {
23            n_splits,
24            shuffle: false,
25            random_state: None,
26        }
27    }
28
29    /// Set whether to shuffle the data before splitting
30    pub fn shuffle(mut self, shuffle: bool) -> Self {
31        self.shuffle = shuffle;
32        self
33    }
34
35    /// Set the random state for shuffling
36    pub fn random_state(mut self, seed: u64) -> Self {
37        self.random_state = Some(seed);
38        self
39    }
40
41    /// Calculate the size of each fold
42    fn calculate_fold_sizes(&self, n_samples: usize) -> Vec<usize> {
43        let min_fold_size = n_samples / self.n_splits;
44        let n_larger_folds = n_samples % self.n_splits;
45
46        let mut fold_sizes = vec![min_fold_size; self.n_splits];
47        for fold_size in fold_sizes.iter_mut().take(n_larger_folds) {
48            *fold_size += 1;
49        }
50
51        fold_sizes
52    }
53}
54
55impl CrossValidator for KFold {
56    fn n_splits(&self) -> usize {
57        self.n_splits
58    }
59
60    fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
61        assert!(
62            self.n_splits <= n_samples,
63            "Cannot have number of splits {} greater than the number of samples {}",
64            self.n_splits,
65            n_samples
66        );
67
68        // Create indices
69        let mut indices: Vec<usize> = (0..n_samples).collect();
70
71        // Shuffle if requested
72        if self.shuffle {
73            let mut rng = match self.random_state {
74                Some(seed) => StdRng::seed_from_u64(seed),
75                None => {
76                    use scirs2_core::random::thread_rng;
77                    StdRng::from_rng(&mut thread_rng())
78                }
79            };
80            indices.shuffle(&mut rng);
81        }
82
83        // Generate train/test splits
84        let mut splits = Vec::new();
85        let fold_sizes = self.calculate_fold_sizes(n_samples);
86        let mut current = 0;
87
88        for fold_size in fold_sizes.iter().take(self.n_splits) {
89            let test_start = current;
90            let test_end = current + *fold_size;
91
92            let test_indices: Vec<usize> = indices[test_start..test_end].to_vec();
93            let train_indices: Vec<usize> = indices[..test_start]
94                .iter()
95                .chain(indices[test_end..].iter())
96                .cloned()
97                .collect();
98
99            splits.push((train_indices, test_indices));
100            current = test_end;
101        }
102
103        splits
104    }
105}
106
107/// Stratified K-Fold cross-validation iterator
108#[derive(Debug, Clone)]
109pub struct StratifiedKFold {
110    n_splits: usize,
111    shuffle: bool,
112    random_state: Option<u64>,
113}
114
115impl StratifiedKFold {
116    /// Create a new StratifiedKFold cross-validator
117    pub fn new(n_splits: usize) -> Self {
118        assert!(n_splits >= 2, "n_splits must be at least 2");
119        Self {
120            n_splits,
121            shuffle: false,
122            random_state: None,
123        }
124    }
125
126    /// Set whether to shuffle the data before splitting
127    pub fn shuffle(mut self, shuffle: bool) -> Self {
128        self.shuffle = shuffle;
129        self
130    }
131
132    /// Set the random state for shuffling
133    pub fn random_state(mut self, seed: u64) -> Self {
134        self.random_state = Some(seed);
135        self
136    }
137
138    /// Calculate the size of each fold
139    fn calculate_fold_sizes(&self, n_samples: usize) -> Vec<usize> {
140        let min_fold_size = n_samples / self.n_splits;
141        let n_larger_folds = n_samples % self.n_splits;
142
143        let mut fold_sizes = vec![min_fold_size; self.n_splits];
144        for fold_size in fold_sizes.iter_mut().take(n_larger_folds) {
145            *fold_size += 1;
146        }
147
148        fold_sizes
149    }
150}
151
152impl CrossValidator for StratifiedKFold {
153    fn n_splits(&self) -> usize {
154        self.n_splits
155    }
156
157    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
158        let y = y.expect("StratifiedKFold requires y to be provided");
159        assert_eq!(
160            y.len(),
161            n_samples,
162            "y must have the same length as n_samples"
163        );
164        assert!(
165            self.n_splits <= n_samples,
166            "Cannot have number of splits {} greater than the number of samples {}",
167            self.n_splits,
168            n_samples
169        );
170
171        // Group indices by class
172        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
173        for (idx, &label) in y.iter().enumerate() {
174            class_indices.entry(label).or_default().push(idx);
175        }
176
177        // Check we have enough samples in each class
178        for indices in class_indices.values() {
179            assert!(
180                indices.len() >= self.n_splits,
181                "The least populated class has only {} members, which is less than n_splits={}",
182                indices.len(),
183                self.n_splits
184            );
185        }
186
187        // Shuffle within each class if requested
188        if self.shuffle {
189            let mut rng = match self.random_state {
190                Some(seed) => StdRng::seed_from_u64(seed),
191                None => {
192                    use scirs2_core::random::thread_rng;
193                    StdRng::from_rng(&mut thread_rng())
194                }
195            };
196            for indices in class_indices.values_mut() {
197                indices.shuffle(&mut rng);
198            }
199        }
200
201        // Create stratified folds
202        let mut splits = vec![(Vec::new(), Vec::new()); self.n_splits];
203
204        for (_class, indices) in class_indices {
205            let fold_sizes = self.calculate_fold_sizes(indices.len());
206            let mut current = 0;
207
208            for i in 0..self.n_splits {
209                let fold_size = fold_sizes[i];
210                let test_end = current + fold_size;
211
212                // Add to test set for this fold
213                splits[i].1.extend(&indices[current..test_end]);
214
215                // Add to train sets for other folds
216                for (j, split) in splits.iter_mut().enumerate().take(self.n_splits) {
217                    if i != j {
218                        split.0.extend(&indices[current..test_end]);
219                    }
220                }
221
222                current = test_end;
223            }
224        }
225
226        splits
227    }
228}
229
230/// Leave-One-Out cross-validator
231#[derive(Debug, Clone)]
232pub struct LeaveOneOut;
233
234impl Default for LeaveOneOut {
235    fn default() -> Self {
236        Self::new()
237    }
238}
239
240impl LeaveOneOut {
241    /// Create a new LeaveOneOut cross-validator
242    pub fn new() -> Self {
243        LeaveOneOut
244    }
245}
246
247impl CrossValidator for LeaveOneOut {
248    fn n_splits(&self) -> usize {
249        // This is dynamic based on the number of samples
250        0 // Will be determined during split
251    }
252
253    fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
254        let mut splits = Vec::new();
255
256        for i in 0..n_samples {
257            let test_indices = vec![i];
258            let train_indices: Vec<usize> = (0..i).chain(i + 1..n_samples).collect();
259            splits.push((train_indices, test_indices));
260        }
261
262        splits
263    }
264}
265
266/// Leave-P-Out cross-validator
267///
268/// Provides test sets by taking all possible combinations of p samples
269#[derive(Debug, Clone)]
270pub struct LeavePOut {
271    p: usize,
272}
273
274impl LeavePOut {
275    /// Create a new LeavePOut cross-validator
276    pub fn new(p: usize) -> Self {
277        assert!(p >= 1, "p must be at least 1");
278        Self { p }
279    }
280
281    /// Get the value of p
282    pub fn p(&self) -> usize {
283        self.p
284    }
285}
286
287impl CrossValidator for LeavePOut {
288    fn n_splits(&self) -> usize {
289        // This is dynamic based on the number of samples
290        0 // Will be determined during split
291    }
292
293    fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
294        assert!(
295            self.p <= n_samples,
296            "p ({}) cannot be greater than the number of samples ({})",
297            self.p,
298            n_samples
299        );
300
301        let mut splits = Vec::new();
302        let all_indices: Vec<usize> = (0..n_samples).collect();
303
304        // Generate all combinations of p indices
305        for test_indices in combinations(&all_indices, self.p) {
306            let test_set: std::collections::HashSet<usize> = test_indices.iter().cloned().collect();
307            let train_indices: Vec<usize> = all_indices
308                .iter()
309                .cloned()
310                .filter(|&i| !test_set.contains(&i))
311                .collect();
312
313            splits.push((train_indices, test_indices));
314        }
315
316        splits
317    }
318}
319
320/// Generate all combinations of k elements from a vector
321fn combinations<T: Clone>(items: &[T], k: usize) -> Vec<Vec<T>> {
322    if k == 0 {
323        return vec![vec![]];
324    }
325    if k > items.len() {
326        return vec![];
327    }
328    if k == items.len() {
329        return vec![items.to_vec()];
330    }
331
332    let mut result = Vec::new();
333
334    // Include first item
335    let with_first = combinations(&items[1..], k - 1);
336    for mut combo in with_first {
337        combo.insert(0, items[0].clone());
338        result.push(combo);
339    }
340
341    // Exclude first item
342    let without_first = combinations(&items[1..], k);
343    result.extend(without_first);
344
345    result
346}