1use sklears_core::error::{Result, SklearsError};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct TemporalValidationConfig {
12 pub min_gap: usize,
14 pub max_lookback: usize,
16 pub forward_chaining: bool,
18 pub n_splits: usize,
20 pub test_size: f64,
22 pub allow_overlap: bool,
24 pub seasonal_period: Option<usize>,
26}
27
28impl Default for TemporalValidationConfig {
29 fn default() -> Self {
30 Self {
31 min_gap: 1,
32 max_lookback: 10,
33 forward_chaining: true,
34 n_splits: 5,
35 test_size: 0.2,
36 allow_overlap: false,
37 seasonal_period: None,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct TemporalCrossValidator {
45 config: TemporalValidationConfig,
46}
47
48impl TemporalCrossValidator {
49 pub fn new(config: TemporalValidationConfig) -> Self {
50 Self { config }
51 }
52
53 pub fn split(
55 &self,
56 n_samples: usize,
57 time_index: &[usize],
58 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
59 if time_index.len() != n_samples {
60 return Err(SklearsError::InvalidInput(
61 "Time index length must match number of samples".to_string(),
62 ));
63 }
64
65 let sorted_indices = self.sort_by_time(time_index)?;
66
67 let mut splits = if self.config.forward_chaining {
68 self.forward_chaining_splits(&sorted_indices)?
69 } else {
70 self.sliding_window_splits(&sorted_indices)?
71 };
72
73 self.apply_temporal_constraints(&mut splits, time_index)?;
75
76 Ok(splits)
77 }
78
79 fn forward_chaining_splits(
81 &self,
82 sorted_indices: &[usize],
83 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
84 let mut splits = Vec::new();
85 let n_samples = sorted_indices.len();
86 let test_samples = (n_samples as f64 * self.config.test_size) as usize;
87
88 for i in 0..self.config.n_splits {
89 let test_start = n_samples - (self.config.n_splits - i) * test_samples;
90 let test_end = test_start + test_samples;
91
92 if test_start < self.config.max_lookback {
93 continue;
94 }
95
96 let train_end = test_start.saturating_sub(self.config.min_gap);
97
98 let train_indices = sorted_indices[0..train_end].to_vec();
99 let test_indices = if test_end <= n_samples {
100 sorted_indices[test_start..test_end].to_vec()
101 } else {
102 sorted_indices[test_start..].to_vec()
103 };
104
105 if !train_indices.is_empty() && !test_indices.is_empty() {
106 splits.push((train_indices, test_indices));
107 }
108 }
109
110 Ok(splits)
111 }
112
113 fn sliding_window_splits(
115 &self,
116 sorted_indices: &[usize],
117 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
118 let mut splits = Vec::new();
119 let n_samples = sorted_indices.len();
120 let test_samples = (n_samples as f64 * self.config.test_size) as usize;
121 let train_samples = n_samples - test_samples - self.config.min_gap;
122
123 let step_size = if self.config.allow_overlap {
124 test_samples / 2
125 } else {
126 test_samples
127 };
128
129 let mut start = self.config.max_lookback;
130
131 while start + train_samples + self.config.min_gap + test_samples <= n_samples {
132 let train_end = start + train_samples;
133 let test_start = train_end + self.config.min_gap;
134 let test_end = test_start + test_samples;
135
136 let train_indices = sorted_indices[start..train_end].to_vec();
137 let test_indices = sorted_indices[test_start..test_end].to_vec();
138
139 splits.push((train_indices, test_indices));
140 start += step_size;
141 }
142
143 Ok(splits)
144 }
145
146 fn sort_by_time(&self, time_index: &[usize]) -> Result<Vec<usize>> {
148 let mut indexed_times: Vec<(usize, usize)> = time_index
149 .iter()
150 .enumerate()
151 .map(|(idx, &time)| (idx, time))
152 .collect();
153
154 indexed_times.sort_by_key(|&(_, time)| time);
155
156 Ok(indexed_times.into_iter().map(|(idx, _)| idx).collect())
157 }
158
159 fn apply_temporal_constraints(
161 &self,
162 splits: &mut [(Vec<usize>, Vec<usize>)],
163 time_index: &[usize],
164 ) -> Result<()> {
165 for (train_indices, test_indices) in splits.iter_mut() {
166 let max_test_time = test_indices
168 .iter()
169 .map(|&idx| time_index[idx])
170 .min()
171 .unwrap_or(0);
172
173 train_indices.retain(|&idx| time_index[idx] + self.config.min_gap <= max_test_time);
174
175 if let Some(min_train_time) = train_indices.iter().map(|&idx| time_index[idx]).max() {
177 let cutoff_time = min_train_time.saturating_sub(self.config.max_lookback);
178 train_indices.retain(|&idx| time_index[idx] >= cutoff_time);
179 }
180 }
181
182 Ok(())
183 }
184}
185
186#[derive(Debug, Clone)]
188pub struct SeasonalCrossValidator {
189 config: TemporalValidationConfig,
190 seasonal_period: usize,
191}
192
193impl SeasonalCrossValidator {
194 pub fn new(config: TemporalValidationConfig, seasonal_period: usize) -> Self {
195 Self {
196 config,
197 seasonal_period,
198 }
199 }
200
201 pub fn split(
203 &self,
204 _n_samples: usize,
205 time_index: &[usize],
206 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
207 let mut splits = Vec::new();
208 let sorted_indices = self.sort_by_time(time_index)?;
209
210 let seasonal_groups = self.group_by_season(&sorted_indices, time_index)?;
212
213 for split_idx in 0..self.config.n_splits {
215 let (train_indices, test_indices) =
216 self.create_seasonal_split(&seasonal_groups, split_idx, time_index)?;
217
218 if !train_indices.is_empty() && !test_indices.is_empty() {
219 splits.push((train_indices, test_indices));
220 }
221 }
222
223 Ok(splits)
224 }
225
226 fn sort_by_time(&self, time_index: &[usize]) -> Result<Vec<usize>> {
227 let mut indexed_times: Vec<(usize, usize)> = time_index
228 .iter()
229 .enumerate()
230 .map(|(idx, &time)| (idx, time))
231 .collect();
232
233 indexed_times.sort_by_key(|&(_, time)| time);
234
235 Ok(indexed_times.into_iter().map(|(idx, _)| idx).collect())
236 }
237
238 fn group_by_season(
239 &self,
240 sorted_indices: &[usize],
241 time_index: &[usize],
242 ) -> Result<HashMap<usize, Vec<usize>>> {
243 let mut seasonal_groups: HashMap<usize, Vec<usize>> = HashMap::new();
244
245 for &idx in sorted_indices {
246 let season = time_index[idx] % self.seasonal_period;
247 seasonal_groups.entry(season).or_default().push(idx);
248 }
249
250 Ok(seasonal_groups)
251 }
252
253 fn create_seasonal_split(
254 &self,
255 seasonal_groups: &HashMap<usize, Vec<usize>>,
256 split_idx: usize,
257 time_index: &[usize],
258 ) -> Result<(Vec<usize>, Vec<usize>)> {
259 let mut train_indices = Vec::new();
260 let mut test_indices = Vec::new();
261
262 for indices in seasonal_groups.values() {
263 let n_season_samples = indices.len();
264 let test_size = (n_season_samples as f64 * self.config.test_size) as usize;
265 let samples_per_split = test_size.max(1);
266
267 let test_start = split_idx * samples_per_split;
268 let test_end = ((split_idx + 1) * samples_per_split).min(n_season_samples);
269
270 if test_start < n_season_samples {
271 test_indices.extend_from_slice(&indices[test_start..test_end]);
273
274 for (i, &idx) in indices.iter().enumerate() {
276 if i < test_start || i >= test_end {
277 let sample_time = time_index[idx];
279 let can_use_for_training = test_indices.iter().all(|&test_idx| {
280 sample_time + self.config.min_gap <= time_index[test_idx]
281 });
282
283 if can_use_for_training {
284 train_indices.push(idx);
285 }
286 }
287 }
288 }
289 }
290
291 Ok((train_indices, test_indices))
292 }
293}
294
295#[derive(Debug, Clone)]
297pub struct BlockedTemporalCV {
298 config: TemporalValidationConfig,
299 block_size: usize,
300}
301
302impl BlockedTemporalCV {
303 pub fn new(config: TemporalValidationConfig, block_size: usize) -> Self {
304 Self { config, block_size }
305 }
306
307 pub fn split(
309 &self,
310 _n_samples: usize,
311 time_index: &[usize],
312 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
313 let sorted_indices = self.sort_by_time(time_index)?;
314 let blocks = self.create_blocks(&sorted_indices)?;
315
316 let mut splits = Vec::new();
317
318 for i in 0..self.config.n_splits {
319 let (train_blocks, test_blocks) = self.select_blocks(&blocks, i)?;
320
321 let train_indices: Vec<usize> = train_blocks.into_iter().flatten().collect();
322 let test_indices: Vec<usize> = test_blocks.into_iter().flatten().collect();
323
324 if !train_indices.is_empty() && !test_indices.is_empty() {
325 splits.push((train_indices, test_indices));
326 }
327 }
328
329 Ok(splits)
330 }
331
332 fn sort_by_time(&self, time_index: &[usize]) -> Result<Vec<usize>> {
333 let mut indexed_times: Vec<(usize, usize)> = time_index
334 .iter()
335 .enumerate()
336 .map(|(idx, &time)| (idx, time))
337 .collect();
338
339 indexed_times.sort_by_key(|&(_, time)| time);
340
341 Ok(indexed_times.into_iter().map(|(idx, _)| idx).collect())
342 }
343
344 fn create_blocks(&self, sorted_indices: &[usize]) -> Result<Vec<Vec<usize>>> {
345 let mut blocks = Vec::new();
346
347 for chunk in sorted_indices.chunks(self.block_size) {
348 blocks.push(chunk.to_vec());
349 }
350
351 Ok(blocks)
352 }
353
354 fn select_blocks(
355 &self,
356 blocks: &[Vec<usize>],
357 split_idx: usize,
358 ) -> Result<(Vec<Vec<usize>>, Vec<Vec<usize>>)> {
359 let n_blocks = blocks.len();
360 let test_blocks_count = (n_blocks as f64 * self.config.test_size) as usize;
361 let test_blocks_count = test_blocks_count.max(1);
362
363 let test_start = split_idx * test_blocks_count;
364 let test_end = ((split_idx + 1) * test_blocks_count).min(n_blocks);
365
366 let mut train_blocks = Vec::new();
367 let mut test_blocks = Vec::new();
368
369 for (i, block) in blocks.iter().enumerate() {
370 if i >= test_start && i < test_end {
371 test_blocks.push(block.clone());
372 } else if i < test_start.saturating_sub(self.config.min_gap) {
373 train_blocks.push(block.clone());
374 }
375 }
376
377 Ok((train_blocks, test_blocks))
378 }
379}
380
381#[allow(non_snake_case)]
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn test_temporal_cross_validator() {
388 let config = TemporalValidationConfig::default();
389 let cv = TemporalCrossValidator::new(config);
390
391 let time_index: Vec<usize> = (0..100).collect();
392 let splits = cv.split(100, &time_index).unwrap();
393
394 assert!(!splits.is_empty(), "Should generate at least one split");
395
396 for (train_indices, test_indices) in &splits {
398 let max_train_time = train_indices
399 .iter()
400 .map(|&i| time_index[i])
401 .max()
402 .unwrap_or(0);
403 let min_test_time = test_indices
404 .iter()
405 .map(|&i| time_index[i])
406 .min()
407 .unwrap_or(usize::MAX);
408
409 assert!(
410 max_train_time + 1 <= min_test_time,
411 "Temporal constraint violated"
412 );
413 }
414 }
415
416 #[test]
417 fn test_seasonal_cross_validator() {
418 let config = TemporalValidationConfig::default();
419 let cv = SeasonalCrossValidator::new(config, 12); let time_index: Vec<usize> = (0..120).collect(); let splits = cv.split(120, &time_index).unwrap();
423
424 assert!(!splits.is_empty(), "Should generate at least one split");
425 }
426
427 #[test]
428 fn test_blocked_temporal_cv() {
429 let config = TemporalValidationConfig::default();
430 let cv = BlockedTemporalCV::new(config, 10);
431
432 let time_index: Vec<usize> = (0..100).collect();
433 let splits = cv.split(100, &time_index).unwrap();
434
435 assert!(!splits.is_empty(), "Should generate at least one split");
436 }
437}