sklears_preprocessing/temporal/
lag_features.rs

1//! Lag feature generation for time series data
2//!
3//! This module provides utilities for creating lag features from time series data,
4//! which are essential for many temporal modeling tasks.
5
6use scirs2_core::ndarray::Array2;
7use sklears_core::{
8    error::{Result, SklearsError},
9    traits::{Fit, Trained, Transform, Untrained},
10    types::Float,
11};
12use std::marker::PhantomData;
13
14/// Configuration for LagFeatureGenerator
15#[derive(Debug, Clone)]
16pub struct LagFeatureGeneratorConfig {
17    /// Number of lag periods to generate
18    pub lags: Vec<usize>,
19    /// Whether to drop missing values (first n rows where lags are not available)
20    pub drop_na: bool,
21    /// Fill value for missing lag values
22    pub fill_value: Option<Float>,
23}
24
25impl Default for LagFeatureGeneratorConfig {
26    fn default() -> Self {
27        Self {
28            lags: vec![1, 2, 3], // Default to 3 lags
29            drop_na: false,
30            fill_value: Some(0.0),
31        }
32    }
33}
34
35/// LagFeatureGenerator for creating lag features from time series data
36#[derive(Debug, Clone)]
37pub struct LagFeatureGenerator<S> {
38    config: LagFeatureGeneratorConfig,
39    n_features_out_: Option<usize>,
40    _phantom: PhantomData<S>,
41}
42
43impl LagFeatureGenerator<Untrained> {
44    /// Create a new LagFeatureGenerator
45    pub fn new() -> Self {
46        Self {
47            config: LagFeatureGeneratorConfig::default(),
48            n_features_out_: None,
49            _phantom: PhantomData,
50        }
51    }
52
53    /// Set the lag periods
54    pub fn lags(mut self, lags: Vec<usize>) -> Self {
55        self.config.lags = lags;
56        self
57    }
58
59    /// Set whether to drop missing values
60    pub fn drop_na(mut self, drop_na: bool) -> Self {
61        self.config.drop_na = drop_na;
62        self
63    }
64
65    /// Set the fill value for missing lag values
66    pub fn fill_value(mut self, fill_value: Float) -> Self {
67        self.config.fill_value = Some(fill_value);
68        self
69    }
70}
71
72impl LagFeatureGenerator<Trained> {
73    /// Get the number of output features
74    pub fn n_features_out(&self) -> usize {
75        self.n_features_out_.expect("Generator should be fitted")
76    }
77}
78
79impl Default for LagFeatureGenerator<Untrained> {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl Fit<Array2<Float>, ()> for LagFeatureGenerator<Untrained> {
86    type Fitted = LagFeatureGenerator<Trained>;
87
88    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
89        let (_, n_input_features) = x.dim();
90        let n_features_out = n_input_features * (self.config.lags.len() + 1); // Original + lags
91
92        if self.config.lags.is_empty() {
93            return Err(SklearsError::InvalidParameter {
94                name: "lags".to_string(),
95                reason: "At least one lag must be specified".to_string(),
96            });
97        }
98
99        Ok(LagFeatureGenerator {
100            config: self.config,
101            n_features_out_: Some(n_features_out),
102            _phantom: PhantomData,
103        })
104    }
105}
106
107impl Transform<Array2<Float>, Array2<Float>> for LagFeatureGenerator<Trained> {
108    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
109        let (n_samples, n_features) = x.dim();
110        let max_lag = *self.config.lags.iter().max().unwrap_or(&0);
111
112        if max_lag >= n_samples {
113            return Err(SklearsError::InvalidInput(
114                "Maximum lag cannot be greater than or equal to number of samples".to_string(),
115            ));
116        }
117
118        let n_features_out = self.n_features_out();
119        let mut result = Array2::<Float>::zeros((n_samples, n_features_out));
120
121        // Copy original features
122        for i in 0..n_samples {
123            for j in 0..n_features {
124                result[[i, j]] = x[[i, j]];
125            }
126        }
127
128        // Generate lag features
129        let mut feature_idx = n_features;
130        for &lag in &self.config.lags {
131            for j in 0..n_features {
132                for i in 0..n_samples {
133                    if i >= lag {
134                        result[[i, feature_idx]] = x[[i - lag, j]];
135                    } else {
136                        // Fill missing values
137                        result[[i, feature_idx]] = self.config.fill_value.unwrap_or(0.0);
138                    }
139                }
140                feature_idx += 1;
141            }
142        }
143
144        // Drop rows with missing values if requested
145        if self.config.drop_na && max_lag > 0 {
146            let valid_rows = n_samples - max_lag;
147            let mut trimmed_result = Array2::<Float>::zeros((valid_rows, n_features_out));
148            for i in 0..valid_rows {
149                for j in 0..n_features_out {
150                    trimmed_result[[i, j]] = result[[i + max_lag, j]];
151                }
152            }
153            Ok(trimmed_result)
154        } else {
155            Ok(result)
156        }
157    }
158}