scirs2_transform/signal_transforms/
cwt.rs1use crate::error::{Result, TransformError};
10use rayon::prelude::*;
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
12use scirs2_core::numeric::Complex;
13use scirs2_fft::{fft, ifft};
14use std::f64::consts::PI;
15
16pub trait ContinuousWavelet: Send + Sync {
18 fn wavelet(&self, t: f64, scale: f64) -> Complex<f64>;
20
21 fn name(&self) -> &str;
23
24 fn central_frequency(&self) -> f64 {
26 1.0
27 }
28
29 fn wavelet_fft(&self, omega: f64, scale: f64) -> Complex<f64> {
31 let norm = (2.0 * PI).sqrt();
33 Complex::new((omega * scale).cos() * norm, -(omega * scale).sin() * norm)
34 }
35}
36
37#[derive(Debug, Clone, Copy)]
39pub struct MorletWavelet {
40 pub omega0: f64,
42}
43
44impl MorletWavelet {
45 pub fn new(omega0: f64) -> Self {
47 MorletWavelet { omega0 }
48 }
49
50 pub fn default() -> Self {
52 MorletWavelet::new(6.0)
53 }
54}
55
56impl ContinuousWavelet for MorletWavelet {
57 fn wavelet(&self, t: f64, scale: f64) -> Complex<f64> {
58 let scaled_t = t / scale;
59 let exp_term = (-0.5 * scaled_t * scaled_t).exp();
60 let cos_term = (self.omega0 * scaled_t).cos();
61 let correction = (-0.5 * self.omega0 * self.omega0).exp();
62
63 let value = (exp_term * cos_term - correction * exp_term) / scale.sqrt();
64 Complex::new(value, 0.0)
65 }
66
67 fn name(&self) -> &str {
68 "Morlet"
69 }
70
71 fn central_frequency(&self) -> f64 {
72 self.omega0 / (2.0 * PI)
73 }
74
75 fn wavelet_fft(&self, omega: f64, scale: f64) -> Complex<f64> {
76 let scaled_omega = omega * scale;
77 let arg = -0.5 * (scaled_omega - self.omega0).powi(2);
78 let value = (PI.sqrt() * 2.0).sqrt() * scale.sqrt() * arg.exp();
79 Complex::new(value, 0.0)
80 }
81}
82
83#[derive(Debug, Clone, Copy)]
85pub struct ComplexMorletWavelet {
86 pub omega0: f64,
88 pub sigma: f64,
90}
91
92impl ComplexMorletWavelet {
93 pub fn new(omega0: f64, sigma: f64) -> Self {
95 ComplexMorletWavelet { omega0, sigma }
96 }
97
98 pub fn default() -> Self {
100 ComplexMorletWavelet::new(6.0, 1.0)
101 }
102}
103
104impl ContinuousWavelet for ComplexMorletWavelet {
105 fn wavelet(&self, t: f64, scale: f64) -> Complex<f64> {
106 let scaled_t = t / scale;
107 let exp_term = (-0.5 * scaled_t * scaled_t / (self.sigma * self.sigma)).exp();
108 let complex_exp = Complex::new(
109 (self.omega0 * scaled_t).cos(),
110 (self.omega0 * scaled_t).sin(),
111 );
112
113 (complex_exp * exp_term) / scale.sqrt()
114 }
115
116 fn name(&self) -> &str {
117 "Complex Morlet"
118 }
119
120 fn central_frequency(&self) -> f64 {
121 self.omega0 / (2.0 * PI)
122 }
123}
124
125#[derive(Debug, Clone, Copy)]
127pub struct MexicanHatWavelet {
128 pub sigma: f64,
130}
131
132impl MexicanHatWavelet {
133 pub fn new(sigma: f64) -> Self {
135 MexicanHatWavelet { sigma }
136 }
137
138 pub fn default() -> Self {
140 MexicanHatWavelet::new(1.0)
141 }
142}
143
144impl ContinuousWavelet for MexicanHatWavelet {
145 fn wavelet(&self, t: f64, scale: f64) -> Complex<f64> {
146 let scaled_t = t / scale;
147 let sigma2 = self.sigma * self.sigma;
148 let t2 = scaled_t * scaled_t;
149
150 let norm = 2.0 / (3.0 * self.sigma).sqrt() / PI.powf(0.25);
151 let exp_term = (-t2 / (2.0 * sigma2)).exp();
152 let poly_term = 1.0 - t2 / sigma2;
153
154 let value = norm * poly_term * exp_term / scale.sqrt();
155 Complex::new(value, 0.0)
156 }
157
158 fn name(&self) -> &str {
159 "Mexican Hat"
160 }
161
162 fn central_frequency(&self) -> f64 {
163 1.0 / (2.0 * PI)
164 }
165}
166
167#[derive(Debug, Clone, Copy)]
169pub struct GaussianWavelet {
170 pub order: usize,
172}
173
174impl GaussianWavelet {
175 pub fn new(order: usize) -> Self {
177 GaussianWavelet { order }
178 }
179}
180
181impl ContinuousWavelet for GaussianWavelet {
182 fn wavelet(&self, t: f64, scale: f64) -> Complex<f64> {
183 let scaled_t = t / scale;
184 let exp_term = (-0.5 * scaled_t * scaled_t).exp();
185
186 let value = match self.order {
187 0 => exp_term,
188 1 => -scaled_t * exp_term,
189 2 => (scaled_t * scaled_t - 1.0) * exp_term,
190 _ => {
191 (scaled_t * scaled_t - 1.0) * exp_term
193 }
194 };
195
196 Complex::new(value / scale.sqrt(), 0.0)
197 }
198
199 fn name(&self) -> &str {
200 "Gaussian"
201 }
202}
203
204#[derive(Debug, Clone)]
206pub struct CWT<W: ContinuousWavelet> {
207 wavelet: W,
208 scales: Vec<f64>,
209 sampling_period: f64,
210}
211
212impl<W: ContinuousWavelet> CWT<W> {
213 pub fn new(wavelet: W, scales: Vec<f64>) -> Self {
215 CWT {
216 wavelet,
217 scales,
218 sampling_period: 1.0,
219 }
220 }
221
222 pub fn with_sampling_period(mut self, period: f64) -> Self {
224 self.sampling_period = period;
225 self
226 }
227
228 pub fn with_log_scales(wavelet: W, n_scales: usize, min_scale: f64, max_scale: f64) -> Self {
230 let scales = Self::log_scales(n_scales, min_scale, max_scale);
231 CWT::new(wavelet, scales)
232 }
233
234 fn log_scales(n: usize, min_scale: f64, max_scale: f64) -> Vec<f64> {
236 let log_min = min_scale.ln();
237 let log_max = max_scale.ln();
238 let step = (log_max - log_min) / (n - 1) as f64;
239
240 (0..n).map(|i| (log_min + i as f64 * step).exp()).collect()
241 }
242
243 pub fn transform(&self, signal: &ArrayView1<f64>) -> Result<Array2<Complex<f64>>> {
245 let n = signal.len();
246 let n_scales = self.scales.len();
247
248 if n == 0 {
249 return Err(TransformError::InvalidInput("Empty signal".to_string()));
250 }
251
252 let mut coeffs = Array2::from_elem((n_scales, n), Complex::new(0.0, 0.0));
253
254 for (scale_idx, &scale) in self.scales.iter().enumerate() {
256 for t_idx in 0..n {
257 let mut sum = Complex::new(0.0, 0.0);
258
259 for tau_idx in 0..n {
260 let tau = (tau_idx as f64 - t_idx as f64) * self.sampling_period;
261 let wavelet_val = self.wavelet.wavelet(tau, scale);
262 sum = sum + wavelet_val * signal[tau_idx];
263 }
264
265 coeffs[[scale_idx, t_idx]] = sum * self.sampling_period;
266 }
267 }
268
269 Ok(coeffs)
270 }
271
272 pub fn transform_fft(&self, signal: &ArrayView1<f64>) -> Result<Array2<Complex<f64>>> {
274 let n = signal.len();
275 let n_scales = self.scales.len();
276
277 if n == 0 {
278 return Err(TransformError::InvalidInput("Empty signal".to_string()));
279 }
280
281 let signal_vec: Vec<f64> = signal.iter().copied().collect();
283
284 let signal_fft = fft(&signal_vec, None)?;
286
287 let freqs: Vec<f64> = (0..n)
289 .map(|i| {
290 if i <= n / 2 {
291 2.0 * PI * i as f64 / (n as f64 * self.sampling_period)
292 } else {
293 2.0 * PI * (i as f64 - n as f64) / (n as f64 * self.sampling_period)
294 }
295 })
296 .collect();
297
298 let mut coeffs = Array2::from_elem((n_scales, n), Complex::new(0.0, 0.0));
299
300 for (scale_idx, &scale) in self.scales.iter().enumerate() {
302 let wavelet_fft: Vec<Complex<f64>> = freqs
304 .iter()
305 .map(|&omega| {
306 if omega >= 0.0 {
307 self.wavelet.wavelet_fft(omega, scale).conj()
308 } else {
309 Complex::new(0.0, 0.0)
310 }
311 })
312 .collect();
313
314 let product: Vec<Complex<f64>> = signal_fft
316 .iter()
317 .zip(wavelet_fft.iter())
318 .map(|(&s, &w)| s * w)
319 .collect();
320
321 let cwt_scale = ifft(&product, None)?;
323
324 for (t_idx, &val) in cwt_scale.iter().enumerate() {
326 coeffs[[scale_idx, t_idx]] = val;
327 }
328 }
329
330 Ok(coeffs)
331 }
332
333 pub fn scalogram(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
335 let coeffs = self.transform_fft(signal)?;
336 let (n_scales, n_time) = coeffs.dim();
337
338 let mut scalogram = Array2::zeros((n_scales, n_time));
339 for i in 0..n_scales {
340 for j in 0..n_time {
341 scalogram[[i, j]] = coeffs[[i, j]].norm();
342 }
343 }
344
345 Ok(scalogram)
346 }
347
348 pub fn scales(&self) -> &[f64] {
350 &self.scales
351 }
352
353 pub fn frequencies(&self) -> Vec<f64> {
355 let fc = self.wavelet.central_frequency();
356 self.scales
357 .iter()
358 .map(|&s| fc / (s * self.sampling_period))
359 .collect()
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use approx::assert_abs_diff_eq;
367
368 #[test]
369 fn test_morlet_wavelet() {
370 let wavelet = MorletWavelet::default();
371 let val = wavelet.wavelet(0.0, 1.0);
372
373 assert!(val.re.abs() > 0.0);
374 assert_abs_diff_eq!(val.im, 0.0, epsilon = 1e-10);
375 }
376
377 #[test]
378 fn test_mexican_hat_wavelet() {
379 let wavelet = MexicanHatWavelet::default();
380 let val = wavelet.wavelet(0.0, 1.0);
381
382 assert!(val.re.abs() > 0.0);
383 assert_abs_diff_eq!(val.im, 0.0, epsilon = 1e-10);
384 }
385
386 #[test]
387 fn test_cwt_simple() -> Result<()> {
388 let signal = Array1::from_vec(vec![0.0, 1.0, 0.0, -1.0, 0.0, 1.0, 0.0, -1.0]);
389 let wavelet = MorletWavelet::default();
390 let scales = vec![1.0, 2.0, 4.0];
391
392 let cwt = CWT::new(wavelet, scales);
393 let coeffs = cwt.transform(&signal.view())?;
394
395 assert_eq!(coeffs.dim(), (3, 8));
396
397 Ok(())
398 }
399
400 #[test]
401 fn test_cwt_fft() -> Result<()> {
402 let signal = Array1::from_vec((0..64).map(|i| (i as f64 * 0.1).sin()).collect());
403 let wavelet = MorletWavelet::default();
404 let cwt = CWT::with_log_scales(wavelet, 32, 1.0, 32.0);
405
406 let coeffs = cwt.transform_fft(&signal.view())?;
407
408 assert_eq!(coeffs.dim(), (32, 64));
409
410 Ok(())
411 }
412
413 #[test]
414 fn test_scalogram() -> Result<()> {
415 let signal = Array1::from_vec((0..64).map(|i| (i as f64 * 0.1).sin()).collect());
416 let wavelet = MorletWavelet::default();
417 let cwt = CWT::with_log_scales(wavelet, 16, 1.0, 16.0);
418
419 let scalogram = cwt.scalogram(&signal.view())?;
420
421 assert_eq!(scalogram.dim(), (16, 64));
422 assert!(scalogram.iter().all(|&x| x >= 0.0));
423
424 Ok(())
425 }
426
427 #[test]
428 fn test_log_scales() {
429 let scales = CWT::<MorletWavelet>::log_scales(10, 1.0, 100.0);
430
431 assert_eq!(scales.len(), 10);
432 assert_abs_diff_eq!(scales[0], 1.0, epsilon = 1e-10);
433 assert_abs_diff_eq!(scales[9], 100.0, epsilon = 1e-10);
434
435 for i in 1..scales.len() {
437 let ratio = scales[i] / scales[i - 1];
438 assert!(ratio > 1.0);
439 }
440 }
441}