sklears_model_selection/
temporal_validation.rs

1//! Advanced time series validation with temporal dependencies
2//!
3//! This module provides sophisticated validation methods for time series data
4//! that respect temporal dependencies and prevent data leakage.
5
6use sklears_core::error::{Result, SklearsError};
7use std::collections::HashMap;
8
9/// Configuration for temporal validation
10#[derive(Debug, Clone)]
11pub struct TemporalValidationConfig {
12    /// Minimum gap between train and test sets (in time units)
13    pub min_gap: usize,
14    /// Maximum lookback window for features (prevents future data leakage)
15    pub max_lookback: usize,
16    /// Whether to use forward chaining validation
17    pub forward_chaining: bool,
18    /// Number of validation splits
19    pub n_splits: usize,
20    /// Percentage of data to use for testing in each split
21    pub test_size: f64,
22    /// Whether to allow overlapping training periods
23    pub allow_overlap: bool,
24    /// Seasonal period for seasonal adjustments
25    pub seasonal_period: Option<usize>,
26}
27
28impl Default for TemporalValidationConfig {
29    fn default() -> Self {
30        Self {
31            min_gap: 1,
32            max_lookback: 10,
33            forward_chaining: true,
34            n_splits: 5,
35            test_size: 0.2,
36            allow_overlap: false,
37            seasonal_period: None,
38        }
39    }
40}
41
42/// Time series cross-validator with temporal dependency awareness
43#[derive(Debug, Clone)]
44pub struct TemporalCrossValidator {
45    config: TemporalValidationConfig,
46}
47
48impl TemporalCrossValidator {
49    pub fn new(config: TemporalValidationConfig) -> Self {
50        Self { config }
51    }
52
53    /// Generate temporal splits that respect time dependencies
54    pub fn split(
55        &self,
56        n_samples: usize,
57        time_index: &[usize],
58    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
59        if time_index.len() != n_samples {
60            return Err(SklearsError::InvalidInput(
61                "Time index length must match number of samples".to_string(),
62            ));
63        }
64
65        let sorted_indices = self.sort_by_time(time_index)?;
66
67        let mut splits = if self.config.forward_chaining {
68            self.forward_chaining_splits(&sorted_indices)?
69        } else {
70            self.sliding_window_splits(&sorted_indices)?
71        };
72
73        // Apply temporal constraints
74        self.apply_temporal_constraints(&mut splits, time_index)?;
75
76        Ok(splits)
77    }
78
79    /// Forward chaining: each training set includes all previous data
80    fn forward_chaining_splits(
81        &self,
82        sorted_indices: &[usize],
83    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
84        let mut splits = Vec::new();
85        let n_samples = sorted_indices.len();
86        let test_samples = (n_samples as f64 * self.config.test_size) as usize;
87
88        for i in 0..self.config.n_splits {
89            let test_start = n_samples - (self.config.n_splits - i) * test_samples;
90            let test_end = test_start + test_samples;
91
92            if test_start < self.config.max_lookback {
93                continue;
94            }
95
96            let train_end = test_start.saturating_sub(self.config.min_gap);
97
98            let train_indices = sorted_indices[0..train_end].to_vec();
99            let test_indices = if test_end <= n_samples {
100                sorted_indices[test_start..test_end].to_vec()
101            } else {
102                sorted_indices[test_start..].to_vec()
103            };
104
105            if !train_indices.is_empty() && !test_indices.is_empty() {
106                splits.push((train_indices, test_indices));
107            }
108        }
109
110        Ok(splits)
111    }
112
113    /// Sliding window: fixed-size training windows
114    fn sliding_window_splits(
115        &self,
116        sorted_indices: &[usize],
117    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
118        let mut splits = Vec::new();
119        let n_samples = sorted_indices.len();
120        let test_samples = (n_samples as f64 * self.config.test_size) as usize;
121        let train_samples = n_samples - test_samples - self.config.min_gap;
122
123        let step_size = if self.config.allow_overlap {
124            test_samples / 2
125        } else {
126            test_samples
127        };
128
129        let mut start = self.config.max_lookback;
130
131        while start + train_samples + self.config.min_gap + test_samples <= n_samples {
132            let train_end = start + train_samples;
133            let test_start = train_end + self.config.min_gap;
134            let test_end = test_start + test_samples;
135
136            let train_indices = sorted_indices[start..train_end].to_vec();
137            let test_indices = sorted_indices[test_start..test_end].to_vec();
138
139            splits.push((train_indices, test_indices));
140            start += step_size;
141        }
142
143        Ok(splits)
144    }
145
146    /// Sort indices by time
147    fn sort_by_time(&self, time_index: &[usize]) -> Result<Vec<usize>> {
148        let mut indexed_times: Vec<(usize, usize)> = time_index
149            .iter()
150            .enumerate()
151            .map(|(idx, &time)| (idx, time))
152            .collect();
153
154        indexed_times.sort_by_key(|&(_, time)| time);
155
156        Ok(indexed_times.into_iter().map(|(idx, _)| idx).collect())
157    }
158
159    /// Apply temporal constraints to ensure no data leakage
160    fn apply_temporal_constraints(
161        &self,
162        splits: &mut [(Vec<usize>, Vec<usize>)],
163        time_index: &[usize],
164    ) -> Result<()> {
165        for (train_indices, test_indices) in splits.iter_mut() {
166            // Ensure no future information in training set
167            let max_test_time = test_indices
168                .iter()
169                .map(|&idx| time_index[idx])
170                .min()
171                .unwrap_or(0);
172
173            train_indices.retain(|&idx| time_index[idx] + self.config.min_gap <= max_test_time);
174
175            // Apply lookback window constraint
176            if let Some(min_train_time) = train_indices.iter().map(|&idx| time_index[idx]).max() {
177                let cutoff_time = min_train_time.saturating_sub(self.config.max_lookback);
178                train_indices.retain(|&idx| time_index[idx] >= cutoff_time);
179            }
180        }
181
182        Ok(())
183    }
184}
185
186/// Seasonal cross-validator for time series with seasonal patterns
187#[derive(Debug, Clone)]
188pub struct SeasonalCrossValidator {
189    config: TemporalValidationConfig,
190    seasonal_period: usize,
191}
192
193impl SeasonalCrossValidator {
194    pub fn new(config: TemporalValidationConfig, seasonal_period: usize) -> Self {
195        Self {
196            config,
197            seasonal_period,
198        }
199    }
200
201    /// Generate seasonal splits that maintain seasonal patterns
202    pub fn split(
203        &self,
204        _n_samples: usize,
205        time_index: &[usize],
206    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
207        let mut splits = Vec::new();
208        let sorted_indices = self.sort_by_time(time_index)?;
209
210        // Group indices by seasonal period
211        let seasonal_groups = self.group_by_season(&sorted_indices, time_index)?;
212
213        // Create splits ensuring each split has representation from all seasons
214        for split_idx in 0..self.config.n_splits {
215            let (train_indices, test_indices) =
216                self.create_seasonal_split(&seasonal_groups, split_idx, time_index)?;
217
218            if !train_indices.is_empty() && !test_indices.is_empty() {
219                splits.push((train_indices, test_indices));
220            }
221        }
222
223        Ok(splits)
224    }
225
226    fn sort_by_time(&self, time_index: &[usize]) -> Result<Vec<usize>> {
227        let mut indexed_times: Vec<(usize, usize)> = time_index
228            .iter()
229            .enumerate()
230            .map(|(idx, &time)| (idx, time))
231            .collect();
232
233        indexed_times.sort_by_key(|&(_, time)| time);
234
235        Ok(indexed_times.into_iter().map(|(idx, _)| idx).collect())
236    }
237
238    fn group_by_season(
239        &self,
240        sorted_indices: &[usize],
241        time_index: &[usize],
242    ) -> Result<HashMap<usize, Vec<usize>>> {
243        let mut seasonal_groups: HashMap<usize, Vec<usize>> = HashMap::new();
244
245        for &idx in sorted_indices {
246            let season = time_index[idx] % self.seasonal_period;
247            seasonal_groups.entry(season).or_default().push(idx);
248        }
249
250        Ok(seasonal_groups)
251    }
252
253    fn create_seasonal_split(
254        &self,
255        seasonal_groups: &HashMap<usize, Vec<usize>>,
256        split_idx: usize,
257        time_index: &[usize],
258    ) -> Result<(Vec<usize>, Vec<usize>)> {
259        let mut train_indices = Vec::new();
260        let mut test_indices = Vec::new();
261
262        for indices in seasonal_groups.values() {
263            let n_season_samples = indices.len();
264            let test_size = (n_season_samples as f64 * self.config.test_size) as usize;
265            let samples_per_split = test_size.max(1);
266
267            let test_start = split_idx * samples_per_split;
268            let test_end = ((split_idx + 1) * samples_per_split).min(n_season_samples);
269
270            if test_start < n_season_samples {
271                // Add test samples
272                test_indices.extend_from_slice(&indices[test_start..test_end]);
273
274                // Add training samples (excluding test and respecting temporal constraints)
275                for (i, &idx) in indices.iter().enumerate() {
276                    if i < test_start || i >= test_end {
277                        // Check temporal constraints
278                        let sample_time = time_index[idx];
279                        let can_use_for_training = test_indices.iter().all(|&test_idx| {
280                            sample_time + self.config.min_gap <= time_index[test_idx]
281                        });
282
283                        if can_use_for_training {
284                            train_indices.push(idx);
285                        }
286                    }
287                }
288            }
289        }
290
291        Ok((train_indices, test_indices))
292    }
293}
294
295/// Blocked temporal cross-validator for handling irregular time series
296#[derive(Debug, Clone)]
297pub struct BlockedTemporalCV {
298    config: TemporalValidationConfig,
299    block_size: usize,
300}
301
302impl BlockedTemporalCV {
303    pub fn new(config: TemporalValidationConfig, block_size: usize) -> Self {
304        Self { config, block_size }
305    }
306
307    /// Generate blocked temporal splits
308    pub fn split(
309        &self,
310        _n_samples: usize,
311        time_index: &[usize],
312    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
313        let sorted_indices = self.sort_by_time(time_index)?;
314        let blocks = self.create_blocks(&sorted_indices)?;
315
316        let mut splits = Vec::new();
317
318        for i in 0..self.config.n_splits {
319            let (train_blocks, test_blocks) = self.select_blocks(&blocks, i)?;
320
321            let train_indices: Vec<usize> = train_blocks.into_iter().flatten().collect();
322            let test_indices: Vec<usize> = test_blocks.into_iter().flatten().collect();
323
324            if !train_indices.is_empty() && !test_indices.is_empty() {
325                splits.push((train_indices, test_indices));
326            }
327        }
328
329        Ok(splits)
330    }
331
332    fn sort_by_time(&self, time_index: &[usize]) -> Result<Vec<usize>> {
333        let mut indexed_times: Vec<(usize, usize)> = time_index
334            .iter()
335            .enumerate()
336            .map(|(idx, &time)| (idx, time))
337            .collect();
338
339        indexed_times.sort_by_key(|&(_, time)| time);
340
341        Ok(indexed_times.into_iter().map(|(idx, _)| idx).collect())
342    }
343
344    fn create_blocks(&self, sorted_indices: &[usize]) -> Result<Vec<Vec<usize>>> {
345        let mut blocks = Vec::new();
346
347        for chunk in sorted_indices.chunks(self.block_size) {
348            blocks.push(chunk.to_vec());
349        }
350
351        Ok(blocks)
352    }
353
354    fn select_blocks(
355        &self,
356        blocks: &[Vec<usize>],
357        split_idx: usize,
358    ) -> Result<(Vec<Vec<usize>>, Vec<Vec<usize>>)> {
359        let n_blocks = blocks.len();
360        let test_blocks_count = (n_blocks as f64 * self.config.test_size) as usize;
361        let test_blocks_count = test_blocks_count.max(1);
362
363        let test_start = split_idx * test_blocks_count;
364        let test_end = ((split_idx + 1) * test_blocks_count).min(n_blocks);
365
366        let mut train_blocks = Vec::new();
367        let mut test_blocks = Vec::new();
368
369        for (i, block) in blocks.iter().enumerate() {
370            if i >= test_start && i < test_end {
371                test_blocks.push(block.clone());
372            } else if i < test_start.saturating_sub(self.config.min_gap) {
373                train_blocks.push(block.clone());
374            }
375        }
376
377        Ok((train_blocks, test_blocks))
378    }
379}
380
381#[allow(non_snake_case)]
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn test_temporal_cross_validator() {
388        let config = TemporalValidationConfig::default();
389        let cv = TemporalCrossValidator::new(config);
390
391        let time_index: Vec<usize> = (0..100).collect();
392        let splits = cv.split(100, &time_index).unwrap();
393
394        assert!(!splits.is_empty(), "Should generate at least one split");
395
396        // Check temporal constraints
397        for (train_indices, test_indices) in &splits {
398            let max_train_time = train_indices
399                .iter()
400                .map(|&i| time_index[i])
401                .max()
402                .unwrap_or(0);
403            let min_test_time = test_indices
404                .iter()
405                .map(|&i| time_index[i])
406                .min()
407                .unwrap_or(usize::MAX);
408
409            assert!(
410                max_train_time + 1 <= min_test_time,
411                "Temporal constraint violated"
412            );
413        }
414    }
415
416    #[test]
417    fn test_seasonal_cross_validator() {
418        let config = TemporalValidationConfig::default();
419        let cv = SeasonalCrossValidator::new(config, 12); // Monthly seasonality
420
421        let time_index: Vec<usize> = (0..120).collect(); // 10 years of monthly data
422        let splits = cv.split(120, &time_index).unwrap();
423
424        assert!(!splits.is_empty(), "Should generate at least one split");
425    }
426
427    #[test]
428    fn test_blocked_temporal_cv() {
429        let config = TemporalValidationConfig::default();
430        let cv = BlockedTemporalCV::new(config, 10);
431
432        let time_index: Vec<usize> = (0..100).collect();
433        let splits = cv.split(100, &time_index).unwrap();
434
435        assert!(!splits.is_empty(), "Should generate at least one split");
436    }
437}