Skip to main content

scirs2_fft/scattering/
filter_bank.rs

1//! Morlet wavelet filter bank for scattering transforms
2//!
3//! Constructs a dyadic filter bank of Morlet wavelets in the frequency domain,
4//! along with a low-pass scaling function. The filter bank is parameterized by:
5//! - J: number of octaves (scales)
6//! - Q: quality factor (wavelets per octave)
7//! - signal_length: length of the input signal (determines FFT size)
8
9use std::f64::consts::PI;
10
11use scirs2_core::numeric::Complex64;
12
13use crate::error::{FFTError, FFTResult};
14
15/// Configuration for a Morlet wavelet filter bank.
16#[derive(Debug, Clone)]
17pub struct FilterBankConfig {
18    /// Number of octaves (logarithmic scale range)
19    pub j_max: usize,
20    /// Quality factors per order (wavelets per octave for each scattering order)
21    pub quality_factors: Vec<usize>,
22    /// Length of the input signal
23    pub signal_length: usize,
24    /// Center frequency of the mother wavelet (default: PI)
25    pub xi0: f64,
26    /// Bandwidth parameter sigma (default: computed from Q)
27    pub sigma: Option<f64>,
28}
29
30impl FilterBankConfig {
31    /// Create a new filter bank configuration with default parameters.
32    ///
33    /// # Arguments
34    /// * `j_max` - Number of octaves
35    /// * `quality_factors` - Quality factors per order (e.g., `[8, 1]` for Q1=8, Q2=1)
36    /// * `signal_length` - Length of the input signal
37    pub fn new(j_max: usize, quality_factors: Vec<usize>, signal_length: usize) -> Self {
38        Self {
39            j_max,
40            quality_factors,
41            signal_length,
42            xi0: PI,
43            sigma: None,
44        }
45    }
46
47    /// Set the center frequency of the mother wavelet.
48    #[must_use]
49    pub fn with_xi0(mut self, xi0: f64) -> Self {
50        self.xi0 = xi0;
51        self
52    }
53
54    /// Set a custom bandwidth parameter.
55    #[must_use]
56    pub fn with_sigma(mut self, sigma: f64) -> Self {
57        self.sigma = Some(sigma);
58        self
59    }
60}
61
62/// A single Morlet wavelet: psi(t) = C * exp(-t^2 / (2*sigma^2)) * exp(j*xi*t)
63///
64/// Stored in the frequency domain for efficient convolution.
65#[derive(Debug, Clone)]
66pub struct MorletWavelet {
67    /// Center frequency (radians per sample)
68    pub xi: f64,
69    /// Bandwidth parameter
70    pub sigma: f64,
71    /// Scale index j
72    pub j: usize,
73    /// Sub-octave index within octave
74    pub q_index: usize,
75    /// Linear index (j * Q + q_index)
76    pub linear_index: usize,
77    /// Frequency-domain representation (complex-valued, length = fft_size)
78    pub freq_response: Vec<Complex64>,
79}
80
81/// A wavelet filter bank containing all wavelets and the scaling function.
82#[derive(Debug, Clone)]
83pub struct FilterBank {
84    /// Configuration used to build this filter bank
85    pub config: FilterBankConfig,
86    /// FFT size (next power of 2 >= signal_length)
87    pub fft_size: usize,
88    /// Wavelets for each scattering order (outer: order, inner: wavelet index)
89    pub wavelets: Vec<Vec<MorletWavelet>>,
90    /// Low-pass scaling function in frequency domain
91    pub phi: Vec<Complex64>,
92}
93
94impl FilterBank {
95    /// Construct a new filter bank from configuration.
96    ///
97    /// Builds Morlet wavelets at dyadic scales 2^(j/Q) for j = 0..J*Q-1,
98    /// plus a low-pass scaling function phi that captures content below 2^J.
99    pub fn new(config: FilterBankConfig) -> FFTResult<Self> {
100        if config.j_max == 0 {
101            return Err(FFTError::ValueError("j_max must be at least 1".to_string()));
102        }
103        if config.quality_factors.is_empty() {
104            return Err(FFTError::ValueError(
105                "quality_factors must have at least one entry".to_string(),
106            ));
107        }
108        for (i, &q) in config.quality_factors.iter().enumerate() {
109            if q == 0 {
110                return Err(FFTError::ValueError(format!(
111                    "quality_factors[{i}] must be at least 1"
112                )));
113            }
114        }
115        if config.signal_length == 0 {
116            return Err(FFTError::ValueError(
117                "signal_length must be positive".to_string(),
118            ));
119        }
120
121        let fft_size = config.signal_length.next_power_of_two();
122
123        // Build wavelets for each order
124        let mut all_wavelets = Vec::new();
125        for (order, &q) in config.quality_factors.iter().enumerate() {
126            let sigma_base = compute_sigma_from_q(q, config.xi0, config.sigma);
127            let wavelets =
128                build_morlet_wavelets(config.j_max, q, config.xi0, sigma_base, fft_size, order)?;
129            all_wavelets.push(wavelets);
130        }
131
132        // Build low-pass scaling function
133        let sigma_phi = compute_sigma_from_q(config.quality_factors[0], config.xi0, config.sigma);
134        let phi = build_scaling_function(config.j_max, sigma_phi, fft_size)?;
135
136        Ok(Self {
137            config,
138            fft_size,
139            wavelets: all_wavelets,
140            phi,
141        })
142    }
143
144    /// Number of first-order wavelets (J * Q1).
145    pub fn num_first_order(&self) -> usize {
146        self.wavelets.first().map_or(0, |w| w.len())
147    }
148
149    /// Number of second-order wavelets (J * Q2), if a second order exists.
150    pub fn num_second_order(&self) -> usize {
151        self.wavelets.get(1).map_or(0, |w| w.len())
152    }
153
154    /// Total number of wavelets across all orders.
155    pub fn total_wavelets(&self) -> usize {
156        self.wavelets.iter().map(|w| w.len()).sum()
157    }
158}
159
160/// Compute sigma (bandwidth) from quality factor Q and center frequency xi0.
161///
162/// sigma = xi0 / (2 * ln(2)^(1/2) * Q) ensures that the half-power bandwidth
163/// spans one octave divided by Q.
164fn compute_sigma_from_q(q: usize, xi0: f64, custom_sigma: Option<f64>) -> f64 {
165    if let Some(s) = custom_sigma {
166        return s;
167    }
168    // sigma such that the wavelet has bandwidth ~ xi0 / Q
169    // Using the relation: bandwidth = 2 * sigma * sqrt(2 * ln(2))
170    // Q = xi0 / bandwidth => sigma = xi0 / (Q * 2 * sqrt(2 * ln(2)))
171    let ln2_sqrt = (2.0_f64 * 2.0_f64.ln()).sqrt();
172    xi0 / (q as f64 * ln2_sqrt)
173}
174
175/// Build Morlet wavelets at dyadic scales for a given quality factor.
176fn build_morlet_wavelets(
177    j_max: usize,
178    q: usize,
179    xi0: f64,
180    sigma_base: f64,
181    fft_size: usize,
182    _order: usize,
183) -> FFTResult<Vec<MorletWavelet>> {
184    let total = j_max * q;
185    let mut wavelets = Vec::with_capacity(total);
186    let n = fft_size;
187
188    for idx in 0..total {
189        let j = idx / q;
190        let q_index = idx % q;
191
192        // Scale factor: 2^(idx / Q)
193        let scale = 2.0_f64.powf(idx as f64 / q as f64);
194
195        // Center frequency at this scale
196        let xi = xi0 / scale;
197
198        // Bandwidth at this scale
199        let sigma = sigma_base * scale;
200
201        // Build frequency-domain Morlet wavelet
202        // Psi_hat(omega) = C * exp(-(omega - xi)^2 * sigma^2 / 2)
203        // with correction term to ensure zero mean
204        let mut freq_response = vec![Complex64::new(0.0, 0.0); n];
205        let n_f64 = n as f64;
206
207        for k in 0..n {
208            // Normalized frequency: omega = 2*pi*k/N
209            let omega = 2.0 * PI * k as f64 / n_f64;
210
211            // Gaussian centered at xi
212            let diff_pos = omega - xi;
213            let gauss_pos = (-0.5 * diff_pos * diff_pos * sigma * sigma).exp();
214
215            // Correction term for zero mean (subtract Gaussian at omega=0)
216            let gauss_correction = (-0.5 * xi * xi * sigma * sigma).exp();
217            let gauss_zero = (-0.5 * omega * omega * sigma * sigma).exp();
218
219            let value = gauss_pos - gauss_correction * gauss_zero;
220            freq_response[k] = Complex64::new(value, 0.0);
221        }
222
223        // Normalize: L2 norm in frequency domain = 1
224        let energy: f64 = freq_response.iter().map(|c| c.norm_sqr()).sum();
225        if energy > 1e-15 {
226            let norm_factor = 1.0 / energy.sqrt();
227            for c in &mut freq_response {
228                *c = Complex64::new(c.re * norm_factor, c.im * norm_factor);
229            }
230        }
231
232        wavelets.push(MorletWavelet {
233            xi,
234            sigma,
235            j,
236            q_index,
237            linear_index: idx,
238            freq_response,
239        });
240    }
241
242    Ok(wavelets)
243}
244
245/// Build the low-pass scaling function phi in the frequency domain.
246///
247/// phi_hat(omega) = exp(-omega^2 * sigma_J^2 / 2) where sigma_J = sigma_base * 2^J
248fn build_scaling_function(
249    j_max: usize,
250    sigma_base: f64,
251    fft_size: usize,
252) -> FFTResult<Vec<Complex64>> {
253    let n = fft_size;
254    let n_f64 = n as f64;
255    let sigma_j = sigma_base * 2.0_f64.powi(j_max as i32);
256
257    let mut phi = vec![Complex64::new(0.0, 0.0); n];
258
259    for k in 0..n {
260        let omega = 2.0 * PI * k as f64 / n_f64;
261        // Wrap frequency to [-pi, pi]
262        let omega_wrapped = if omega > PI { omega - 2.0 * PI } else { omega };
263        let value = (-0.5 * omega_wrapped * omega_wrapped * sigma_j * sigma_j).exp();
264        phi[k] = Complex64::new(value, 0.0);
265    }
266
267    // Normalize
268    let energy: f64 = phi.iter().map(|c| c.norm_sqr()).sum();
269    if energy > 1e-15 {
270        let norm_factor = 1.0 / energy.sqrt();
271        for c in &mut phi {
272            *c = Complex64::new(c.re * norm_factor, c.im * norm_factor);
273        }
274    }
275
276    Ok(phi)
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_filter_bank_creation() {
285        let config = FilterBankConfig::new(4, vec![8, 1], 1024);
286        let fb = FilterBank::new(config).expect("filter bank creation should succeed");
287
288        assert_eq!(fb.num_first_order(), 32); // J=4, Q=8 => 32 wavelets
289        assert_eq!(fb.num_second_order(), 4); // J=4, Q=1 => 4 wavelets
290        assert_eq!(fb.fft_size, 1024);
291        assert_eq!(fb.phi.len(), 1024);
292    }
293
294    #[test]
295    fn test_wavelet_frequency_peaks() {
296        // Each wavelet should peak near its center frequency
297        let config = FilterBankConfig::new(3, vec![4], 512);
298        let fb = FilterBank::new(config).expect("filter bank creation should succeed");
299
300        let first_order = &fb.wavelets[0];
301        for w in first_order {
302            // Find peak frequency bin
303            let peak_bin = w
304                .freq_response
305                .iter()
306                .enumerate()
307                .max_by(|(_, a), (_, b)| {
308                    a.norm_sqr()
309                        .partial_cmp(&b.norm_sqr())
310                        .unwrap_or(std::cmp::Ordering::Equal)
311                })
312                .map(|(idx, _)| idx)
313                .expect("should find peak");
314
315            let peak_omega = 2.0 * PI * peak_bin as f64 / fb.fft_size as f64;
316
317            // Peak should be near the wavelet's center frequency xi
318            // Allow generous tolerance due to discretization
319            let rel_error = if w.xi > 1e-6 {
320                (peak_omega - w.xi).abs() / w.xi
321            } else {
322                peak_omega.abs()
323            };
324            assert!(
325                rel_error < 0.5,
326                "wavelet j={} q={}: peak_omega={:.4} vs xi={:.4}, rel_error={:.4}",
327                w.j,
328                w.q_index,
329                peak_omega,
330                w.xi,
331                rel_error
332            );
333        }
334    }
335
336    #[test]
337    fn test_dyadic_scaling() {
338        // Wavelets should be spaced at octave intervals when Q=1
339        let config = FilterBankConfig::new(4, vec![1], 1024);
340        let fb = FilterBank::new(config).expect("filter bank creation should succeed");
341
342        let wavelets = &fb.wavelets[0];
343        // Check that center frequencies decrease by factor ~2 each octave
344        for i in 0..wavelets.len() - 1 {
345            let ratio = wavelets[i].xi / wavelets[i + 1].xi;
346            // Should be approximately 2.0
347            assert!(
348                (ratio - 2.0).abs() < 0.1,
349                "octave {i} to {}: ratio={:.4}, expected ~2.0",
350                i + 1,
351                ratio
352            );
353        }
354    }
355
356    #[test]
357    fn test_filter_bank_invalid_config() {
358        // j_max = 0
359        let config = FilterBankConfig::new(0, vec![8], 1024);
360        assert!(FilterBank::new(config).is_err());
361
362        // empty quality factors
363        let config = FilterBankConfig::new(4, vec![], 1024);
364        assert!(FilterBank::new(config).is_err());
365
366        // quality factor = 0
367        let config = FilterBankConfig::new(4, vec![0], 1024);
368        assert!(FilterBank::new(config).is_err());
369
370        // signal_length = 0
371        let config = FilterBankConfig::new(4, vec![8], 0);
372        assert!(FilterBank::new(config).is_err());
373    }
374
375    #[test]
376    fn test_wavelet_l2_normalization() {
377        let config = FilterBankConfig::new(3, vec![4], 256);
378        let fb = FilterBank::new(config).expect("filter bank creation should succeed");
379
380        for w in &fb.wavelets[0] {
381            let energy: f64 = w.freq_response.iter().map(|c| c.norm_sqr()).sum();
382            assert!(
383                (energy - 1.0).abs() < 1e-10,
384                "wavelet j={} q={} has energy {:.6}, expected 1.0",
385                w.j,
386                w.q_index,
387                energy
388            );
389        }
390    }
391
392    #[test]
393    fn test_scaling_function_is_lowpass() {
394        let config = FilterBankConfig::new(3, vec![4], 512);
395        let fb = FilterBank::new(config).expect("filter bank creation should succeed");
396
397        // phi should peak at DC (bin 0)
398        let dc_mag = fb.phi[0].norm_sqr();
399        let nyquist_bin = fb.fft_size / 2;
400        let nyquist_mag = fb.phi[nyquist_bin].norm_sqr();
401
402        assert!(
403            dc_mag > nyquist_mag,
404            "scaling function should peak at DC: dc={:.6} vs nyquist={:.6}",
405            dc_mag,
406            nyquist_mag
407        );
408    }
409}