sklears_preprocessing/temporal/
fourier.rs1use scirs2_core::ndarray::Array1;
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::marker::PhantomData;
13
14#[derive(Debug, Clone)]
16pub struct FourierFeatureGeneratorConfig {
17 pub n_components: usize,
19 pub include_dc: bool,
21 pub include_phase: bool,
23 pub normalize: bool,
25}
26
27impl Default for FourierFeatureGeneratorConfig {
28 fn default() -> Self {
29 Self {
30 n_components: 10,
31 include_dc: true,
32 include_phase: false,
33 normalize: true,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct FourierFeatureGenerator<S> {
41 config: FourierFeatureGeneratorConfig,
42 n_features_out_: Option<usize>,
43 _phantom: PhantomData<S>,
44}
45
46impl FourierFeatureGenerator<Untrained> {
47 pub fn new() -> Self {
49 Self {
50 config: FourierFeatureGeneratorConfig::default(),
51 n_features_out_: None,
52 _phantom: PhantomData,
53 }
54 }
55
56 pub fn n_components(mut self, n_components: usize) -> Self {
58 self.config.n_components = n_components;
59 self
60 }
61
62 pub fn include_dc(mut self, include_dc: bool) -> Self {
64 self.config.include_dc = include_dc;
65 self
66 }
67
68 pub fn include_phase(mut self, include_phase: bool) -> Self {
70 self.config.include_phase = include_phase;
71 self
72 }
73
74 pub fn normalize(mut self, normalize: bool) -> Self {
76 self.config.normalize = normalize;
77 self
78 }
79
80 fn calculate_n_features_out(&self) -> usize {
82 let mut count = self.config.n_components;
83 if !self.config.include_dc {
84 count = count.saturating_sub(1);
85 }
86 if self.config.include_phase {
87 count *= 2; }
89 count
90 }
91}
92
93impl FourierFeatureGenerator<Trained> {
94 pub fn n_features_out(&self) -> usize {
96 self.n_features_out_.expect("Generator should be fitted")
97 }
98
99 fn compute_dft(&self, data: &Array1<Float>) -> Vec<(Float, Float)> {
101 let n = data.len();
102 let mut result = Vec::new();
103
104 let max_k = self.config.n_components.min(n / 2 + 1);
105 let start_k = if self.config.include_dc { 0 } else { 1 };
106
107 for k in start_k..max_k {
108 let mut real_sum = 0.0;
109 let mut imag_sum = 0.0;
110
111 for (i, &x) in data.iter().enumerate() {
112 let angle = -2.0 * std::f64::consts::PI * k as Float * i as Float / n as Float;
113 real_sum += x * angle.cos();
114 imag_sum += x * angle.sin();
115 }
116
117 let magnitude = (real_sum * real_sum + imag_sum * imag_sum).sqrt();
118 let phase = imag_sum.atan2(real_sum);
119
120 result.push((magnitude, phase));
121 }
122
123 result
124 }
125}
126
127impl Default for FourierFeatureGenerator<Untrained> {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133impl Fit<Array1<Float>, ()> for FourierFeatureGenerator<Untrained> {
134 type Fitted = FourierFeatureGenerator<Trained>;
135
136 fn fit(self, x: &Array1<Float>, _y: &()) -> Result<Self::Fitted> {
137 let n_features_out = self.calculate_n_features_out();
138
139 if n_features_out == 0 {
140 return Err(SklearsError::InvalidParameter {
141 name: "n_components".to_string(),
142 reason: "Number of output features must be greater than 0".to_string(),
143 });
144 }
145
146 if x.len() < 2 {
147 return Err(SklearsError::InvalidInput(
148 "Data must have at least 2 points for Fourier analysis".to_string(),
149 ));
150 }
151
152 Ok(FourierFeatureGenerator {
153 config: self.config,
154 n_features_out_: Some(n_features_out),
155 _phantom: PhantomData,
156 })
157 }
158}
159
160impl Transform<Array1<Float>, Array1<Float>> for FourierFeatureGenerator<Trained> {
161 fn transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
162 let fourier_components = self.compute_dft(x);
163 let mut features = Vec::new();
164
165 for (magnitude, phase) in fourier_components {
166 features.push(magnitude);
167 if self.config.include_phase {
168 features.push(phase);
169 }
170 }
171
172 let mut result = Array1::from_vec(features);
173
174 if self.config.normalize {
176 let max_val = result.iter().cloned().fold(0.0, Float::max);
177 if max_val > 1e-10 {
178 result.mapv_inplace(|x| x / max_val);
179 }
180 }
181
182 Ok(result)
183 }
184}