sklears_model_selection/cv/
custom_cv.rs

1//! Custom cross-validation implementations
2//!
3//! This module provides flexible cross-validation strategies that allow users to define
4//! custom splitting logic or use predefined splits for specific use cases.
5
6use scirs2_core::ndarray::Array1;
7
8use crate::CrossValidator;
9
10/// Custom cross-validation iterator that allows users to define their own splitting logic
11///
12/// This provides a flexible way to implement custom cross-validation strategies
13/// by allowing users to provide their own splitting function.
14///
15/// # Example
16/// ```rust
17/// use sklears_model_selection::{CustomCrossValidator, CrossValidator};
18/// use scirs2_core::ndarray::Array1;
19///
20/// // Example: Simple 2-fold split that alternates samples
21/// let custom_cv = CustomCrossValidator::new(
22///     2,
23///     Box::new(|n_samples, _y: Option<&Array1<i32>>| {
24///         let mut splits = Vec::new();
25///         
26///         // First fold: even indices for train, odd for test
27///         let train1: `Vec<usize>` = (0..n_samples).filter(|&i| i % 2 == 0).collect();
28///         let test1: `Vec<usize>` = (0..n_samples).filter(|&i| i % 2 == 1).collect();
29///         splits.push((train1, test1));
30///         
31///         // Second fold: odd indices for train, even for test
32///         let train2: `Vec<usize>` = (0..n_samples).filter(|&i| i % 2 == 1).collect();
33///         let test2: `Vec<usize>` = (0..n_samples).filter(|&i| i % 2 == 0).collect();
34///         splits.push((train2, test2));
35///         
36///         splits
37///     })
38/// );
39/// ```
40/// Type alias for the split function
41type SplitFn =
42    Box<dyn Fn(usize, Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> + Send + Sync>;
43
44pub struct CustomCrossValidator {
45    n_splits: usize,
46    split_fn: SplitFn,
47}
48
49impl CustomCrossValidator {
50    /// Create a new custom cross-validator
51    ///
52    /// # Arguments
53    /// * `n_splits` - Number of splits this validator will generate
54    /// * `split_fn` - Function that takes n_samples and optional labels and returns train/test index pairs
55    pub fn new<F>(n_splits: usize, split_fn: F) -> Self
56    where
57        F: Fn(usize, Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> + Send + Sync + 'static,
58    {
59        Self {
60            n_splits,
61            split_fn: Box::new(split_fn),
62        }
63    }
64}
65
66impl CrossValidator for CustomCrossValidator {
67    fn n_splits(&self) -> usize {
68        self.n_splits
69    }
70
71    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
72        (self.split_fn)(n_samples, y)
73    }
74}
75
76/// Block cross-validation for time series or sequential data
77///
78/// Splits the data into blocks where each fold uses a contiguous block for testing
79/// and the preceding data for training. This is useful for time series data where
80/// we want to respect temporal order.
81///
82/// # Example
83/// ```rust
84/// use sklears_model_selection::{BlockCrossValidator, CrossValidator};
85///
86/// let block_cv = BlockCrossValidator::new(3);
87/// let splits = block_cv.split(12, None);
88/// // This would create 3 folds with blocks of 4 samples each
89/// ```
90#[derive(Debug, Clone)]
91pub struct BlockCrossValidator {
92    n_splits: usize,
93    test_size: Option<usize>,
94    gap: usize,
95}
96
97impl BlockCrossValidator {
98    /// Create a new block cross-validator
99    ///
100    /// # Arguments
101    /// * `n_splits` - Number of blocks to create
102    pub fn new(n_splits: usize) -> Self {
103        assert!(n_splits >= 2, "n_splits must be at least 2");
104        Self {
105            n_splits,
106            test_size: None,
107            gap: 0,
108        }
109    }
110
111    /// Set the size of each test block
112    pub fn test_size(mut self, test_size: usize) -> Self {
113        self.test_size = Some(test_size);
114        self
115    }
116
117    /// Set the gap between train and test sets to avoid data leakage
118    pub fn gap(mut self, gap: usize) -> Self {
119        self.gap = gap;
120        self
121    }
122}
123
124impl CrossValidator for BlockCrossValidator {
125    fn n_splits(&self) -> usize {
126        self.n_splits
127    }
128
129    fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
130        let test_size = self.test_size.unwrap_or(n_samples / self.n_splits);
131
132        assert!(
133            test_size * self.n_splits <= n_samples,
134            "Test size too large for number of samples"
135        );
136
137        let mut splits = Vec::new();
138
139        for fold in 0..self.n_splits {
140            let test_start = fold * test_size;
141            let test_end = (test_start + test_size).min(n_samples);
142
143            // Train on all data before test set (with gap)
144            let train_end = test_start.saturating_sub(self.gap);
145
146            let train_indices: Vec<usize> = (0..train_end).collect();
147            let test_indices: Vec<usize> = (test_start..test_end).collect();
148
149            if !train_indices.is_empty() && !test_indices.is_empty() {
150                splits.push((train_indices, test_indices));
151            }
152        }
153
154        splits
155    }
156}
157
158/// Predefined Split cross-validator
159///
160/// Uses user-provided split indices to generate train/test splits.
161/// This is useful when you have a specific validation strategy or when
162/// cross-validation folds are predefined by the problem.
163#[derive(Debug, Clone)]
164pub struct PredefinedSplit {
165    test_fold: Array1<i32>,
166}
167
168impl PredefinedSplit {
169    /// Create a new PredefinedSplit cross-validator
170    ///
171    /// # Arguments
172    /// * `test_fold` - Array where test_fold\[i\] is the fold number for sample i.
173    ///   A value of -1 indicates that the sample should always be in the training set.
174    pub fn new(test_fold: Array1<i32>) -> Self {
175        // Validate that fold indices are valid (-1 or non-negative)
176        for &fold in test_fold.iter() {
177            assert!(
178                fold >= -1,
179                "test_fold values must be -1 or non-negative, got {fold}"
180            );
181        }
182        Self { test_fold }
183    }
184
185    /// Get the number of unique folds (excluding -1)
186    fn get_n_splits(&self) -> usize {
187        let unique_folds: std::collections::HashSet<i32> = self
188            .test_fold
189            .iter()
190            .filter(|&&x| x >= 0)
191            .cloned()
192            .collect();
193        unique_folds.len()
194    }
195}
196
197impl CrossValidator for PredefinedSplit {
198    fn n_splits(&self) -> usize {
199        self.get_n_splits()
200    }
201
202    fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
203        assert_eq!(
204            self.test_fold.len(),
205            n_samples,
206            "test_fold must have the same length as n_samples"
207        );
208
209        // Get unique fold numbers (excluding -1)
210        let mut unique_folds: Vec<i32> = self
211            .test_fold
212            .iter()
213            .filter(|&&x| x >= 0)
214            .cloned()
215            .collect::<std::collections::HashSet<_>>()
216            .into_iter()
217            .collect();
218        unique_folds.sort();
219
220        let mut splits = Vec::new();
221
222        for &fold in &unique_folds {
223            let mut train_indices = Vec::new();
224            let mut test_indices = Vec::new();
225
226            for (idx, &sample_fold) in self.test_fold.iter().enumerate() {
227                if sample_fold == fold {
228                    test_indices.push(idx);
229                } else if sample_fold == -1 || sample_fold != fold {
230                    train_indices.push(idx);
231                }
232            }
233
234            splits.push((train_indices, test_indices));
235        }
236
237        splits
238    }
239}
240
241#[allow(non_snake_case)]
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use scirs2_core::ndarray::array;
246
247    #[test]
248    fn test_custom_cross_validator() {
249        // Create a custom CV that alternates samples
250        let custom_cv = CustomCrossValidator::new(2, |n_samples, _y| {
251            let mut splits = Vec::new();
252
253            // First fold: even indices for train, odd for test
254            let train1: Vec<usize> = (0..n_samples).filter(|&i| i % 2 == 0).collect();
255            let test1: Vec<usize> = (0..n_samples).filter(|&i| i % 2 == 1).collect();
256            splits.push((train1, test1));
257
258            // Second fold: odd indices for train, even for test
259            let train2: Vec<usize> = (0..n_samples).filter(|&i| i % 2 == 1).collect();
260            let test2: Vec<usize> = (0..n_samples).filter(|&i| i % 2 == 0).collect();
261            splits.push((train2, test2));
262
263            splits
264        });
265
266        let splits = custom_cv.split(6, None);
267        assert_eq!(splits.len(), 2);
268
269        // Check first fold
270        assert_eq!(splits[0].0, vec![0, 2, 4]); // even indices for train
271        assert_eq!(splits[0].1, vec![1, 3, 5]); // odd indices for test
272
273        // Check second fold
274        assert_eq!(splits[1].0, vec![1, 3, 5]); // odd indices for train
275        assert_eq!(splits[1].1, vec![0, 2, 4]); // even indices for test
276    }
277
278    #[test]
279    fn test_block_cross_validator() {
280        let block_cv = BlockCrossValidator::new(3).test_size(2);
281        let splits = block_cv.split(8, None);
282
283        // Should create 2 folds (first fold is skipped because no training data)
284        assert_eq!(splits.len(), 2);
285
286        // First fold: train on [0, 1], test on [2, 3]
287        assert_eq!(splits[0].0, vec![0, 1]);
288        assert_eq!(splits[0].1, vec![2, 3]);
289
290        // Second fold: train on [0, 1, 2, 3], test on [4, 5]
291        assert_eq!(splits[1].0, vec![0, 1, 2, 3]);
292        assert_eq!(splits[1].1, vec![4, 5]);
293    }
294
295    #[test]
296    fn test_block_cross_validator_with_gap() {
297        let block_cv = BlockCrossValidator::new(3).test_size(2).gap(1);
298        let splits = block_cv.split(8, None);
299
300        // Should create folds with gap between train and test
301        assert_eq!(splits.len(), 2); // First fold will be skipped due to gap
302
303        // Second fold: train on [0], test on [2, 3] (gap of 1 at index 1)
304        assert_eq!(splits[0].0, vec![0]);
305        assert_eq!(splits[0].1, vec![2, 3]);
306
307        // Third fold: train on [0, 1, 2], test on [4, 5] (gap of 1 at index 3)
308        assert_eq!(splits[1].0, vec![0, 1, 2]);
309        assert_eq!(splits[1].1, vec![4, 5]);
310    }
311
312    #[test]
313    fn test_predefined_split() {
314        // Define custom folds: -1 means always in train, 0/1/2 are fold numbers
315        let test_fold = array![-1, 0, 0, 1, 1, 2, 2, -1];
316        let cv = PredefinedSplit::new(test_fold);
317
318        assert_eq!(cv.n_splits(), 3);
319
320        let splits = cv.split(8, None::<&Array1<i32>>);
321        assert_eq!(splits.len(), 3);
322
323        // Check first fold (test fold 0)
324        assert_eq!(splits[0].1, vec![1, 2]); // indices with fold 0
325        assert!(splits[0].0.contains(&0)); // index 0 (fold -1) should be in train
326        assert!(splits[0].0.contains(&7)); // index 7 (fold -1) should be in train
327
328        // Check second fold (test fold 1)
329        assert_eq!(splits[1].1, vec![3, 4]); // indices with fold 1
330
331        // Check third fold (test fold 2)
332        assert_eq!(splits[2].1, vec![5, 6]); // indices with fold 2
333
334        // Verify that samples with fold -1 are always in training sets
335        for (train, _) in &splits {
336            assert!(train.contains(&0));
337            assert!(train.contains(&7));
338        }
339    }
340
341    #[test]
342    fn test_predefined_split_edge_cases() {
343        // Test with all samples in training (all -1)
344        let test_fold = array![-1, -1, -1, -1];
345        let cv = PredefinedSplit::new(test_fold);
346        assert_eq!(cv.n_splits(), 0);
347        let splits = cv.split(4, None::<&Array1<i32>>);
348        assert_eq!(splits.len(), 0);
349
350        // Test with single fold
351        let test_fold = array![0, 0, -1, -1];
352        let cv = PredefinedSplit::new(test_fold);
353        assert_eq!(cv.n_splits(), 1);
354        let splits = cv.split(4, None::<&Array1<i32>>);
355        assert_eq!(splits.len(), 1);
356        assert_eq!(splits[0].1, vec![0, 1]);
357        assert_eq!(splits[0].0, vec![2, 3]);
358    }
359}