sklears_model_selection/cv/
custom_cv.rs1use scirs2_core::ndarray::Array1;
7
8use crate::CrossValidator;
9
10type SplitFn =
42 Box<dyn Fn(usize, Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> + Send + Sync>;
43
44pub struct CustomCrossValidator {
45 n_splits: usize,
46 split_fn: SplitFn,
47}
48
49impl CustomCrossValidator {
50 pub fn new<F>(n_splits: usize, split_fn: F) -> Self
56 where
57 F: Fn(usize, Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> + Send + Sync + 'static,
58 {
59 Self {
60 n_splits,
61 split_fn: Box::new(split_fn),
62 }
63 }
64}
65
66impl CrossValidator for CustomCrossValidator {
67 fn n_splits(&self) -> usize {
68 self.n_splits
69 }
70
71 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
72 (self.split_fn)(n_samples, y)
73 }
74}
75
76#[derive(Debug, Clone)]
91pub struct BlockCrossValidator {
92 n_splits: usize,
93 test_size: Option<usize>,
94 gap: usize,
95}
96
97impl BlockCrossValidator {
98 pub fn new(n_splits: usize) -> Self {
103 assert!(n_splits >= 2, "n_splits must be at least 2");
104 Self {
105 n_splits,
106 test_size: None,
107 gap: 0,
108 }
109 }
110
111 pub fn test_size(mut self, test_size: usize) -> Self {
113 self.test_size = Some(test_size);
114 self
115 }
116
117 pub fn gap(mut self, gap: usize) -> Self {
119 self.gap = gap;
120 self
121 }
122}
123
124impl CrossValidator for BlockCrossValidator {
125 fn n_splits(&self) -> usize {
126 self.n_splits
127 }
128
129 fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
130 let test_size = self.test_size.unwrap_or(n_samples / self.n_splits);
131
132 assert!(
133 test_size * self.n_splits <= n_samples,
134 "Test size too large for number of samples"
135 );
136
137 let mut splits = Vec::new();
138
139 for fold in 0..self.n_splits {
140 let test_start = fold * test_size;
141 let test_end = (test_start + test_size).min(n_samples);
142
143 let train_end = test_start.saturating_sub(self.gap);
145
146 let train_indices: Vec<usize> = (0..train_end).collect();
147 let test_indices: Vec<usize> = (test_start..test_end).collect();
148
149 if !train_indices.is_empty() && !test_indices.is_empty() {
150 splits.push((train_indices, test_indices));
151 }
152 }
153
154 splits
155 }
156}
157
158#[derive(Debug, Clone)]
164pub struct PredefinedSplit {
165 test_fold: Array1<i32>,
166}
167
168impl PredefinedSplit {
169 pub fn new(test_fold: Array1<i32>) -> Self {
175 for &fold in test_fold.iter() {
177 assert!(
178 fold >= -1,
179 "test_fold values must be -1 or non-negative, got {fold}"
180 );
181 }
182 Self { test_fold }
183 }
184
185 fn get_n_splits(&self) -> usize {
187 let unique_folds: std::collections::HashSet<i32> = self
188 .test_fold
189 .iter()
190 .filter(|&&x| x >= 0)
191 .cloned()
192 .collect();
193 unique_folds.len()
194 }
195}
196
197impl CrossValidator for PredefinedSplit {
198 fn n_splits(&self) -> usize {
199 self.get_n_splits()
200 }
201
202 fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
203 assert_eq!(
204 self.test_fold.len(),
205 n_samples,
206 "test_fold must have the same length as n_samples"
207 );
208
209 let mut unique_folds: Vec<i32> = self
211 .test_fold
212 .iter()
213 .filter(|&&x| x >= 0)
214 .cloned()
215 .collect::<std::collections::HashSet<_>>()
216 .into_iter()
217 .collect();
218 unique_folds.sort();
219
220 let mut splits = Vec::new();
221
222 for &fold in &unique_folds {
223 let mut train_indices = Vec::new();
224 let mut test_indices = Vec::new();
225
226 for (idx, &sample_fold) in self.test_fold.iter().enumerate() {
227 if sample_fold == fold {
228 test_indices.push(idx);
229 } else if sample_fold == -1 || sample_fold != fold {
230 train_indices.push(idx);
231 }
232 }
233
234 splits.push((train_indices, test_indices));
235 }
236
237 splits
238 }
239}
240
241#[allow(non_snake_case)]
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use scirs2_core::ndarray::array;
246
247 #[test]
248 fn test_custom_cross_validator() {
249 let custom_cv = CustomCrossValidator::new(2, |n_samples, _y| {
251 let mut splits = Vec::new();
252
253 let train1: Vec<usize> = (0..n_samples).filter(|&i| i % 2 == 0).collect();
255 let test1: Vec<usize> = (0..n_samples).filter(|&i| i % 2 == 1).collect();
256 splits.push((train1, test1));
257
258 let train2: Vec<usize> = (0..n_samples).filter(|&i| i % 2 == 1).collect();
260 let test2: Vec<usize> = (0..n_samples).filter(|&i| i % 2 == 0).collect();
261 splits.push((train2, test2));
262
263 splits
264 });
265
266 let splits = custom_cv.split(6, None);
267 assert_eq!(splits.len(), 2);
268
269 assert_eq!(splits[0].0, vec![0, 2, 4]); assert_eq!(splits[0].1, vec![1, 3, 5]); assert_eq!(splits[1].0, vec![1, 3, 5]); assert_eq!(splits[1].1, vec![0, 2, 4]); }
277
278 #[test]
279 fn test_block_cross_validator() {
280 let block_cv = BlockCrossValidator::new(3).test_size(2);
281 let splits = block_cv.split(8, None);
282
283 assert_eq!(splits.len(), 2);
285
286 assert_eq!(splits[0].0, vec![0, 1]);
288 assert_eq!(splits[0].1, vec![2, 3]);
289
290 assert_eq!(splits[1].0, vec![0, 1, 2, 3]);
292 assert_eq!(splits[1].1, vec![4, 5]);
293 }
294
295 #[test]
296 fn test_block_cross_validator_with_gap() {
297 let block_cv = BlockCrossValidator::new(3).test_size(2).gap(1);
298 let splits = block_cv.split(8, None);
299
300 assert_eq!(splits.len(), 2); assert_eq!(splits[0].0, vec![0]);
305 assert_eq!(splits[0].1, vec![2, 3]);
306
307 assert_eq!(splits[1].0, vec![0, 1, 2]);
309 assert_eq!(splits[1].1, vec![4, 5]);
310 }
311
312 #[test]
313 fn test_predefined_split() {
314 let test_fold = array![-1, 0, 0, 1, 1, 2, 2, -1];
316 let cv = PredefinedSplit::new(test_fold);
317
318 assert_eq!(cv.n_splits(), 3);
319
320 let splits = cv.split(8, None::<&Array1<i32>>);
321 assert_eq!(splits.len(), 3);
322
323 assert_eq!(splits[0].1, vec![1, 2]); assert!(splits[0].0.contains(&0)); assert!(splits[0].0.contains(&7)); assert_eq!(splits[1].1, vec![3, 4]); assert_eq!(splits[2].1, vec![5, 6]); for (train, _) in &splits {
336 assert!(train.contains(&0));
337 assert!(train.contains(&7));
338 }
339 }
340
341 #[test]
342 fn test_predefined_split_edge_cases() {
343 let test_fold = array![-1, -1, -1, -1];
345 let cv = PredefinedSplit::new(test_fold);
346 assert_eq!(cv.n_splits(), 0);
347 let splits = cv.split(4, None::<&Array1<i32>>);
348 assert_eq!(splits.len(), 0);
349
350 let test_fold = array![0, 0, -1, -1];
352 let cv = PredefinedSplit::new(test_fold);
353 assert_eq!(cv.n_splits(), 1);
354 let splits = cv.split(4, None::<&Array1<i32>>);
355 assert_eq!(splits.len(), 1);
356 assert_eq!(splits[0].1, vec![0, 1]);
357 assert_eq!(splits[0].0, vec![2, 3]);
358 }
359}