sklears_model_selection/cv/
time_series_cv.rs1use super::CrossValidator;
4use scirs2_core::ndarray::Array1;
5
6#[derive(Debug, Clone)]
8pub struct TimeSeriesSplit {
9 n_splits: usize,
10 max_train_size: Option<usize>,
11 test_size: Option<usize>,
12 gap: usize,
13 overlap: usize,
14}
15
16impl TimeSeriesSplit {
17 pub fn new(n_splits: usize) -> Self {
19 assert!(n_splits >= 2, "n_splits must be at least 2");
20 Self {
21 n_splits,
22 max_train_size: None,
23 test_size: None,
24 gap: 0,
25 overlap: 0,
26 }
27 }
28
29 pub fn max_train_size(mut self, size: usize) -> Self {
31 self.max_train_size = Some(size);
32 self
33 }
34
35 pub fn test_size(mut self, size: usize) -> Self {
37 self.test_size = Some(size);
38 self
39 }
40
41 pub fn gap(mut self, gap: usize) -> Self {
43 self.gap = gap;
44 self
45 }
46
47 pub fn overlap(mut self, overlap: usize) -> Self {
50 self.overlap = overlap;
51 self
52 }
53}
54
55impl CrossValidator for TimeSeriesSplit {
56 fn n_splits(&self) -> usize {
57 self.n_splits
58 }
59
60 fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
61 let n_splits = self.n_splits;
62 let n_folds = n_splits + 1;
63 let test_size = self.test_size.unwrap_or_else(|| n_samples / n_folds);
64
65 assert!(
66 n_folds * test_size <= n_samples,
67 "Too many splits {n_splits} for number of samples {n_samples}"
68 );
69
70 let mut splits = Vec::new();
71 let test_starts = (0..n_splits)
72 .map(|i| n_samples - (n_splits - i) * test_size)
73 .collect::<Vec<_>>();
74
75 for (split_idx, &test_start) in test_starts.iter().enumerate() {
76 let train_end = test_start - self.gap;
77 let test_end = test_start + test_size;
78
79 let mut train_start = 0;
81 if self.overlap > 0 && split_idx > 0 {
82 let prev_test_start = test_starts[split_idx - 1];
84 train_start = prev_test_start.saturating_sub(self.overlap);
85 }
86
87 let mut train_indices: Vec<usize> = (train_start..train_end).collect();
88
89 if let Some(max_size) = self.max_train_size {
91 if train_indices.len() > max_size {
92 let start_idx = train_indices.len() - max_size;
93 train_indices = train_indices[start_idx..].to_vec();
94 }
95 }
96
97 let test_indices: Vec<usize> = (test_start..test_end).collect();
98 splits.push((train_indices, test_indices));
99 }
100
101 splits
102 }
103}
104
105#[derive(Debug, Clone)]
110pub struct BlockedTimeSeriesCV {
111 n_splits: usize,
112 n_blocks: usize,
113 gap: usize,
114 test_size: Option<usize>,
115}
116
117impl BlockedTimeSeriesCV {
118 pub fn new(n_splits: usize, n_blocks: usize) -> Self {
120 assert!(n_splits >= 2, "n_splits must be at least 2");
121 assert!(n_blocks >= 1, "n_blocks must be at least 1");
122 Self {
123 n_splits,
124 n_blocks,
125 gap: 0,
126 test_size: None,
127 }
128 }
129
130 pub fn gap(mut self, gap: usize) -> Self {
132 self.gap = gap;
133 self
134 }
135
136 pub fn test_size(mut self, size: usize) -> Self {
138 self.test_size = Some(size);
139 self
140 }
141}
142
143impl CrossValidator for BlockedTimeSeriesCV {
144 fn n_splits(&self) -> usize {
145 self.n_splits
146 }
147
148 fn split(&self, n_samples: usize, _y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
149 let test_size = self.test_size.unwrap_or(n_samples / (self.n_splits + 1));
150 let mut splits = Vec::new();
151
152 for i in 0..self.n_splits {
153 let test_start = n_samples - (self.n_splits - i) * test_size;
154 let test_end = test_start + test_size;
155 let test_indices: Vec<usize> = (test_start..test_end).collect();
156
157 let mut train_indices = Vec::new();
159 let available_train_space = test_start.saturating_sub(self.gap);
160 let block_size = available_train_space / (self.n_blocks + self.n_blocks - 1); for block in 0..self.n_blocks {
163 let block_start = block * 2 * block_size; let block_end = block_start + block_size;
165
166 if block_end <= available_train_space {
167 train_indices.extend(block_start..block_end);
168 }
169 }
170
171 splits.push((train_indices, test_indices));
172 }
173
174 splits
175 }
176}
177
178#[derive(Debug, Clone)]
183pub struct PurgedGroupTimeSeriesSplit {
184 n_splits: usize,
185 max_train_group_size: Option<usize>,
186 group_gap: usize,
187 purge_length: usize,
188 embargo_length: usize,
189}
190
191impl PurgedGroupTimeSeriesSplit {
192 pub fn new(n_splits: usize) -> Self {
194 assert!(n_splits >= 2, "n_splits must be at least 2");
195 Self {
196 n_splits,
197 max_train_group_size: None,
198 group_gap: 0,
199 purge_length: 0,
200 embargo_length: 0,
201 }
202 }
203
204 pub fn max_train_group_size(mut self, size: usize) -> Self {
206 self.max_train_group_size = Some(size);
207 self
208 }
209
210 pub fn group_gap(mut self, gap: usize) -> Self {
212 self.group_gap = gap;
213 self
214 }
215
216 pub fn purge_length(mut self, length: usize) -> Self {
218 self.purge_length = length;
219 self
220 }
221
222 pub fn embargo_length(mut self, length: usize) -> Self {
224 self.embargo_length = length;
225 self
226 }
227
228 pub fn split_with_groups(
230 &self,
231 n_samples: usize,
232 groups: &Array1<i32>,
233 ) -> Vec<(Vec<usize>, Vec<usize>)> {
234 assert_eq!(
235 groups.len(),
236 n_samples,
237 "groups must have same length as n_samples"
238 );
239
240 let mut group_positions: std::collections::HashMap<i32, Vec<usize>> =
242 std::collections::HashMap::new();
243 for (idx, &group) in groups.iter().enumerate() {
244 group_positions.entry(group).or_default().push(idx);
245 }
246
247 let mut unique_groups: Vec<i32> = group_positions.keys().cloned().collect();
248 unique_groups.sort();
249
250 let groups_per_test = unique_groups.len() / self.n_splits;
251 let mut splits = Vec::new();
252
253 for i in 0..self.n_splits {
254 let test_group_start = i * groups_per_test;
255 let test_group_end = if i == self.n_splits - 1 {
256 unique_groups.len()
257 } else {
258 (i + 1) * groups_per_test
259 };
260
261 let test_groups = &unique_groups[test_group_start..test_group_end];
262
263 let mut test_indices = Vec::new();
265 for &group in test_groups {
266 test_indices.extend(&group_positions[&group]);
267 }
268
269 let mut train_indices = Vec::new();
271 for &group in &unique_groups {
272 if !test_groups.contains(&group) {
273 let group_indices = &group_positions[&group];
274
275 let should_include =
277 self.should_include_group(group, test_groups, &unique_groups);
278
279 if should_include {
280 train_indices.extend(group_indices);
281 }
282 }
283 }
284
285 splits.push((train_indices, test_indices));
286 }
287
288 splits
289 }
290
291 fn should_include_group(&self, group: i32, test_groups: &[i32], all_groups: &[i32]) -> bool {
292 let group_pos = all_groups.iter().position(|&g| g == group).unwrap();
293 let test_start = all_groups
294 .iter()
295 .position(|&g| g == test_groups[0])
296 .unwrap();
297 let test_end = all_groups
298 .iter()
299 .position(|&g| g == test_groups[test_groups.len() - 1])
300 .unwrap();
301
302 if group_pos + self.purge_length > test_start && group_pos < test_start {
304 return false;
305 }
306
307 if group_pos > test_end && group_pos <= test_end + self.embargo_length {
309 return false;
310 }
311
312 true
313 }
314}
315
316impl CrossValidator for PurgedGroupTimeSeriesSplit {
317 fn n_splits(&self) -> usize {
318 self.n_splits
319 }
320
321 fn split(&self, n_samples: usize, y: Option<&Array1<i32>>) -> Vec<(Vec<usize>, Vec<usize>)> {
322 let groups = y.expect("PurgedGroupTimeSeriesSplit requires group labels in y parameter");
323 self.split_with_groups(n_samples, groups)
324 }
325}