sklears_preprocessing/temporal/
lag_features.rs1use scirs2_core::ndarray::Array2;
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::marker::PhantomData;
13
14#[derive(Debug, Clone)]
16pub struct LagFeatureGeneratorConfig {
17 pub lags: Vec<usize>,
19 pub drop_na: bool,
21 pub fill_value: Option<Float>,
23}
24
25impl Default for LagFeatureGeneratorConfig {
26 fn default() -> Self {
27 Self {
28 lags: vec![1, 2, 3], drop_na: false,
30 fill_value: Some(0.0),
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct LagFeatureGenerator<S> {
38 config: LagFeatureGeneratorConfig,
39 n_features_out_: Option<usize>,
40 _phantom: PhantomData<S>,
41}
42
43impl LagFeatureGenerator<Untrained> {
44 pub fn new() -> Self {
46 Self {
47 config: LagFeatureGeneratorConfig::default(),
48 n_features_out_: None,
49 _phantom: PhantomData,
50 }
51 }
52
53 pub fn lags(mut self, lags: Vec<usize>) -> Self {
55 self.config.lags = lags;
56 self
57 }
58
59 pub fn drop_na(mut self, drop_na: bool) -> Self {
61 self.config.drop_na = drop_na;
62 self
63 }
64
65 pub fn fill_value(mut self, fill_value: Float) -> Self {
67 self.config.fill_value = Some(fill_value);
68 self
69 }
70}
71
72impl LagFeatureGenerator<Trained> {
73 pub fn n_features_out(&self) -> usize {
75 self.n_features_out_.expect("Generator should be fitted")
76 }
77}
78
79impl Default for LagFeatureGenerator<Untrained> {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl Fit<Array2<Float>, ()> for LagFeatureGenerator<Untrained> {
86 type Fitted = LagFeatureGenerator<Trained>;
87
88 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
89 let (_, n_input_features) = x.dim();
90 let n_features_out = n_input_features * (self.config.lags.len() + 1); if self.config.lags.is_empty() {
93 return Err(SklearsError::InvalidParameter {
94 name: "lags".to_string(),
95 reason: "At least one lag must be specified".to_string(),
96 });
97 }
98
99 Ok(LagFeatureGenerator {
100 config: self.config,
101 n_features_out_: Some(n_features_out),
102 _phantom: PhantomData,
103 })
104 }
105}
106
107impl Transform<Array2<Float>, Array2<Float>> for LagFeatureGenerator<Trained> {
108 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
109 let (n_samples, n_features) = x.dim();
110 let max_lag = *self.config.lags.iter().max().unwrap_or(&0);
111
112 if max_lag >= n_samples {
113 return Err(SklearsError::InvalidInput(
114 "Maximum lag cannot be greater than or equal to number of samples".to_string(),
115 ));
116 }
117
118 let n_features_out = self.n_features_out();
119 let mut result = Array2::<Float>::zeros((n_samples, n_features_out));
120
121 for i in 0..n_samples {
123 for j in 0..n_features {
124 result[[i, j]] = x[[i, j]];
125 }
126 }
127
128 let mut feature_idx = n_features;
130 for &lag in &self.config.lags {
131 for j in 0..n_features {
132 for i in 0..n_samples {
133 if i >= lag {
134 result[[i, feature_idx]] = x[[i - lag, j]];
135 } else {
136 result[[i, feature_idx]] = self.config.fill_value.unwrap_or(0.0);
138 }
139 }
140 feature_idx += 1;
141 }
142 }
143
144 if self.config.drop_na && max_lag > 0 {
146 let valid_rows = n_samples - max_lag;
147 let mut trimmed_result = Array2::<Float>::zeros((valid_rows, n_features_out));
148 for i in 0..valid_rows {
149 for j in 0..n_features_out {
150 trimmed_result[[i, j]] = result[[i + max_lag, j]];
151 }
152 }
153 Ok(trimmed_result)
154 } else {
155 Ok(result)
156 }
157 }
158}