Skip to main content

scirs2_transform/signal_transforms/
cwt.rs

1//! Continuous Wavelet Transform (CWT) Implementation
2//!
3//! Provides continuous wavelet transform with multiple mother wavelets including:
4//! - Morlet wavelet
5//! - Mexican Hat wavelet (Ricker wavelet)
6//! - Complex Morlet wavelet
7//! - Gaussian derivatives
8
9use crate::error::{Result, TransformError};
10use rayon::prelude::*;
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
12use scirs2_core::numeric::Complex;
13use scirs2_fft::{fft, ifft};
14use std::f64::consts::PI;
15
16/// Trait for continuous wavelet functions
17pub trait ContinuousWavelet: Send + Sync {
18    /// Compute the wavelet at a given scale and position
19    fn wavelet(&self, t: f64, scale: f64) -> Complex<f64>;
20
21    /// Get the wavelet name
22    fn name(&self) -> &str;
23
24    /// Get the central frequency
25    fn central_frequency(&self) -> f64 {
26        1.0
27    }
28
29    /// Compute the wavelet in frequency domain (for FFT-based CWT)
30    fn wavelet_fft(&self, omega: f64, scale: f64) -> Complex<f64> {
31        // Default implementation - can be overridden for efficiency
32        let norm = (2.0 * PI).sqrt();
33        Complex::new((omega * scale).cos() * norm, -(omega * scale).sin() * norm)
34    }
35}
36
37/// Morlet wavelet (real-valued)
38#[derive(Debug, Clone, Copy)]
39pub struct MorletWavelet {
40    /// Central frequency parameter (omega0)
41    pub omega0: f64,
42}
43
44impl MorletWavelet {
45    /// Create a new Morlet wavelet
46    pub fn new(omega0: f64) -> Self {
47        MorletWavelet { omega0 }
48    }
49
50    /// Create with default omega0 = 6.0
51    pub fn default() -> Self {
52        MorletWavelet::new(6.0)
53    }
54}
55
56impl ContinuousWavelet for MorletWavelet {
57    fn wavelet(&self, t: f64, scale: f64) -> Complex<f64> {
58        let scaled_t = t / scale;
59        let exp_term = (-0.5 * scaled_t * scaled_t).exp();
60        let cos_term = (self.omega0 * scaled_t).cos();
61        let correction = (-0.5 * self.omega0 * self.omega0).exp();
62
63        let value = (exp_term * cos_term - correction * exp_term) / scale.sqrt();
64        Complex::new(value, 0.0)
65    }
66
67    fn name(&self) -> &str {
68        "Morlet"
69    }
70
71    fn central_frequency(&self) -> f64 {
72        self.omega0 / (2.0 * PI)
73    }
74
75    fn wavelet_fft(&self, omega: f64, scale: f64) -> Complex<f64> {
76        let scaled_omega = omega * scale;
77        let arg = -0.5 * (scaled_omega - self.omega0).powi(2);
78        let value = (PI.sqrt() * 2.0).sqrt() * scale.sqrt() * arg.exp();
79        Complex::new(value, 0.0)
80    }
81}
82
83/// Complex Morlet wavelet
84#[derive(Debug, Clone, Copy)]
85pub struct ComplexMorletWavelet {
86    /// Central frequency parameter
87    pub omega0: f64,
88    /// Bandwidth parameter
89    pub sigma: f64,
90}
91
92impl ComplexMorletWavelet {
93    /// Create a new complex Morlet wavelet
94    pub fn new(omega0: f64, sigma: f64) -> Self {
95        ComplexMorletWavelet { omega0, sigma }
96    }
97
98    /// Create with default parameters
99    pub fn default() -> Self {
100        ComplexMorletWavelet::new(6.0, 1.0)
101    }
102}
103
104impl ContinuousWavelet for ComplexMorletWavelet {
105    fn wavelet(&self, t: f64, scale: f64) -> Complex<f64> {
106        let scaled_t = t / scale;
107        let exp_term = (-0.5 * scaled_t * scaled_t / (self.sigma * self.sigma)).exp();
108        let complex_exp = Complex::new(
109            (self.omega0 * scaled_t).cos(),
110            (self.omega0 * scaled_t).sin(),
111        );
112
113        (complex_exp * exp_term) / scale.sqrt()
114    }
115
116    fn name(&self) -> &str {
117        "Complex Morlet"
118    }
119
120    fn central_frequency(&self) -> f64 {
121        self.omega0 / (2.0 * PI)
122    }
123}
124
125/// Mexican Hat wavelet (Ricker wavelet, 2nd derivative of Gaussian)
126#[derive(Debug, Clone, Copy)]
127pub struct MexicanHatWavelet {
128    /// Scaling parameter
129    pub sigma: f64,
130}
131
132impl MexicanHatWavelet {
133    /// Create a new Mexican Hat wavelet
134    pub fn new(sigma: f64) -> Self {
135        MexicanHatWavelet { sigma }
136    }
137
138    /// Create with default sigma = 1.0
139    pub fn default() -> Self {
140        MexicanHatWavelet::new(1.0)
141    }
142}
143
144impl ContinuousWavelet for MexicanHatWavelet {
145    fn wavelet(&self, t: f64, scale: f64) -> Complex<f64> {
146        let scaled_t = t / scale;
147        let sigma2 = self.sigma * self.sigma;
148        let t2 = scaled_t * scaled_t;
149
150        let norm = 2.0 / (3.0 * self.sigma).sqrt() / PI.powf(0.25);
151        let exp_term = (-t2 / (2.0 * sigma2)).exp();
152        let poly_term = 1.0 - t2 / sigma2;
153
154        let value = norm * poly_term * exp_term / scale.sqrt();
155        Complex::new(value, 0.0)
156    }
157
158    fn name(&self) -> &str {
159        "Mexican Hat"
160    }
161
162    fn central_frequency(&self) -> f64 {
163        1.0 / (2.0 * PI)
164    }
165}
166
167/// Gaussian wavelet (nth derivative)
168#[derive(Debug, Clone, Copy)]
169pub struct GaussianWavelet {
170    /// Derivative order
171    pub order: usize,
172}
173
174impl GaussianWavelet {
175    /// Create a new Gaussian wavelet
176    pub fn new(order: usize) -> Self {
177        GaussianWavelet { order }
178    }
179}
180
181impl ContinuousWavelet for GaussianWavelet {
182    fn wavelet(&self, t: f64, scale: f64) -> Complex<f64> {
183        let scaled_t = t / scale;
184        let exp_term = (-0.5 * scaled_t * scaled_t).exp();
185
186        let value = match self.order {
187            0 => exp_term,
188            1 => -scaled_t * exp_term,
189            2 => (scaled_t * scaled_t - 1.0) * exp_term,
190            _ => {
191                // Hermite polynomial approximation for higher orders
192                (scaled_t * scaled_t - 1.0) * exp_term
193            }
194        };
195
196        Complex::new(value / scale.sqrt(), 0.0)
197    }
198
199    fn name(&self) -> &str {
200        "Gaussian"
201    }
202}
203
204/// Continuous Wavelet Transform
205#[derive(Debug, Clone)]
206pub struct CWT<W: ContinuousWavelet> {
207    wavelet: W,
208    scales: Vec<f64>,
209    sampling_period: f64,
210}
211
212impl<W: ContinuousWavelet> CWT<W> {
213    /// Create a new CWT with given wavelet and scales
214    pub fn new(wavelet: W, scales: Vec<f64>) -> Self {
215        CWT {
216            wavelet,
217            scales,
218            sampling_period: 1.0,
219        }
220    }
221
222    /// Set the sampling period
223    pub fn with_sampling_period(mut self, period: f64) -> Self {
224        self.sampling_period = period;
225        self
226    }
227
228    /// Create scales using logarithmic spacing
229    pub fn with_log_scales(wavelet: W, n_scales: usize, min_scale: f64, max_scale: f64) -> Self {
230        let scales = Self::log_scales(n_scales, min_scale, max_scale);
231        CWT::new(wavelet, scales)
232    }
233
234    /// Generate logarithmically spaced scales
235    fn log_scales(n: usize, min_scale: f64, max_scale: f64) -> Vec<f64> {
236        let log_min = min_scale.ln();
237        let log_max = max_scale.ln();
238        let step = (log_max - log_min) / (n - 1) as f64;
239
240        (0..n).map(|i| (log_min + i as f64 * step).exp()).collect()
241    }
242
243    /// Compute CWT using direct convolution
244    pub fn transform(&self, signal: &ArrayView1<f64>) -> Result<Array2<Complex<f64>>> {
245        let n = signal.len();
246        let n_scales = self.scales.len();
247
248        if n == 0 {
249            return Err(TransformError::InvalidInput("Empty signal".to_string()));
250        }
251
252        let mut coeffs = Array2::from_elem((n_scales, n), Complex::new(0.0, 0.0));
253
254        // Compute CWT for each scale
255        for (scale_idx, &scale) in self.scales.iter().enumerate() {
256            for t_idx in 0..n {
257                let mut sum = Complex::new(0.0, 0.0);
258
259                for tau_idx in 0..n {
260                    let tau = (tau_idx as f64 - t_idx as f64) * self.sampling_period;
261                    let wavelet_val = self.wavelet.wavelet(tau, scale);
262                    sum = sum + wavelet_val * signal[tau_idx];
263                }
264
265                coeffs[[scale_idx, t_idx]] = sum * self.sampling_period;
266            }
267        }
268
269        Ok(coeffs)
270    }
271
272    /// Compute CWT using FFT (more efficient for longer signals)
273    pub fn transform_fft(&self, signal: &ArrayView1<f64>) -> Result<Array2<Complex<f64>>> {
274        let n = signal.len();
275        let n_scales = self.scales.len();
276
277        if n == 0 {
278            return Err(TransformError::InvalidInput("Empty signal".to_string()));
279        }
280
281        // Convert signal to f64 vector for FFT
282        let signal_vec: Vec<f64> = signal.iter().copied().collect();
283
284        // Compute FFT of signal
285        let signal_fft = fft(&signal_vec, None)?;
286
287        // Prepare frequency array
288        let freqs: Vec<f64> = (0..n)
289            .map(|i| {
290                if i <= n / 2 {
291                    2.0 * PI * i as f64 / (n as f64 * self.sampling_period)
292                } else {
293                    2.0 * PI * (i as f64 - n as f64) / (n as f64 * self.sampling_period)
294                }
295            })
296            .collect();
297
298        let mut coeffs = Array2::from_elem((n_scales, n), Complex::new(0.0, 0.0));
299
300        // Compute CWT for each scale using FFT
301        for (scale_idx, &scale) in self.scales.iter().enumerate() {
302            // Compute wavelet in frequency domain
303            let wavelet_fft: Vec<Complex<f64>> = freqs
304                .iter()
305                .map(|&omega| {
306                    if omega >= 0.0 {
307                        self.wavelet.wavelet_fft(omega, scale).conj()
308                    } else {
309                        Complex::new(0.0, 0.0)
310                    }
311                })
312                .collect();
313
314            // Multiply in frequency domain
315            let product: Vec<Complex<f64>> = signal_fft
316                .iter()
317                .zip(wavelet_fft.iter())
318                .map(|(&s, &w)| s * w)
319                .collect();
320
321            // Inverse FFT
322            let cwt_scale = ifft(&product, None)?;
323
324            // Store results
325            for (t_idx, &val) in cwt_scale.iter().enumerate() {
326                coeffs[[scale_idx, t_idx]] = val;
327            }
328        }
329
330        Ok(coeffs)
331    }
332
333    /// Compute the scalogram (magnitude of CWT coefficients)
334    pub fn scalogram(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
335        let coeffs = self.transform_fft(signal)?;
336        let (n_scales, n_time) = coeffs.dim();
337
338        let mut scalogram = Array2::zeros((n_scales, n_time));
339        for i in 0..n_scales {
340            for j in 0..n_time {
341                scalogram[[i, j]] = coeffs[[i, j]].norm();
342            }
343        }
344
345        Ok(scalogram)
346    }
347
348    /// Get the scales
349    pub fn scales(&self) -> &[f64] {
350        &self.scales
351    }
352
353    /// Get the frequencies corresponding to scales
354    pub fn frequencies(&self) -> Vec<f64> {
355        let fc = self.wavelet.central_frequency();
356        self.scales
357            .iter()
358            .map(|&s| fc / (s * self.sampling_period))
359            .collect()
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use approx::assert_abs_diff_eq;
367
368    #[test]
369    fn test_morlet_wavelet() {
370        let wavelet = MorletWavelet::default();
371        let val = wavelet.wavelet(0.0, 1.0);
372
373        assert!(val.re.abs() > 0.0);
374        assert_abs_diff_eq!(val.im, 0.0, epsilon = 1e-10);
375    }
376
377    #[test]
378    fn test_mexican_hat_wavelet() {
379        let wavelet = MexicanHatWavelet::default();
380        let val = wavelet.wavelet(0.0, 1.0);
381
382        assert!(val.re.abs() > 0.0);
383        assert_abs_diff_eq!(val.im, 0.0, epsilon = 1e-10);
384    }
385
386    #[test]
387    fn test_cwt_simple() -> Result<()> {
388        let signal = Array1::from_vec(vec![0.0, 1.0, 0.0, -1.0, 0.0, 1.0, 0.0, -1.0]);
389        let wavelet = MorletWavelet::default();
390        let scales = vec![1.0, 2.0, 4.0];
391
392        let cwt = CWT::new(wavelet, scales);
393        let coeffs = cwt.transform(&signal.view())?;
394
395        assert_eq!(coeffs.dim(), (3, 8));
396
397        Ok(())
398    }
399
400    #[test]
401    fn test_cwt_fft() -> Result<()> {
402        let signal = Array1::from_vec((0..64).map(|i| (i as f64 * 0.1).sin()).collect());
403        let wavelet = MorletWavelet::default();
404        let cwt = CWT::with_log_scales(wavelet, 32, 1.0, 32.0);
405
406        let coeffs = cwt.transform_fft(&signal.view())?;
407
408        assert_eq!(coeffs.dim(), (32, 64));
409
410        Ok(())
411    }
412
413    #[test]
414    fn test_scalogram() -> Result<()> {
415        let signal = Array1::from_vec((0..64).map(|i| (i as f64 * 0.1).sin()).collect());
416        let wavelet = MorletWavelet::default();
417        let cwt = CWT::with_log_scales(wavelet, 16, 1.0, 16.0);
418
419        let scalogram = cwt.scalogram(&signal.view())?;
420
421        assert_eq!(scalogram.dim(), (16, 64));
422        assert!(scalogram.iter().all(|&x| x >= 0.0));
423
424        Ok(())
425    }
426
427    #[test]
428    fn test_log_scales() {
429        let scales = CWT::<MorletWavelet>::log_scales(10, 1.0, 100.0);
430
431        assert_eq!(scales.len(), 10);
432        assert_abs_diff_eq!(scales[0], 1.0, epsilon = 1e-10);
433        assert_abs_diff_eq!(scales[9], 100.0, epsilon = 1e-10);
434
435        // Check logarithmic spacing
436        for i in 1..scales.len() {
437            let ratio = scales[i] / scales[i - 1];
438            assert!(ratio > 1.0);
439        }
440    }
441}