Skip to main content

scirs2_fft/scattering/
scattering.rs

1//! Core scattering transform implementation
2//!
3//! Implements Mallat's scattering transform (2012):
4//! - Zeroth order: S0 = x * phi (low-pass average)
5//! - First order: S1 = |x * psi_{lambda1}| * phi
6//! - Second order: S2 = ||x * psi_{lambda1}| * psi_{lambda2}| * phi
7//!
8//! The transform is translation invariant up to scale 2^J and Lipschitz
9//! continuous to deformations, making it suitable as a feature extractor
10//! for classification tasks.
11
12use scirs2_core::numeric::Complex64;
13
14use crate::error::{FFTError, FFTResult};
15use crate::fft::{fft, ifft};
16
17use super::filter_bank::{FilterBank, FilterBankConfig};
18
19/// Configuration for the scattering transform.
20#[derive(Debug, Clone)]
21pub struct ScatteringConfig {
22    /// Number of octaves
23    pub j_max: usize,
24    /// Quality factors per order (e.g., `[8, 1]`)
25    pub quality_factors: Vec<usize>,
26    /// Maximum scattering order (0, 1, or 2)
27    pub max_order: usize,
28    /// Whether to average (convolve with phi) the output
29    pub average: bool,
30    /// Subsampling factor for output (power of 2)
31    pub oversampling: usize,
32}
33
34impl ScatteringConfig {
35    /// Create a default scattering config.
36    ///
37    /// # Arguments
38    /// * `j_max` - Number of octaves
39    /// * `quality_factors` - Quality factors per order
40    pub fn new(j_max: usize, quality_factors: Vec<usize>) -> Self {
41        Self {
42            j_max,
43            quality_factors,
44            max_order: 2,
45            average: true,
46            oversampling: 0,
47        }
48    }
49
50    /// Set maximum scattering order.
51    #[must_use]
52    pub fn with_max_order(mut self, order: usize) -> Self {
53        self.max_order = order.min(2);
54        self
55    }
56
57    /// Set whether to average the output with the scaling function.
58    #[must_use]
59    pub fn with_average(mut self, average: bool) -> Self {
60        self.average = average;
61        self
62    }
63
64    /// Set oversampling factor.
65    #[must_use]
66    pub fn with_oversampling(mut self, oversampling: usize) -> Self {
67        self.oversampling = oversampling;
68        self
69    }
70}
71
72/// Labels identifying which scattering order and path a coefficient belongs to.
73#[derive(Debug, Clone)]
74pub enum ScatteringOrder {
75    /// Zeroth order: S0 = x * phi
76    Zeroth,
77    /// First order: S1 at wavelet index lambda1
78    First { lambda1: usize },
79    /// Second order: S2 at wavelet indices (lambda1, lambda2)
80    Second { lambda1: usize, lambda2: usize },
81}
82
83/// A single scattering coefficient (time series or scalar).
84#[derive(Debug, Clone)]
85pub struct ScatteringCoefficients {
86    /// Which order and path this coefficient belongs to
87    pub order: ScatteringOrder,
88    /// The coefficient values (time samples after subsampling)
89    pub values: Vec<f64>,
90}
91
92/// Complete result of a scattering transform.
93#[derive(Debug, Clone)]
94pub struct ScatteringResult {
95    /// All scattering coefficients, ordered: S0, then S1, then S2
96    pub coefficients: Vec<ScatteringCoefficients>,
97    /// Number of zeroth-order coefficients (always 1)
98    pub num_zeroth: usize,
99    /// Number of first-order paths
100    pub num_first: usize,
101    /// Number of second-order paths
102    pub num_second: usize,
103    /// Subsampled output length
104    pub output_length: usize,
105}
106
107impl ScatteringResult {
108    /// Get zeroth-order coefficients.
109    pub fn zeroth_order(&self) -> &[ScatteringCoefficients] {
110        &self.coefficients[..self.num_zeroth]
111    }
112
113    /// Get first-order coefficients.
114    pub fn first_order(&self) -> &[ScatteringCoefficients] {
115        &self.coefficients[self.num_zeroth..self.num_zeroth + self.num_first]
116    }
117
118    /// Get second-order coefficients.
119    pub fn second_order(&self) -> &[ScatteringCoefficients] {
120        &self.coefficients[self.num_zeroth + self.num_first..]
121    }
122
123    /// Flatten all coefficients into a single feature vector.
124    pub fn flatten(&self) -> Vec<f64> {
125        let mut result = Vec::new();
126        for coeff in &self.coefficients {
127            result.extend_from_slice(&coeff.values);
128        }
129        result
130    }
131
132    /// Total energy across all scattering coefficients.
133    pub fn total_energy(&self) -> f64 {
134        self.coefficients
135            .iter()
136            .flat_map(|c| c.values.iter())
137            .map(|v| v * v)
138            .sum()
139    }
140}
141
142/// The scattering transform engine.
143#[derive(Debug, Clone)]
144pub struct ScatteringTransform {
145    /// Scattering configuration
146    config: ScatteringConfig,
147    /// Pre-built filter bank
148    filter_bank: FilterBank,
149}
150
151impl ScatteringTransform {
152    /// Create a new scattering transform for signals of a given length.
153    ///
154    /// # Arguments
155    /// * `config` - Scattering configuration
156    /// * `signal_length` - Length of input signals
157    pub fn new(config: ScatteringConfig, signal_length: usize) -> FFTResult<Self> {
158        if signal_length == 0 {
159            return Err(FFTError::ValueError(
160                "signal_length must be positive".to_string(),
161            ));
162        }
163
164        let fb_config =
165            FilterBankConfig::new(config.j_max, config.quality_factors.clone(), signal_length);
166        let filter_bank = FilterBank::new(fb_config)?;
167
168        Ok(Self {
169            config,
170            filter_bank,
171        })
172    }
173
174    /// Access the underlying filter bank.
175    pub fn filter_bank(&self) -> &FilterBank {
176        &self.filter_bank
177    }
178
179    /// Compute the scattering transform of a real-valued signal.
180    ///
181    /// Returns scattering coefficients organized by order.
182    pub fn transform(&self, signal: &[f64]) -> FFTResult<ScatteringResult> {
183        if signal.is_empty() {
184            return Err(FFTError::ValueError(
185                "Input signal must not be empty".to_string(),
186            ));
187        }
188
189        let fft_size = self.filter_bank.fft_size;
190
191        // Pad signal to FFT size
192        let mut padded = vec![0.0_f64; fft_size];
193        let copy_len = signal.len().min(fft_size);
194        padded[..copy_len].copy_from_slice(&signal[..copy_len]);
195
196        // Compute FFT of input
197        let x_hat = fft(&padded, Some(fft_size))?;
198
199        // Subsampling factor
200        let subsample = if self.config.average {
201            let base = 2_usize.pow(self.config.j_max as u32);
202            base >> self.config.oversampling.min(self.config.j_max)
203        } else {
204            1
205        };
206        let output_length = fft_size.div_ceil(subsample);
207
208        let mut coefficients = Vec::new();
209
210        let mut num_first = 0;
211        let mut num_second = 0;
212
213        // --- Zeroth order: S0 = x * phi ---
214        let s0 = convolve_and_subsample(&x_hat, &self.filter_bank.phi, fft_size, subsample)?;
215        coefficients.push(ScatteringCoefficients {
216            order: ScatteringOrder::Zeroth,
217            values: s0,
218        });
219        let num_zeroth = 1;
220
221        if self.config.max_order == 0 {
222            return Ok(ScatteringResult {
223                coefficients,
224                num_zeroth,
225                num_first,
226                num_second,
227                output_length,
228            });
229        }
230
231        // --- First order: S1 = |x * psi_{lambda1}| * phi ---
232        let first_order_wavelets = self
233            .filter_bank
234            .wavelets
235            .first()
236            .ok_or_else(|| FFTError::ComputationError("No first-order wavelets".to_string()))?;
237
238        // Store U1 (unaveraged first-order) for second-order computation
239        let mut u1_hats: Vec<Vec<Complex64>> = Vec::new();
240
241        for (lambda1, wavelet) in first_order_wavelets.iter().enumerate() {
242            // x * psi_{lambda1} in frequency domain
243            let convolved: Vec<Complex64> = x_hat
244                .iter()
245                .zip(wavelet.freq_response.iter())
246                .map(|(x, w)| x * w)
247                .collect();
248
249            // IFFT to get time-domain convolution
250            let u1_time = ifft(&convolved, None)?;
251
252            // Modulus: |x * psi_{lambda1}|
253            let u1_mod: Vec<f64> = u1_time.iter().map(|c| c.norm()).collect();
254
255            // Store FFT of modulus for second-order computation
256            if self.config.max_order >= 2 {
257                let u1_mod_hat = fft(&u1_mod, Some(fft_size))?;
258                u1_hats.push(u1_mod_hat);
259            }
260
261            // Average with phi: |x * psi_{lambda1}| * phi
262            let u1_mod_hat_for_avg = if self.config.max_order >= 2 {
263                // Already computed above; reuse
264                u1_hats.last().ok_or_else(|| {
265                    FFTError::ComputationError("u1_hats should not be empty".to_string())
266                })?
267            } else {
268                // Compute just for averaging
269                &fft(&u1_mod, Some(fft_size))?
270            };
271
272            let s1 = convolve_and_subsample(
273                u1_mod_hat_for_avg,
274                &self.filter_bank.phi,
275                fft_size,
276                subsample,
277            )?;
278
279            coefficients.push(ScatteringCoefficients {
280                order: ScatteringOrder::First { lambda1 },
281                values: s1,
282            });
283            num_first += 1;
284        }
285
286        if self.config.max_order < 2 {
287            return Ok(ScatteringResult {
288                coefficients,
289                num_zeroth,
290                num_first,
291                num_second,
292                output_length,
293            });
294        }
295
296        // --- Second order: S2 = ||x * psi_{lambda1}| * psi_{lambda2}| * phi ---
297        // Only for lambda2 > lambda1 (coarser scale than lambda1)
298        let second_order_wavelets = if self.filter_bank.wavelets.len() > 1 {
299            &self.filter_bank.wavelets[1]
300        } else {
301            // Use first-order wavelets if no separate second-order bank
302            &self.filter_bank.wavelets[0]
303        };
304
305        for (lambda1, u1_hat) in u1_hats.iter().enumerate() {
306            for (lambda2, wavelet2) in second_order_wavelets.iter().enumerate() {
307                // Only compute when lambda2 represents a coarser scale
308                // For second-order wavelets with different Q, compare scale indices
309                let first_scale = if !first_order_wavelets.is_empty() {
310                    first_order_wavelets[lambda1].j
311                } else {
312                    0
313                };
314                let second_scale = wavelet2.j;
315
316                if second_scale <= first_scale {
317                    continue;
318                }
319
320                // |U1| * psi_{lambda2}
321                let convolved2: Vec<Complex64> = u1_hat
322                    .iter()
323                    .zip(wavelet2.freq_response.iter())
324                    .map(|(u, w)| u * w)
325                    .collect();
326
327                let u2_time = ifft(&convolved2, None)?;
328
329                // Modulus
330                let u2_mod: Vec<f64> = u2_time.iter().map(|c| c.norm()).collect();
331
332                // Average with phi
333                let u2_mod_hat = fft(&u2_mod, Some(fft_size))?;
334                let s2 = convolve_and_subsample(
335                    &u2_mod_hat,
336                    &self.filter_bank.phi,
337                    fft_size,
338                    subsample,
339                )?;
340
341                coefficients.push(ScatteringCoefficients {
342                    order: ScatteringOrder::Second { lambda1, lambda2 },
343                    values: s2,
344                });
345                num_second += 1;
346            }
347        }
348
349        Ok(ScatteringResult {
350            coefficients,
351            num_zeroth,
352            num_first,
353            num_second,
354            output_length,
355        })
356    }
357
358    /// Compute the scattering transform and return only the feature vector.
359    pub fn features(&self, signal: &[f64]) -> FFTResult<Vec<f64>> {
360        let result = self.transform(signal)?;
361        Ok(result.flatten())
362    }
363}
364
365/// Multiply two spectra element-wise, IFFT, take real part, and subsample.
366fn convolve_and_subsample(
367    x_hat: &[Complex64],
368    filter_hat: &[Complex64],
369    fft_size: usize,
370    subsample: usize,
371) -> FFTResult<Vec<f64>> {
372    // Pointwise multiplication in frequency domain
373    let product: Vec<Complex64> = x_hat
374        .iter()
375        .zip(filter_hat.iter())
376        .map(|(x, f)| x * f)
377        .collect();
378
379    // IFFT
380    let time_domain = ifft(&product, None)?;
381
382    // Subsample and take real part
383    let output_len = fft_size.div_ceil(subsample);
384    let mut result = Vec::with_capacity(output_len);
385    for i in 0..output_len {
386        let idx = i * subsample;
387        if idx < time_domain.len() {
388            result.push(time_domain[idx].re);
389        } else {
390            result.push(0.0);
391        }
392    }
393
394    Ok(result)
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use std::f64::consts::PI;
401
402    #[test]
403    fn test_scattering_basic() {
404        let config = ScatteringConfig::new(3, vec![2, 1]);
405        let st = ScatteringTransform::new(config, 256)
406            .expect("scattering transform creation should succeed");
407
408        // Simple sine wave
409        let signal: Vec<f64> = (0..256)
410            .map(|i| (2.0 * PI * 10.0 * i as f64 / 256.0).sin())
411            .collect();
412
413        let result = st.transform(&signal).expect("transform should succeed");
414
415        assert_eq!(result.num_zeroth, 1);
416        assert!(result.num_first > 0);
417        // Second-order should exist but may be 0 if no valid lambda2 > lambda1
418    }
419
420    #[test]
421    fn test_translation_invariance() {
422        // Translation invariance is demonstrated by comparing the total energy
423        // of first-order scattering coefficients for a signal and its circular shift.
424        // The scattering transform averages over a window of 2^J samples, so
425        // circular translations should produce similar total first-order energies.
426        let config = ScatteringConfig::new(3, vec![4, 1]).with_max_order(1);
427        let n = 512;
428        let st = ScatteringTransform::new(config, n)
429            .expect("scattering transform creation should succeed");
430
431        // Signal: a Gaussian pulse
432        let mut signal1 = vec![0.0; n];
433        for i in 0..n {
434            let t = (i as f64 - 128.0) / 20.0;
435            signal1[i] = (-0.5 * t * t).exp();
436        }
437
438        // Circularly shifted version
439        let shift = 64;
440        let mut signal2 = vec![0.0; n];
441        for i in 0..n {
442            let src = (i + n - shift) % n;
443            signal2[i] = signal1[src];
444        }
445
446        let r1 = st.transform(&signal1).expect("transform should succeed");
447        let r2 = st.transform(&signal2).expect("transform should succeed");
448
449        // Compare total energy of first-order coefficients
450        // Each S1 path is |x * psi| * phi, so translating x circularly
451        // should give similar energies per path
452        let s1_energies_1: Vec<f64> = r1
453            .first_order()
454            .iter()
455            .map(|c| c.values.iter().map(|v| v * v).sum::<f64>())
456            .collect();
457        let s1_energies_2: Vec<f64> = r2
458            .first_order()
459            .iter()
460            .map(|c| c.values.iter().map(|v| v * v).sum::<f64>())
461            .collect();
462
463        let total_e1: f64 = s1_energies_1.iter().sum();
464        let total_e2: f64 = s1_energies_2.iter().sum();
465
466        if total_e1 > 1e-15 {
467            let rel_error = ((total_e1 - total_e2) / total_e1).abs();
468            assert!(
469                rel_error < 0.3,
470                "First-order total energy should be approximately translation invariant, \
471                 rel_error={:.4} (e1={:.4}, e2={:.4})",
472                rel_error,
473                total_e1,
474                total_e2
475            );
476        }
477    }
478
479    #[test]
480    fn test_output_dimensions() {
481        let j = 3;
482        let q1 = 4;
483        let q2 = 1;
484        let config = ScatteringConfig::new(j, vec![q1, q2]);
485        let n = 256;
486        let st = ScatteringTransform::new(config, n)
487            .expect("scattering transform creation should succeed");
488
489        let signal: Vec<f64> = (0..n)
490            .map(|i| (2.0 * PI * 5.0 * i as f64 / n as f64).sin())
491            .collect();
492
493        let result = st.transform(&signal).expect("transform should succeed");
494
495        // First order: J * Q1 = 3 * 4 = 12 paths
496        assert_eq!(result.num_first, j * q1);
497
498        // Second order depends on the lambda2 > lambda1 condition
499        // With Q2=1, we have J*Q2=3 second-order wavelets
500        // For each first-order path, only coarser second-order wavelets apply
501        // num_second can be any non-negative value depending on scale ordering
502        let _ = result.num_second;
503
504        // All coefficients should have the same output length
505        let expected_len = result.output_length;
506        for coeff in &result.coefficients {
507            assert_eq!(
508                coeff.values.len(),
509                expected_len,
510                "coefficient output length mismatch"
511            );
512        }
513    }
514
515    #[test]
516    fn test_energy_approximate_preservation() {
517        let config = ScatteringConfig::new(3, vec![4, 1]);
518        let n = 256;
519        let st = ScatteringTransform::new(config, n)
520            .expect("scattering transform creation should succeed");
521
522        let signal: Vec<f64> = (0..n)
523            .map(|i| {
524                let t = i as f64 / n as f64;
525                (2.0 * PI * 8.0 * t).sin() + 0.5 * (2.0 * PI * 32.0 * t).cos()
526            })
527            .collect();
528
529        let input_energy: f64 = signal.iter().map(|v| v * v).sum();
530        let result = st.transform(&signal).expect("transform should succeed");
531        let scatter_energy = result.total_energy();
532
533        // Scattering energy should be bounded by input energy
534        // Due to subsampling and the scattering inequality, scatter_energy <= input_energy
535        // but a significant fraction should be preserved
536        assert!(scatter_energy > 0.0, "scattering energy should be positive");
537    }
538
539    #[test]
540    fn test_sine_wave_first_order() {
541        let config = ScatteringConfig::new(4, vec![8]).with_max_order(1);
542        let n = 1024;
543        let st = ScatteringTransform::new(config, n)
544            .expect("scattering transform creation should succeed");
545
546        // Pure sine wave at a known frequency
547        let freq = 20.0; // cycles per signal length
548        let signal: Vec<f64> = (0..n)
549            .map(|i| (2.0 * PI * freq * i as f64 / n as f64).sin())
550            .collect();
551
552        let result = st.transform(&signal).expect("transform should succeed");
553
554        // First-order coefficients: the strongest response should come from
555        // the wavelet whose center frequency is closest to the sine frequency.
556        let first = result.first_order();
557        assert!(!first.is_empty(), "should have first-order coefficients");
558
559        // Find the path with maximum energy
560        let max_path = first
561            .iter()
562            .enumerate()
563            .max_by(|(_, a), (_, b)| {
564                let ea: f64 = a.values.iter().map(|v| v * v).sum();
565                let eb: f64 = b.values.iter().map(|v| v * v).sum();
566                ea.partial_cmp(&eb).unwrap_or(std::cmp::Ordering::Equal)
567            })
568            .map(|(idx, _)| idx);
569
570        assert!(max_path.is_some(), "should find a path with maximum energy");
571    }
572
573    #[test]
574    fn test_zeroth_order_only() {
575        let config = ScatteringConfig::new(3, vec![4]).with_max_order(0);
576        let n = 128;
577        let st = ScatteringTransform::new(config, n)
578            .expect("scattering transform creation should succeed");
579
580        let signal: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
581        let result = st.transform(&signal).expect("transform should succeed");
582
583        assert_eq!(result.num_zeroth, 1);
584        assert_eq!(result.num_first, 0);
585        assert_eq!(result.num_second, 0);
586    }
587
588    #[test]
589    fn test_empty_signal_error() {
590        let config = ScatteringConfig::new(3, vec![4]);
591        let st = ScatteringTransform::new(config, 128)
592            .expect("scattering transform creation should succeed");
593
594        let result = st.transform(&[]);
595        assert!(result.is_err());
596    }
597}