Skip to main content

scirs2_fft/scattering/
features.rs

1//! Scattering feature extraction and normalization
2//!
3//! Provides utilities for extracting usable feature vectors from scattering
4//! transform results, including:
5//! - Log-scattering normalization: log(1 + |Sx|)
6//! - L2 normalization per order
7//! - Standardization (zero mean, unit variance)
8//! - Feature concatenation and flattening
9//! - Joint time-frequency features for 2D signals
10
11use crate::error::{FFTError, FFTResult};
12
13use super::scattering::{ScatteringConfig, ScatteringResult, ScatteringTransform};
14
15/// Normalization methods for scattering features.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum FeatureNormalization {
18    /// No normalization (raw coefficients)
19    None,
20    /// Log-scattering: log(1 + |Sx|)
21    Log,
22    /// L2 normalization per coefficient path
23    L2,
24    /// Standardization: (x - mean) / std
25    Standardize,
26    /// Log followed by standardization
27    LogStandardize,
28}
29
30/// Controls how time dimension is handled.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum TimeFrequencyMode {
33    /// Average over time dimension (produces a single feature vector per signal)
34    TimeAveraged,
35    /// Keep time dimension (produces a time-frequency matrix)
36    TimeSeries,
37}
38
39/// Scattering feature extractor.
40///
41/// Wraps a `ScatteringTransform` and provides normalized feature extraction.
42#[derive(Debug, Clone)]
43pub struct ScatteringFeatureExtractor {
44    transform: ScatteringTransform,
45    normalization: FeatureNormalization,
46    mode: TimeFrequencyMode,
47}
48
49impl ScatteringFeatureExtractor {
50    /// Create a new feature extractor.
51    ///
52    /// # Arguments
53    /// * `config` - Scattering configuration
54    /// * `signal_length` - Expected input signal length
55    /// * `normalization` - Normalization method
56    /// * `mode` - Time-frequency handling mode
57    pub fn new(
58        config: ScatteringConfig,
59        signal_length: usize,
60        normalization: FeatureNormalization,
61        mode: TimeFrequencyMode,
62    ) -> FFTResult<Self> {
63        let transform = ScatteringTransform::new(config, signal_length)?;
64        Ok(Self {
65            transform,
66            normalization,
67            mode,
68        })
69    }
70
71    /// Extract features from a signal.
72    ///
73    /// Returns a `ScatteringFeatures` containing the normalized feature representation.
74    pub fn extract(&self, signal: &[f64]) -> FFTResult<ScatteringFeatures> {
75        let result = self.transform.transform(signal)?;
76        let features = self.normalize_result(&result)?;
77        Ok(features)
78    }
79
80    /// Normalize a scattering result into features.
81    fn normalize_result(&self, result: &ScatteringResult) -> FFTResult<ScatteringFeatures> {
82        let num_paths = result.coefficients.len();
83        let output_length = result.output_length;
84
85        // Collect all coefficient paths as rows
86        let mut matrix: Vec<Vec<f64>> = result
87            .coefficients
88            .iter()
89            .map(|c| c.values.clone())
90            .collect();
91
92        // Apply normalization
93        match self.normalization {
94            FeatureNormalization::None => {}
95            FeatureNormalization::Log => {
96                apply_log_normalization(&mut matrix);
97            }
98            FeatureNormalization::L2 => {
99                apply_l2_normalization(&mut matrix);
100            }
101            FeatureNormalization::Standardize => {
102                apply_standardization(&mut matrix);
103            }
104            FeatureNormalization::LogStandardize => {
105                apply_log_normalization(&mut matrix);
106                apply_standardization(&mut matrix);
107            }
108        }
109
110        // Reduce time dimension if requested
111        let feature_vector = match self.mode {
112            TimeFrequencyMode::TimeAveraged => {
113                // Average each path over time
114                matrix
115                    .iter()
116                    .map(|row| {
117                        if row.is_empty() {
118                            0.0
119                        } else {
120                            row.iter().sum::<f64>() / row.len() as f64
121                        }
122                    })
123                    .collect()
124            }
125            TimeFrequencyMode::TimeSeries => {
126                // Flatten: concatenate all paths
127                matrix.iter().flat_map(|row| row.iter().copied()).collect()
128            }
129        };
130
131        Ok(ScatteringFeatures {
132            feature_vector,
133            num_paths,
134            output_length,
135            num_zeroth: result.num_zeroth,
136            num_first: result.num_first,
137            num_second: result.num_second,
138            normalization: self.normalization,
139            mode: self.mode,
140        })
141    }
142}
143
144/// Normalized scattering features ready for downstream use.
145#[derive(Debug, Clone)]
146pub struct ScatteringFeatures {
147    /// The feature vector (flattened or time-averaged)
148    pub feature_vector: Vec<f64>,
149    /// Number of scattering paths
150    pub num_paths: usize,
151    /// Output length per path (before time-averaging)
152    pub output_length: usize,
153    /// Number of zeroth-order paths
154    pub num_zeroth: usize,
155    /// Number of first-order paths
156    pub num_first: usize,
157    /// Number of second-order paths
158    pub num_second: usize,
159    /// Normalization applied
160    pub normalization: FeatureNormalization,
161    /// Time-frequency mode used
162    pub mode: TimeFrequencyMode,
163}
164
165impl ScatteringFeatures {
166    /// Dimensionality of the feature vector.
167    pub fn dim(&self) -> usize {
168        self.feature_vector.len()
169    }
170
171    /// L2 norm of the feature vector.
172    pub fn norm(&self) -> f64 {
173        self.feature_vector
174            .iter()
175            .map(|v| v * v)
176            .sum::<f64>()
177            .sqrt()
178    }
179}
180
181/// Joint time-frequency scattering features for 2D signals (basic version).
182///
183/// Applies 1D scattering along rows and columns independently, then combines.
184#[derive(Debug, Clone)]
185pub struct JointScatteringFeatures {
186    /// Features from row-wise scattering
187    pub row_features: Vec<ScatteringFeatures>,
188    /// Features from column-wise scattering
189    pub col_features: Vec<ScatteringFeatures>,
190}
191
192impl JointScatteringFeatures {
193    /// Compute joint scattering features for a 2D signal (row-major layout).
194    ///
195    /// # Arguments
196    /// * `data` - 2D signal in row-major order
197    /// * `rows` - Number of rows
198    /// * `cols` - Number of columns
199    /// * `config` - Scattering configuration
200    /// * `normalization` - Normalization method
201    pub fn compute(
202        data: &[f64],
203        rows: usize,
204        cols: usize,
205        config: ScatteringConfig,
206        normalization: FeatureNormalization,
207    ) -> FFTResult<Self> {
208        if data.len() != rows * cols {
209            return Err(FFTError::DimensionError(format!(
210                "data length {} does not match rows={} * cols={}",
211                data.len(),
212                rows,
213                cols
214            )));
215        }
216
217        // Row-wise scattering
218        let row_extractor = ScatteringFeatureExtractor::new(
219            config.clone(),
220            cols,
221            normalization,
222            TimeFrequencyMode::TimeAveraged,
223        )?;
224
225        let mut row_features = Vec::with_capacity(rows);
226        for r in 0..rows {
227            let row_data = &data[r * cols..(r + 1) * cols];
228            let features = row_extractor.extract(row_data)?;
229            row_features.push(features);
230        }
231
232        // Column-wise scattering
233        let col_extractor = ScatteringFeatureExtractor::new(
234            config,
235            rows,
236            normalization,
237            TimeFrequencyMode::TimeAveraged,
238        )?;
239
240        let mut col_features = Vec::with_capacity(cols);
241        for c in 0..cols {
242            let col_data: Vec<f64> = (0..rows).map(|r| data[r * cols + c]).collect();
243            let features = col_extractor.extract(&col_data)?;
244            col_features.push(features);
245        }
246
247        Ok(Self {
248            row_features,
249            col_features,
250        })
251    }
252
253    /// Flatten into a single feature vector by concatenating row and column features.
254    pub fn flatten(&self) -> Vec<f64> {
255        let mut result = Vec::new();
256        for f in &self.row_features {
257            result.extend_from_slice(&f.feature_vector);
258        }
259        for f in &self.col_features {
260            result.extend_from_slice(&f.feature_vector);
261        }
262        result
263    }
264}
265
266/// Apply log-scattering normalization: x -> log(1 + |x|)
267fn apply_log_normalization(matrix: &mut [Vec<f64>]) {
268    for row in matrix.iter_mut() {
269        for v in row.iter_mut() {
270            *v = (1.0 + v.abs()).ln();
271        }
272    }
273}
274
275/// Apply L2 normalization to each row independently.
276fn apply_l2_normalization(matrix: &mut [Vec<f64>]) {
277    for row in matrix.iter_mut() {
278        let norm: f64 = row.iter().map(|v| v * v).sum::<f64>().sqrt();
279        if norm > 1e-15 {
280            for v in row.iter_mut() {
281                *v /= norm;
282            }
283        }
284    }
285}
286
287/// Apply standardization (zero mean, unit variance) to each row.
288fn apply_standardization(matrix: &mut [Vec<f64>]) {
289    for row in matrix.iter_mut() {
290        if row.is_empty() {
291            continue;
292        }
293        let n = row.len() as f64;
294        let mean: f64 = row.iter().sum::<f64>() / n;
295        let variance: f64 = row.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n;
296        let std_dev = variance.sqrt();
297
298        if std_dev > 1e-15 {
299            for v in row.iter_mut() {
300                *v = (*v - mean) / std_dev;
301            }
302        } else {
303            // Constant row: set to zero
304            for v in row.iter_mut() {
305                *v = 0.0;
306            }
307        }
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use std::f64::consts::PI;
315
316    fn make_test_signal(n: usize) -> Vec<f64> {
317        (0..n)
318            .map(|i| {
319                let t = i as f64 / n as f64;
320                (2.0 * PI * 5.0 * t).sin() + 0.3 * (2.0 * PI * 20.0 * t).cos()
321            })
322            .collect()
323    }
324
325    #[test]
326    fn test_log_normalization_handles_zeros() {
327        let config = ScatteringConfig::new(2, vec![2]).with_max_order(1);
328        let extractor = ScatteringFeatureExtractor::new(
329            config,
330            128,
331            FeatureNormalization::Log,
332            TimeFrequencyMode::TimeAveraged,
333        )
334        .expect("extractor creation should succeed");
335
336        // Zero signal should produce finite features
337        let signal = vec![0.0; 128];
338        let features = extractor.extract(&signal).expect("extract should succeed");
339
340        for v in &features.feature_vector {
341            assert!(v.is_finite(), "log-scattering should handle zeros: got {v}");
342        }
343    }
344
345    #[test]
346    fn test_feature_extraction_time_averaged() {
347        let config = ScatteringConfig::new(3, vec![4, 1]);
348        let n = 256;
349        let extractor = ScatteringFeatureExtractor::new(
350            config,
351            n,
352            FeatureNormalization::None,
353            TimeFrequencyMode::TimeAveraged,
354        )
355        .expect("extractor creation should succeed");
356
357        let signal = make_test_signal(n);
358        let features = extractor.extract(&signal).expect("extract should succeed");
359
360        // Time-averaged: one value per path
361        assert_eq!(features.dim(), features.num_paths);
362        assert!(features.norm() > 0.0, "features should be non-trivial");
363    }
364
365    #[test]
366    fn test_feature_extraction_time_series() {
367        let config = ScatteringConfig::new(2, vec![2]).with_max_order(1);
368        let n = 128;
369        let extractor = ScatteringFeatureExtractor::new(
370            config,
371            n,
372            FeatureNormalization::None,
373            TimeFrequencyMode::TimeSeries,
374        )
375        .expect("extractor creation should succeed");
376
377        let signal = make_test_signal(n);
378        let features = extractor.extract(&signal).expect("extract should succeed");
379
380        // Time-series: num_paths * output_length values
381        assert_eq!(features.dim(), features.num_paths * features.output_length);
382    }
383
384    #[test]
385    fn test_l2_normalization() {
386        let mut matrix = vec![vec![3.0, 4.0], vec![0.0, 0.0], vec![1.0, 0.0]];
387        apply_l2_normalization(&mut matrix);
388
389        // [3,4] -> [0.6, 0.8] (norm=5)
390        assert!((matrix[0][0] - 0.6).abs() < 1e-10);
391        assert!((matrix[0][1] - 0.8).abs() < 1e-10);
392
393        // [0,0] -> stays [0,0] (zero norm)
394        assert!((matrix[1][0]).abs() < 1e-10);
395        assert!((matrix[1][1]).abs() < 1e-10);
396
397        // [1,0] -> [1,0] (norm=1)
398        assert!((matrix[2][0] - 1.0).abs() < 1e-10);
399    }
400
401    #[test]
402    fn test_standardization() {
403        let mut matrix = vec![vec![2.0, 4.0, 6.0]];
404        apply_standardization(&mut matrix);
405
406        // mean=4, std=sqrt(8/3)
407        let mean: f64 = matrix[0].iter().sum::<f64>() / 3.0;
408        assert!(
409            mean.abs() < 1e-10,
410            "standardized mean should be ~0, got {mean}"
411        );
412
413        let var: f64 = matrix[0].iter().map(|v| v * v).sum::<f64>() / 3.0;
414        assert!(
415            (var - 1.0).abs() < 1e-10,
416            "standardized variance should be ~1, got {var}"
417        );
418    }
419
420    #[test]
421    fn test_log_standardize_normalization() {
422        let config = ScatteringConfig::new(2, vec![2]).with_max_order(1);
423        let extractor = ScatteringFeatureExtractor::new(
424            config,
425            128,
426            FeatureNormalization::LogStandardize,
427            TimeFrequencyMode::TimeAveraged,
428        )
429        .expect("extractor creation should succeed");
430
431        let signal = make_test_signal(128);
432        let features = extractor.extract(&signal).expect("extract should succeed");
433
434        // Should produce finite, non-trivial features
435        for v in &features.feature_vector {
436            assert!(v.is_finite(), "LogStandardize should produce finite values");
437        }
438    }
439
440    #[test]
441    fn test_joint_scattering_features() {
442        let rows = 16;
443        let cols = 32;
444        let data: Vec<f64> = (0..rows * cols)
445            .map(|i| {
446                let r = (i / cols) as f64;
447                let c = (i % cols) as f64;
448                (2.0 * PI * r / rows as f64).sin() + (2.0 * PI * c / cols as f64).cos()
449            })
450            .collect();
451
452        let config = ScatteringConfig::new(2, vec![2]).with_max_order(1);
453        let joint =
454            JointScatteringFeatures::compute(&data, rows, cols, config, FeatureNormalization::Log)
455                .expect("joint scattering should succeed");
456
457        assert_eq!(joint.row_features.len(), rows);
458        assert_eq!(joint.col_features.len(), cols);
459
460        let flat = joint.flatten();
461        assert!(!flat.is_empty(), "joint features should not be empty");
462        for v in &flat {
463            assert!(v.is_finite());
464        }
465    }
466
467    #[test]
468    fn test_dimension_mismatch_error() {
469        let config = ScatteringConfig::new(2, vec![2]).with_max_order(1);
470        let result = JointScatteringFeatures::compute(
471            &[1.0, 2.0, 3.0],
472            2,
473            3,
474            config,
475            FeatureNormalization::None,
476        );
477        assert!(result.is_err());
478    }
479}