sklears_model_selection/cv/
basic_cv.rs1use super::CrossValidator;
4use scirs2_core::ndarray::Array1;
5use scirs2_core::random::rngs::StdRng;
6use scirs2_core::random::SeedableRng;
7use scirs2_core::SliceRandomExt;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub struct KFold {
13 n_splits: usize,
14 shuffle: bool,
15 random_state: Option<u64>,
16}
17
18impl KFold {
19 pub fn new(n_splits: usize) -> Self {
21 assert!(n_splits >= 2, "n_splits must be at least 2");
22 Self {
23 n_splits,
24 shuffle: false,
25 random_state: None,
26 }
27 }
28
29 pub fn shuffle(mut self, shuffle: bool) -> Self {
31 self.shuffle = shuffle;
32 self
33 }
34
35 pub fn random_state(mut self, seed: u64) -> Self {
37 self.random_state = Some(seed);
38 self
39 }
40
41 fn calculate_fold_sizes(&self, n_samples: usize) -> Vec<usize> {
43 let min_fold_size = n_samples / self.n_splits;
44 let n_larger_folds = n_samples % self.n_splits;
45
46 let mut fold_sizes = vec![min_fold_size; self.n_splits];
47 for fold_size in fold_sizes.iter_mut().take(n_larger_folds) {
48 *fold_size += 1;
49 }
50
51 fold_sizes
52 }
53}
54
55impl CrossValidator for KFold {
56 fn n_splits(&self) -> usize {
57 self.n_splits
58 }
59
60 fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
61 assert!(
62 self.n_splits <= n_samples,
63 "Cannot have number of splits {} greater than the number of samples {}",
64 self.n_splits,
65 n_samples
66 );
67
68 let mut indices: Vec<usize> = (0..n_samples).collect();
70
71 if self.shuffle {
73 let mut rng = match self.random_state {
74 Some(seed) => StdRng::seed_from_u64(seed),
75 None => {
76 use scirs2_core::random::thread_rng;
77 StdRng::from_rng(&mut thread_rng())
78 }
79 };
80 indices.shuffle(&mut rng);
81 }
82
83 let mut splits = Vec::new();
85 let fold_sizes = self.calculate_fold_sizes(n_samples);
86 let mut current = 0;
87
88 for fold_size in fold_sizes.iter().take(self.n_splits) {
89 let test_start = current;
90 let test_end = current + *fold_size;
91
92 let test_indices: Vec<usize> = indices[test_start..test_end].to_vec();
93 let train_indices: Vec<usize> = indices[..test_start]
94 .iter()
95 .chain(indices[test_end..].iter())
96 .cloned()
97 .collect();
98
99 splits.push((train_indices, test_indices));
100 current = test_end;
101 }
102
103 splits
104 }
105}
106
107#[derive(Debug, Clone)]
109pub struct StratifiedKFold {
110 n_splits: usize,
111 shuffle: bool,
112 random_state: Option<u64>,
113}
114
115impl StratifiedKFold {
116 pub fn new(n_splits: usize) -> Self {
118 assert!(n_splits >= 2, "n_splits must be at least 2");
119 Self {
120 n_splits,
121 shuffle: false,
122 random_state: None,
123 }
124 }
125
126 pub fn shuffle(mut self, shuffle: bool) -> Self {
128 self.shuffle = shuffle;
129 self
130 }
131
132 pub fn random_state(mut self, seed: u64) -> Self {
134 self.random_state = Some(seed);
135 self
136 }
137
138 fn calculate_fold_sizes(&self, n_samples: usize) -> Vec<usize> {
140 let min_fold_size = n_samples / self.n_splits;
141 let n_larger_folds = n_samples % self.n_splits;
142
143 let mut fold_sizes = vec![min_fold_size; self.n_splits];
144 for fold_size in fold_sizes.iter_mut().take(n_larger_folds) {
145 *fold_size += 1;
146 }
147
148 fold_sizes
149 }
150}
151
152impl CrossValidator for StratifiedKFold {
153 fn n_splits(&self) -> usize {
154 self.n_splits
155 }
156
157 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
158 let y = y.expect("StratifiedKFold requires y to be provided");
159 assert_eq!(
160 y.len(),
161 n_samples,
162 "y must have the same length as n_samples"
163 );
164 assert!(
165 self.n_splits <= n_samples,
166 "Cannot have number of splits {} greater than the number of samples {}",
167 self.n_splits,
168 n_samples
169 );
170
171 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
173 for (idx, &label) in y.iter().enumerate() {
174 class_indices.entry(label).or_default().push(idx);
175 }
176
177 for indices in class_indices.values() {
179 assert!(
180 indices.len() >= self.n_splits,
181 "The least populated class has only {} members, which is less than n_splits={}",
182 indices.len(),
183 self.n_splits
184 );
185 }
186
187 if self.shuffle {
189 let mut rng = match self.random_state {
190 Some(seed) => StdRng::seed_from_u64(seed),
191 None => {
192 use scirs2_core::random::thread_rng;
193 StdRng::from_rng(&mut thread_rng())
194 }
195 };
196 for indices in class_indices.values_mut() {
197 indices.shuffle(&mut rng);
198 }
199 }
200
201 let mut splits = vec![(Vec::new(), Vec::new()); self.n_splits];
203
204 for (_class, indices) in class_indices {
205 let fold_sizes = self.calculate_fold_sizes(indices.len());
206 let mut current = 0;
207
208 for i in 0..self.n_splits {
209 let fold_size = fold_sizes[i];
210 let test_end = current + fold_size;
211
212 splits[i].1.extend(&indices[current..test_end]);
214
215 for (j, split) in splits.iter_mut().enumerate().take(self.n_splits) {
217 if i != j {
218 split.0.extend(&indices[current..test_end]);
219 }
220 }
221
222 current = test_end;
223 }
224 }
225
226 splits
227 }
228}
229
230#[derive(Debug, Clone)]
232pub struct LeaveOneOut;
233
234impl Default for LeaveOneOut {
235 fn default() -> Self {
236 Self::new()
237 }
238}
239
240impl LeaveOneOut {
241 pub fn new() -> Self {
243 LeaveOneOut
244 }
245}
246
247impl CrossValidator for LeaveOneOut {
248 fn n_splits(&self) -> usize {
249 0 }
252
253 fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
254 let mut splits = Vec::new();
255
256 for i in 0..n_samples {
257 let test_indices = vec![i];
258 let train_indices: Vec<usize> = (0..i).chain(i + 1..n_samples).collect();
259 splits.push((train_indices, test_indices));
260 }
261
262 splits
263 }
264}
265
266#[derive(Debug, Clone)]
270pub struct LeavePOut {
271 p: usize,
272}
273
274impl LeavePOut {
275 pub fn new(p: usize) -> Self {
277 assert!(p >= 1, "p must be at least 1");
278 Self { p }
279 }
280
281 pub fn p(&self) -> usize {
283 self.p
284 }
285}
286
287impl CrossValidator for LeavePOut {
288 fn n_splits(&self) -> usize {
289 0 }
292
293 fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
294 assert!(
295 self.p <= n_samples,
296 "p ({}) cannot be greater than the number of samples ({})",
297 self.p,
298 n_samples
299 );
300
301 let mut splits = Vec::new();
302 let all_indices: Vec<usize> = (0..n_samples).collect();
303
304 for test_indices in combinations(&all_indices, self.p) {
306 let test_set: std::collections::HashSet<usize> = test_indices.iter().cloned().collect();
307 let train_indices: Vec<usize> = all_indices
308 .iter()
309 .cloned()
310 .filter(|&i| !test_set.contains(&i))
311 .collect();
312
313 splits.push((train_indices, test_indices));
314 }
315
316 splits
317 }
318}
319
320fn combinations<T: Clone>(items: &[T], k: usize) -> Vec<Vec<T>> {
322 if k == 0 {
323 return vec![vec![]];
324 }
325 if k > items.len() {
326 return vec![];
327 }
328 if k == items.len() {
329 return vec![items.to_vec()];
330 }
331
332 let mut result = Vec::new();
333
334 let with_first = combinations(&items[1..], k - 1);
336 for mut combo in with_first {
337 combo.insert(0, items[0].clone());
338 result.push(combo);
339 }
340
341 let without_first = combinations(&items[1..], k);
343 result.extend(without_first);
344
345 result
346}