1use crate::error::{FFTError, FFTResult};
12
13use super::scattering::{ScatteringConfig, ScatteringResult, ScatteringTransform};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum FeatureNormalization {
18 None,
20 Log,
22 L2,
24 Standardize,
26 LogStandardize,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum TimeFrequencyMode {
33 TimeAveraged,
35 TimeSeries,
37}
38
39#[derive(Debug, Clone)]
43pub struct ScatteringFeatureExtractor {
44 transform: ScatteringTransform,
45 normalization: FeatureNormalization,
46 mode: TimeFrequencyMode,
47}
48
49impl ScatteringFeatureExtractor {
50 pub fn new(
58 config: ScatteringConfig,
59 signal_length: usize,
60 normalization: FeatureNormalization,
61 mode: TimeFrequencyMode,
62 ) -> FFTResult<Self> {
63 let transform = ScatteringTransform::new(config, signal_length)?;
64 Ok(Self {
65 transform,
66 normalization,
67 mode,
68 })
69 }
70
71 pub fn extract(&self, signal: &[f64]) -> FFTResult<ScatteringFeatures> {
75 let result = self.transform.transform(signal)?;
76 let features = self.normalize_result(&result)?;
77 Ok(features)
78 }
79
80 fn normalize_result(&self, result: &ScatteringResult) -> FFTResult<ScatteringFeatures> {
82 let num_paths = result.coefficients.len();
83 let output_length = result.output_length;
84
85 let mut matrix: Vec<Vec<f64>> = result
87 .coefficients
88 .iter()
89 .map(|c| c.values.clone())
90 .collect();
91
92 match self.normalization {
94 FeatureNormalization::None => {}
95 FeatureNormalization::Log => {
96 apply_log_normalization(&mut matrix);
97 }
98 FeatureNormalization::L2 => {
99 apply_l2_normalization(&mut matrix);
100 }
101 FeatureNormalization::Standardize => {
102 apply_standardization(&mut matrix);
103 }
104 FeatureNormalization::LogStandardize => {
105 apply_log_normalization(&mut matrix);
106 apply_standardization(&mut matrix);
107 }
108 }
109
110 let feature_vector = match self.mode {
112 TimeFrequencyMode::TimeAveraged => {
113 matrix
115 .iter()
116 .map(|row| {
117 if row.is_empty() {
118 0.0
119 } else {
120 row.iter().sum::<f64>() / row.len() as f64
121 }
122 })
123 .collect()
124 }
125 TimeFrequencyMode::TimeSeries => {
126 matrix.iter().flat_map(|row| row.iter().copied()).collect()
128 }
129 };
130
131 Ok(ScatteringFeatures {
132 feature_vector,
133 num_paths,
134 output_length,
135 num_zeroth: result.num_zeroth,
136 num_first: result.num_first,
137 num_second: result.num_second,
138 normalization: self.normalization,
139 mode: self.mode,
140 })
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct ScatteringFeatures {
147 pub feature_vector: Vec<f64>,
149 pub num_paths: usize,
151 pub output_length: usize,
153 pub num_zeroth: usize,
155 pub num_first: usize,
157 pub num_second: usize,
159 pub normalization: FeatureNormalization,
161 pub mode: TimeFrequencyMode,
163}
164
165impl ScatteringFeatures {
166 pub fn dim(&self) -> usize {
168 self.feature_vector.len()
169 }
170
171 pub fn norm(&self) -> f64 {
173 self.feature_vector
174 .iter()
175 .map(|v| v * v)
176 .sum::<f64>()
177 .sqrt()
178 }
179}
180
181#[derive(Debug, Clone)]
185pub struct JointScatteringFeatures {
186 pub row_features: Vec<ScatteringFeatures>,
188 pub col_features: Vec<ScatteringFeatures>,
190}
191
192impl JointScatteringFeatures {
193 pub fn compute(
202 data: &[f64],
203 rows: usize,
204 cols: usize,
205 config: ScatteringConfig,
206 normalization: FeatureNormalization,
207 ) -> FFTResult<Self> {
208 if data.len() != rows * cols {
209 return Err(FFTError::DimensionError(format!(
210 "data length {} does not match rows={} * cols={}",
211 data.len(),
212 rows,
213 cols
214 )));
215 }
216
217 let row_extractor = ScatteringFeatureExtractor::new(
219 config.clone(),
220 cols,
221 normalization,
222 TimeFrequencyMode::TimeAveraged,
223 )?;
224
225 let mut row_features = Vec::with_capacity(rows);
226 for r in 0..rows {
227 let row_data = &data[r * cols..(r + 1) * cols];
228 let features = row_extractor.extract(row_data)?;
229 row_features.push(features);
230 }
231
232 let col_extractor = ScatteringFeatureExtractor::new(
234 config,
235 rows,
236 normalization,
237 TimeFrequencyMode::TimeAveraged,
238 )?;
239
240 let mut col_features = Vec::with_capacity(cols);
241 for c in 0..cols {
242 let col_data: Vec<f64> = (0..rows).map(|r| data[r * cols + c]).collect();
243 let features = col_extractor.extract(&col_data)?;
244 col_features.push(features);
245 }
246
247 Ok(Self {
248 row_features,
249 col_features,
250 })
251 }
252
253 pub fn flatten(&self) -> Vec<f64> {
255 let mut result = Vec::new();
256 for f in &self.row_features {
257 result.extend_from_slice(&f.feature_vector);
258 }
259 for f in &self.col_features {
260 result.extend_from_slice(&f.feature_vector);
261 }
262 result
263 }
264}
265
266fn apply_log_normalization(matrix: &mut [Vec<f64>]) {
268 for row in matrix.iter_mut() {
269 for v in row.iter_mut() {
270 *v = (1.0 + v.abs()).ln();
271 }
272 }
273}
274
275fn apply_l2_normalization(matrix: &mut [Vec<f64>]) {
277 for row in matrix.iter_mut() {
278 let norm: f64 = row.iter().map(|v| v * v).sum::<f64>().sqrt();
279 if norm > 1e-15 {
280 for v in row.iter_mut() {
281 *v /= norm;
282 }
283 }
284 }
285}
286
287fn apply_standardization(matrix: &mut [Vec<f64>]) {
289 for row in matrix.iter_mut() {
290 if row.is_empty() {
291 continue;
292 }
293 let n = row.len() as f64;
294 let mean: f64 = row.iter().sum::<f64>() / n;
295 let variance: f64 = row.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n;
296 let std_dev = variance.sqrt();
297
298 if std_dev > 1e-15 {
299 for v in row.iter_mut() {
300 *v = (*v - mean) / std_dev;
301 }
302 } else {
303 for v in row.iter_mut() {
305 *v = 0.0;
306 }
307 }
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use std::f64::consts::PI;
315
316 fn make_test_signal(n: usize) -> Vec<f64> {
317 (0..n)
318 .map(|i| {
319 let t = i as f64 / n as f64;
320 (2.0 * PI * 5.0 * t).sin() + 0.3 * (2.0 * PI * 20.0 * t).cos()
321 })
322 .collect()
323 }
324
325 #[test]
326 fn test_log_normalization_handles_zeros() {
327 let config = ScatteringConfig::new(2, vec![2]).with_max_order(1);
328 let extractor = ScatteringFeatureExtractor::new(
329 config,
330 128,
331 FeatureNormalization::Log,
332 TimeFrequencyMode::TimeAveraged,
333 )
334 .expect("extractor creation should succeed");
335
336 let signal = vec![0.0; 128];
338 let features = extractor.extract(&signal).expect("extract should succeed");
339
340 for v in &features.feature_vector {
341 assert!(v.is_finite(), "log-scattering should handle zeros: got {v}");
342 }
343 }
344
345 #[test]
346 fn test_feature_extraction_time_averaged() {
347 let config = ScatteringConfig::new(3, vec![4, 1]);
348 let n = 256;
349 let extractor = ScatteringFeatureExtractor::new(
350 config,
351 n,
352 FeatureNormalization::None,
353 TimeFrequencyMode::TimeAveraged,
354 )
355 .expect("extractor creation should succeed");
356
357 let signal = make_test_signal(n);
358 let features = extractor.extract(&signal).expect("extract should succeed");
359
360 assert_eq!(features.dim(), features.num_paths);
362 assert!(features.norm() > 0.0, "features should be non-trivial");
363 }
364
365 #[test]
366 fn test_feature_extraction_time_series() {
367 let config = ScatteringConfig::new(2, vec![2]).with_max_order(1);
368 let n = 128;
369 let extractor = ScatteringFeatureExtractor::new(
370 config,
371 n,
372 FeatureNormalization::None,
373 TimeFrequencyMode::TimeSeries,
374 )
375 .expect("extractor creation should succeed");
376
377 let signal = make_test_signal(n);
378 let features = extractor.extract(&signal).expect("extract should succeed");
379
380 assert_eq!(features.dim(), features.num_paths * features.output_length);
382 }
383
384 #[test]
385 fn test_l2_normalization() {
386 let mut matrix = vec![vec![3.0, 4.0], vec![0.0, 0.0], vec![1.0, 0.0]];
387 apply_l2_normalization(&mut matrix);
388
389 assert!((matrix[0][0] - 0.6).abs() < 1e-10);
391 assert!((matrix[0][1] - 0.8).abs() < 1e-10);
392
393 assert!((matrix[1][0]).abs() < 1e-10);
395 assert!((matrix[1][1]).abs() < 1e-10);
396
397 assert!((matrix[2][0] - 1.0).abs() < 1e-10);
399 }
400
401 #[test]
402 fn test_standardization() {
403 let mut matrix = vec![vec![2.0, 4.0, 6.0]];
404 apply_standardization(&mut matrix);
405
406 let mean: f64 = matrix[0].iter().sum::<f64>() / 3.0;
408 assert!(
409 mean.abs() < 1e-10,
410 "standardized mean should be ~0, got {mean}"
411 );
412
413 let var: f64 = matrix[0].iter().map(|v| v * v).sum::<f64>() / 3.0;
414 assert!(
415 (var - 1.0).abs() < 1e-10,
416 "standardized variance should be ~1, got {var}"
417 );
418 }
419
420 #[test]
421 fn test_log_standardize_normalization() {
422 let config = ScatteringConfig::new(2, vec![2]).with_max_order(1);
423 let extractor = ScatteringFeatureExtractor::new(
424 config,
425 128,
426 FeatureNormalization::LogStandardize,
427 TimeFrequencyMode::TimeAveraged,
428 )
429 .expect("extractor creation should succeed");
430
431 let signal = make_test_signal(128);
432 let features = extractor.extract(&signal).expect("extract should succeed");
433
434 for v in &features.feature_vector {
436 assert!(v.is_finite(), "LogStandardize should produce finite values");
437 }
438 }
439
440 #[test]
441 fn test_joint_scattering_features() {
442 let rows = 16;
443 let cols = 32;
444 let data: Vec<f64> = (0..rows * cols)
445 .map(|i| {
446 let r = (i / cols) as f64;
447 let c = (i % cols) as f64;
448 (2.0 * PI * r / rows as f64).sin() + (2.0 * PI * c / cols as f64).cos()
449 })
450 .collect();
451
452 let config = ScatteringConfig::new(2, vec![2]).with_max_order(1);
453 let joint =
454 JointScatteringFeatures::compute(&data, rows, cols, config, FeatureNormalization::Log)
455 .expect("joint scattering should succeed");
456
457 assert_eq!(joint.row_features.len(), rows);
458 assert_eq!(joint.col_features.len(), cols);
459
460 let flat = joint.flatten();
461 assert!(!flat.is_empty(), "joint features should not be empty");
462 for v in &flat {
463 assert!(v.is_finite());
464 }
465 }
466
467 #[test]
468 fn test_dimension_mismatch_error() {
469 let config = ScatteringConfig::new(2, vec![2]).with_max_order(1);
470 let result = JointScatteringFeatures::compute(
471 &[1.0, 2.0, 3.0],
472 2,
473 3,
474 config,
475 FeatureNormalization::None,
476 );
477 assert!(result.is_err());
478 }
479}