sklears_model_selection/cv/
time_series_cv.rs

1//! Time series cross-validation iterators
2
3use super::CrossValidator;
4use scirs2_core::ndarray::Array1;
5
6/// Time Series Split cross-validator with gap and overlapping support
7#[derive(Debug, Clone)]
8pub struct TimeSeriesSplit {
9    n_splits: usize,
10    max_train_size: Option<usize>,
11    test_size: Option<usize>,
12    gap: usize,
13    overlap: usize,
14}
15
16impl TimeSeriesSplit {
17    /// Create a new TimeSeriesSplit cross-validator
18    pub fn new(n_splits: usize) -> Self {
19        assert!(n_splits >= 2, "n_splits must be at least 2");
20        Self {
21            n_splits,
22            max_train_size: None,
23            test_size: None,
24            gap: 0,
25            overlap: 0,
26        }
27    }
28
29    /// Set the maximum size for a single training set
30    pub fn max_train_size(mut self, size: usize) -> Self {
31        self.max_train_size = Some(size);
32        self
33    }
34
35    /// Set the size of the test set
36    pub fn test_size(mut self, size: usize) -> Self {
37        self.test_size = Some(size);
38        self
39    }
40
41    /// Set the gap between train and test set
42    pub fn gap(mut self, gap: usize) -> Self {
43        self.gap = gap;
44        self
45    }
46
47    /// Set the overlap between consecutive training sets
48    /// When overlap > 0, training sets will include overlapping data from previous splits
49    pub fn overlap(mut self, overlap: usize) -> Self {
50        self.overlap = overlap;
51        self
52    }
53}
54
55impl CrossValidator for TimeSeriesSplit {
56    fn n_splits(&self) -> usize {
57        self.n_splits
58    }
59
60    fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
61        let n_splits = self.n_splits;
62        let n_folds = n_splits + 1;
63        let test_size = self.test_size.unwrap_or_else(|| n_samples / n_folds);
64
65        assert!(
66            n_folds * test_size <= n_samples,
67            "Too many splits {n_splits} for number of samples {n_samples}"
68        );
69
70        let mut splits = Vec::new();
71        let test_starts = (0..n_splits)
72            .map(|i| n_samples - (n_splits - i) * test_size)
73            .collect::<Vec<_>>();
74
75        for (split_idx, &test_start) in test_starts.iter().enumerate() {
76            let train_end = test_start - self.gap;
77            let test_end = test_start + test_size;
78
79            // Calculate training set with potential overlap
80            let mut train_start = 0;
81            if self.overlap > 0 && split_idx > 0 {
82                // For overlapping, start training from the overlap amount before previous test
83                let prev_test_start = test_starts[split_idx - 1];
84                train_start = prev_test_start.saturating_sub(self.overlap);
85            }
86
87            let mut train_indices: Vec<usize> = (train_start..train_end).collect();
88
89            // Apply max_train_size if set
90            if let Some(max_size) = self.max_train_size {
91                if train_indices.len() > max_size {
92                    let start_idx = train_indices.len() - max_size;
93                    train_indices = train_indices[start_idx..].to_vec();
94                }
95            }
96
97            let test_indices: Vec<usize> = (test_start..test_end).collect();
98            splits.push((train_indices, test_indices));
99        }
100
101        splits
102    }
103}
104
105/// Blocked Time Series Cross-Validation
106///
107/// This cross-validator provides multiple non-contiguous training blocks
108/// for time series data, with gap control to prevent data leakage.
109#[derive(Debug, Clone)]
110pub struct BlockedTimeSeriesCV {
111    n_splits: usize,
112    n_blocks: usize,
113    gap: usize,
114    test_size: Option<usize>,
115}
116
117impl BlockedTimeSeriesCV {
118    /// Create a new BlockedTimeSeriesCV cross-validator
119    pub fn new(n_splits: usize, n_blocks: usize) -> Self {
120        assert!(n_splits >= 2, "n_splits must be at least 2");
121        assert!(n_blocks >= 1, "n_blocks must be at least 1");
122        Self {
123            n_splits,
124            n_blocks,
125            gap: 0,
126            test_size: None,
127        }
128    }
129
130    /// Set the gap between blocks and test sets
131    pub fn gap(mut self, gap: usize) -> Self {
132        self.gap = gap;
133        self
134    }
135
136    /// Set the size of the test set
137    pub fn test_size(mut self, size: usize) -> Self {
138        self.test_size = Some(size);
139        self
140    }
141}
142
143impl CrossValidator for BlockedTimeSeriesCV {
144    fn n_splits(&self) -> usize {
145        self.n_splits
146    }
147
148    fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
149        let test_size = self.test_size.unwrap_or(n_samples / (self.n_splits + 1));
150        let mut splits = Vec::new();
151
152        for i in 0..self.n_splits {
153            let test_start = n_samples - (self.n_splits - i) * test_size;
154            let test_end = test_start + test_size;
155            let test_indices: Vec<usize> = (test_start..test_end).collect();
156
157            // Create multiple training blocks before the test set
158            let mut train_indices = Vec::new();
159            let available_train_space = test_start.saturating_sub(self.gap);
160            let block_size = available_train_space / (self.n_blocks + self.n_blocks - 1); // Include gaps between blocks
161
162            for block in 0..self.n_blocks {
163                let block_start = block * 2 * block_size; // 2x for block + gap
164                let block_end = block_start + block_size;
165
166                if block_end <= available_train_space {
167                    train_indices.extend(block_start..block_end);
168                }
169            }
170
171            splits.push((train_indices, test_indices));
172        }
173
174        splits
175    }
176}
177
178/// Purged Group Time Series Split for financial data
179///
180/// Advanced time series cross-validation with purging and embargo periods
181/// to prevent data leakage in financial modeling.
182#[derive(Debug, Clone)]
183pub struct PurgedGroupTimeSeriesSplit {
184    n_splits: usize,
185    max_train_group_size: Option<usize>,
186    group_gap: usize,
187    purge_length: usize,
188    embargo_length: usize,
189}
190
191impl PurgedGroupTimeSeriesSplit {
192    /// Create a new PurgedGroupTimeSeriesSplit cross-validator
193    pub fn new(n_splits: usize) -> Self {
194        assert!(n_splits >= 2, "n_splits must be at least 2");
195        Self {
196            n_splits,
197            max_train_group_size: None,
198            group_gap: 0,
199            purge_length: 0,
200            embargo_length: 0,
201        }
202    }
203
204    /// Set the maximum size for training groups
205    pub fn max_train_group_size(mut self, size: usize) -> Self {
206        self.max_train_group_size = Some(size);
207        self
208    }
209
210    /// Set the gap between groups
211    pub fn group_gap(mut self, gap: usize) -> Self {
212        self.group_gap = gap;
213        self
214    }
215
216    /// Set the purge length (remove samples before test set)
217    pub fn purge_length(mut self, length: usize) -> Self {
218        self.purge_length = length;
219        self
220    }
221
222    /// Set the embargo length (remove samples after test set)
223    pub fn embargo_length(mut self, length: usize) -> Self {
224        self.embargo_length = length;
225        self
226    }
227
228    /// Split with group information for financial time series
229    pub fn split_with_groups(
230        &self,
231        n_samples: usize,
232        groups: &Array1<i32>,
233    ) -> Vec<(Vec<usize>, Vec<usize>)> {
234        assert_eq!(
235            groups.len(),
236            n_samples,
237            "groups must have same length as n_samples"
238        );
239
240        // Get unique groups in order
241        let mut group_positions: std::collections::HashMap<i32, Vec<usize>> =
242            std::collections::HashMap::new();
243        for (idx, &group) in groups.iter().enumerate() {
244            group_positions.entry(group).or_default().push(idx);
245        }
246
247        let mut unique_groups: Vec<i32> = group_positions.keys().cloned().collect();
248        unique_groups.sort();
249
250        let groups_per_test = unique_groups.len() / self.n_splits;
251        let mut splits = Vec::new();
252
253        for i in 0..self.n_splits {
254            let test_group_start = i * groups_per_test;
255            let test_group_end = if i == self.n_splits - 1 {
256                unique_groups.len()
257            } else {
258                (i + 1) * groups_per_test
259            };
260
261            let test_groups = &unique_groups[test_group_start..test_group_end];
262
263            // Collect test indices
264            let mut test_indices = Vec::new();
265            for &group in test_groups {
266                test_indices.extend(&group_positions[&group]);
267            }
268
269            // Collect train indices with purging and embargo
270            let mut train_indices = Vec::new();
271            for &group in &unique_groups {
272                if !test_groups.contains(&group) {
273                    let group_indices = &group_positions[&group];
274
275                    // Check if this group should be purged or embargoed
276                    let should_include =
277                        self.should_include_group(group, test_groups, &unique_groups);
278
279                    if should_include {
280                        train_indices.extend(group_indices);
281                    }
282                }
283            }
284
285            splits.push((train_indices, test_indices));
286        }
287
288        splits
289    }
290
291    fn should_include_group(&self, group: i32, test_groups: &[i32], all_groups: &[i32]) -> bool {
292        let group_pos = all_groups.iter().position(|&g| g == group).unwrap();
293        let test_start = all_groups
294            .iter()
295            .position(|&g| g == test_groups[0])
296            .unwrap();
297        let test_end = all_groups
298            .iter()
299            .position(|&g| g == test_groups[test_groups.len() - 1])
300            .unwrap();
301
302        // Check purge period (before test)
303        if group_pos + self.purge_length > test_start && group_pos < test_start {
304            return false;
305        }
306
307        // Check embargo period (after test)
308        if group_pos > test_end && group_pos <= test_end + self.embargo_length {
309            return false;
310        }
311
312        true
313    }
314}
315
316impl CrossValidator for PurgedGroupTimeSeriesSplit {
317    fn n_splits(&self) -> usize {
318        self.n_splits
319    }
320
321    fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
322        let groups = y.expect("PurgedGroupTimeSeriesSplit requires group labels in y parameter");
323        self.split_with_groups(n_samples, groups)
324    }
325}