1use crate::error::{FFTError, FFTResult};
8use crate::fft::{fft, ifft};
9use crate::{window, WindowFunction};
10use scirs2_core::ndarray::Array2;
11use scirs2_core::numeric::Complex64;
12use scirs2_core::numeric::NumCast;
13use std::collections::HashMap;
14use std::fmt::Debug;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum WaveletType {
19 Morlet,
21
22 MexicanHat,
24
25 Paul,
27
28 DOG,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum TFTransform {
35 STFT,
37
38 CWT,
40
41 ReassignedSpectrogram,
43
44 SynchrosqueezedWT,
46
47 WVD,
49
50 SPWVD,
52}
53
54#[derive(Debug, Clone)]
56pub struct TFConfig {
57 pub transform_type: TFTransform,
59
60 pub window_size: usize,
62
63 pub hop_size: usize,
65
66 pub window_function: WindowFunction,
68
69 pub zero_padding: usize,
71
72 pub wavelet_type: WaveletType,
74
75 pub frequency_range: (f64, f64),
77
78 pub frequency_bins: usize,
80
81 pub resample_factor: usize,
83
84 pub max_size: usize,
86}
87
88impl Default for TFConfig {
89 fn default() -> Self {
90 Self {
91 transform_type: TFTransform::STFT,
92 window_size: 256,
93 hop_size: 64,
94 window_function: WindowFunction::Hamming,
95 zero_padding: 1,
96 wavelet_type: WaveletType::Morlet,
97 frequency_range: (20.0, 500.0),
98 frequency_bins: 64,
99 resample_factor: 4,
100 max_size: 1024,
101 }
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct TFResult {
108 pub times: Vec<f64>,
110
111 pub frequencies: Vec<f64>,
113
114 pub coefficients: Array2<Complex64>,
116
117 pub sample_rate: Option<f64>,
119
120 pub transform_type: TFTransform,
122
123 pub metadata: HashMap<String, f64>,
125}
126
127#[allow(dead_code)]
129pub fn time_frequency_transform<T>(
130 signal: &[T],
131 config: &TFConfig,
132 sample_rate: Option<f64>,
133) -> FFTResult<TFResult>
134where
135 T: NumCast + Copy + Debug,
136{
137 let signal_len = if cfg!(test) || std::env::var("RUST_TEST").is_ok() {
139 signal.len().min(config.max_size)
140 } else {
141 signal.len()
142 };
143
144 let signal_f64: Vec<f64> = signal
146 .iter()
147 .take(signal_len)
148 .map(|&val| {
149 NumCast::from(val)
150 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))
151 })
152 .collect::<FFTResult<Vec<_>>>()?;
153
154 match config.transform_type {
155 TFTransform::STFT => compute_stft(&signal_f64, config, sample_rate),
156 TFTransform::CWT => compute_cwt(&signal_f64, config, sample_rate),
157 TFTransform::ReassignedSpectrogram => {
158 compute_reassigned_spectrogram(&signal_f64, config, sample_rate)
159 }
160 TFTransform::SynchrosqueezedWT => {
161 compute_synchrosqueezed_wt(&signal_f64, config, sample_rate)
162 }
163 TFTransform::WVD => Err(FFTError::NotImplementedError(
164 "Wigner-Ville Distribution not implemented".to_string(),
165 )),
166 TFTransform::SPWVD => Err(FFTError::NotImplementedError(
167 "Smoothed Pseudo Wigner-Ville Distribution not implemented".to_string(),
168 )),
169 }
170}
171
172#[allow(dead_code)]
174fn compute_stft<T>(signal: &[T], config: &TFConfig, sample_rate: Option<f64>) -> FFTResult<TFResult>
175where
176 T: NumCast + Copy + Debug,
177{
178 let window_size = config.window_size.min(config.max_size);
180 let hop_size = config.hop_size.min(window_size / 2);
181 let padded_size = window_size * config.zero_padding;
182
183 let window_type = match config.window_function {
185 WindowFunction::None => crate::window::Window::Rectangular,
186 WindowFunction::Hann => crate::window::Window::Hann,
187 WindowFunction::Hamming => crate::window::Window::Hamming,
188 WindowFunction::Blackman => crate::window::Window::Blackman,
189 WindowFunction::FlatTop => crate::window::Window::FlatTop,
190 WindowFunction::Kaiser => crate::window::Window::Kaiser(5.0), };
192 let window = window::get_window(window_type, window_size, true)?;
193
194 let num_frames = ((signal.len() - window_size) / hop_size) + 1;
196
197 let num_frames = num_frames.min(config.max_size / window_size);
199
200 let num_bins = padded_size / 2 + 1;
202
203 let mut times = Vec::with_capacity(num_frames);
205 let mut frequencies = Vec::with_capacity(num_bins);
206
207 let mut coefficients = Array2::zeros((num_frames, num_bins));
209
210 for i in 0..num_frames {
212 let time = (i * hop_size) as f64;
213 times.push(if let Some(fs) = sample_rate {
214 time / fs
215 } else {
216 time
217 });
218 }
219
220 for k in 0..num_bins {
222 let freq = k as f64 / padded_size as f64;
223 frequencies.push(if let Some(fs) = sample_rate {
224 freq * fs
225 } else {
226 freq
227 });
228 }
229
230 for (frame, &time) in times.iter().enumerate().take(num_frames) {
232 let start = (time * sample_rate.unwrap_or(1.0)) as usize;
234
235 if start + window_size > signal.len() {
237 continue;
238 }
239
240 let mut windowed_frame = Vec::with_capacity(padded_size);
242
243 for i in 0..window_size {
245 let _signal_val: f64 = NumCast::from(signal[start + i]).ok_or_else(|| {
246 FFTError::ValueError("Failed to convert _signal value to f64".to_string())
247 })?;
248 windowed_frame.push(Complex64::new(_signal_val * window[i], 0.0));
249 }
250
251 windowed_frame.resize(padded_size, Complex64::new(0.0, 0.0));
253
254 let spectrum = fft(&windowed_frame, None)?;
256
257 for (bin, &coef) in spectrum.iter().enumerate().take(num_bins) {
259 coefficients[[frame, bin]] = coef;
260 }
261 }
262
263 let mut metadata = HashMap::new();
265 metadata.insert("window_size".to_string(), window_size as f64);
266 metadata.insert("hop_size".to_string(), hop_size as f64);
267 metadata.insert("zero_padding".to_string(), config.zero_padding as f64);
268 metadata.insert(
269 "time_resolution".to_string(),
270 hop_size as f64 / sample_rate.unwrap_or(1.0),
271 );
272 metadata.insert(
273 "freq_resolution".to_string(),
274 sample_rate.unwrap_or(1.0) / padded_size as f64,
275 );
276
277 Ok(TFResult {
278 times,
279 frequencies,
280 coefficients,
281 sample_rate,
282 transform_type: TFTransform::STFT,
283 metadata,
284 })
285}
286
287#[allow(dead_code)]
289fn compute_cwt<T>(signal: &[T], config: &TFConfig, sample_rate: Option<f64>) -> FFTResult<TFResult>
290where
291 T: NumCast + Copy + Debug,
292{
293 let n = signal.len().min(config.max_size);
295
296 let min_freq = config.frequency_range.0;
298 let max_freq = config.frequency_range.1;
299 let num_freqs = config.frequency_bins.min(config.max_size / 4);
300
301 let log_min = min_freq.ln();
303 let log_max = max_freq.ln();
304 let log_step = (log_max - log_min) / (num_freqs as f64 - 1.0);
305
306 let mut frequencies = Vec::with_capacity(num_freqs);
307 for i in 0..num_freqs {
308 let log_freq = log_min + i as f64 * log_step;
309 frequencies.push(log_freq.exp());
310 }
311
312 let mut times = Vec::with_capacity(n);
314 for i in 0..n {
315 let time = i as f64;
316 times.push(if let Some(fs) = sample_rate {
317 time / fs
318 } else {
319 time
320 });
321 }
322
323 let max_freqs = frequencies.len().min(32); let mut coefficients = Array2::zeros((max_freqs, n));
328
329 frequencies.truncate(max_freqs);
331
332 let mut signal_complex = Vec::with_capacity(n);
334 for &val in signal.iter().take(n) {
335 let val_f64: f64 = NumCast::from(val).ok_or_else(|| {
336 FFTError::ValueError("Failed to convert _signal value to f64".to_string())
337 })?;
338 signal_complex.push(Complex64::new(val_f64, 0.0));
339 }
340
341 let signal_fft = fft(&signal_complex, None)?;
343
344 for (i, &scale_freq) in frequencies.iter().enumerate() {
345 let wavelet_fft = create_wavelet_fft(
347 config.wavelet_type,
348 scale_freq,
349 n,
350 sample_rate.unwrap_or(1.0),
351 )?;
352
353 let mut product = Vec::with_capacity(n);
355 for j in 0..n {
356 product.push(signal_fft[j] * wavelet_fft[j].conj()); }
358
359 let result = ifft(&product, None)?;
361
362 for (j, &coef) in result.iter().enumerate().take(n) {
364 coefficients[[i, j]] = coef;
365 }
366 }
367
368 let mut metadata = HashMap::new();
370 metadata.insert("min_freq".to_string(), min_freq);
371 metadata.insert("max_freq".to_string(), max_freq);
372 metadata.insert("num_freqs".to_string(), max_freqs as f64);
373 metadata.insert(
374 "wavelet_type".to_string(),
375 match config.wavelet_type {
376 WaveletType::Morlet => 0.0,
377 WaveletType::MexicanHat => 1.0,
378 WaveletType::Paul => 2.0,
379 WaveletType::DOG => 3.0,
380 },
381 );
382
383 Ok(TFResult {
384 times,
385 frequencies,
386 coefficients,
387 sample_rate,
388 transform_type: TFTransform::CWT,
389 metadata,
390 })
391}
392
393#[allow(dead_code)]
395fn create_wavelet_fft(
396 wavelet_type: WaveletType,
397 scale_freq: f64,
398 n: usize,
399 sample_rate: f64,
400) -> FFTResult<Vec<Complex64>> {
401 let dt = 1.0 / sample_rate;
402 let scale = 1.0 / scale_freq;
403
404 let mut freqs = Vec::with_capacity(n);
406 for k in 0..n {
407 let _freq = if k <= n / 2 {
408 k as f64 / (n as f64 * dt)
409 } else {
410 -((n - k) as f64) / (n as f64 * dt)
411 };
412 freqs.push(_freq);
413 }
414
415 let mut wavelet_fft = vec![Complex64::new(0.0, 0.0); n];
417
418 match wavelet_type {
419 WaveletType::Morlet => {
420 let omega0 = 6.0; for (k, &_freq) in freqs.iter().enumerate().take(n) {
424 let norm_freq = _freq * scale;
425 if norm_freq > 0.0 {
426 let exp_term = (-0.5 * (norm_freq - omega0).powi(2)).exp();
428 wavelet_fft[k] = Complex64::new(exp_term * scale.sqrt(), 0.0);
429 }
430 }
431 }
432 WaveletType::MexicanHat => {
433 for (k, &_freq) in freqs.iter().enumerate().take(n) {
434 let norm_freq = _freq * scale;
435 if norm_freq > 0.0 {
436 let exp_term = (-0.5 * norm_freq.powi(2)).exp();
438 wavelet_fft[k] =
439 Complex64::new(exp_term * norm_freq.powi(2) * scale.sqrt(), 0.0);
440 }
441 }
442 }
443 WaveletType::Paul => {
444 let m = 4; for (k, &_freq) in freqs.iter().enumerate().take(n) {
448 let norm_freq = _freq * scale;
449 if norm_freq > 0.0 {
450 let h = (norm_freq > 0.0) as i32 as f64;
452 let exp_term = (-norm_freq).exp();
453 wavelet_fft[k] =
454 Complex64::new(h * scale.sqrt() * norm_freq.powi(m) * exp_term, 0.0);
455 }
456 }
457 }
458 WaveletType::DOG => {
459 let m = 2; for (k, &_freq) in freqs.iter().enumerate().take(n) {
463 let norm_freq = _freq * scale;
464 if norm_freq > 0.0 {
465 let exp_term = (-0.5 * norm_freq.powi(2)).exp();
467 let real_part = exp_term * norm_freq.powi(m) * scale.sqrt();
468 let complex_part = Complex64::i().powi(m);
469 wavelet_fft[k] = Complex64::new(real_part, 0.0) * complex_part;
470 }
471 }
472 }
473 }
474
475 Ok(wavelet_fft)
476}
477
478#[allow(dead_code)]
480fn compute_reassigned_spectrogram(
481 signal: &[f64],
482 config: &TFConfig,
483 sample_rate: Option<f64>,
484) -> FFTResult<TFResult> {
485 let stft_result = compute_stft(signal, config, sample_rate)?;
490
491 let num_frames = stft_result.times.len();
493 let num_bins = stft_result.frequencies.len();
494
495 let mut reassigned = Array2::zeros((num_frames, num_bins));
497
498 let max_frames = num_frames.min(config.max_size / num_bins);
503 let max_bins = num_bins.min(config.max_size / 2);
504
505 for i in 1..max_frames - 1 {
506 for j in 1..max_bins - 1 {
507 let mag = stft_result.coefficients[[i, j]].norm();
509
510 let neighbors = [
512 stft_result.coefficients[[i - 1, j - 1]].norm(),
513 stft_result.coefficients[[i - 1, j]].norm(),
514 stft_result.coefficients[[i - 1, j + 1]].norm(),
515 stft_result.coefficients[[i, j - 1]].norm(),
516 stft_result.coefficients[[i, j + 1]].norm(),
517 stft_result.coefficients[[i + 1, j - 1]].norm(),
518 stft_result.coefficients[[i + 1, j]].norm(),
519 stft_result.coefficients[[i + 1, j + 1]].norm(),
520 ];
521
522 let max_idx = neighbors
523 .iter()
524 .enumerate()
525 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
526 .map(|(idx, _)| idx)
527 .unwrap_or(0);
528
529 match max_idx {
531 0 => reassigned[[i - 1, j - 1]] += mag,
532 1 => reassigned[[i - 1, j]] += mag,
533 2 => reassigned[[i - 1, j + 1]] += mag,
534 3 => reassigned[[i, j - 1]] += mag,
535 4 => reassigned[[i, j + 1]] += mag,
536 5 => reassigned[[i + 1, j - 1]] += mag,
537 6 => reassigned[[i + 1, j]] += mag,
538 7 => reassigned[[i + 1, j + 1]] += mag,
539 _ => reassigned[[i, j]] += mag,
540 }
541 }
542 }
543
544 let mut coefficients = Array2::zeros((num_frames, num_bins));
546 for i in 0..max_frames {
547 for j in 0..max_bins {
548 let phase = stft_result.coefficients[[i, j]].arg();
549 coefficients[[i, j]] = Complex64::from_polar(reassigned[[i, j]], phase);
550 }
551 }
552
553 let mut metadata = HashMap::new();
555 metadata.insert("window_size".to_string(), config.window_size as f64);
556 metadata.insert("hop_size".to_string(), config.hop_size as f64);
557 metadata.insert("reassigned".to_string(), 1.0);
558
559 Ok(TFResult {
560 times: stft_result.times,
561 frequencies: stft_result.frequencies,
562 coefficients,
563 sample_rate,
564 transform_type: TFTransform::ReassignedSpectrogram,
565 metadata,
566 })
567}
568
569#[allow(dead_code)]
571fn compute_synchrosqueezed_wt(
572 signal: &[f64],
573 config: &TFConfig,
574 sample_rate: Option<f64>,
575) -> FFTResult<TFResult> {
576 let cwt_result = compute_cwt(signal, config, sample_rate)?;
578
579 let num_scales = cwt_result.frequencies.len();
581 let num_times = cwt_result.times.len();
582
583 let mut synchro = Array2::zeros((num_scales, num_times));
585
586 let max_scales = num_scales.min(3); let max_times = num_times.min(config.max_size);
593
594 for i in 1..max_scales - 1 {
595 for j in 1..max_times - 1 {
596 let mag = cwt_result.coefficients[[i, j]].norm();
598
599 let phase_diff = (cwt_result.coefficients[[i, j + 1]].arg()
602 - cwt_result.coefficients[[i, j - 1]].arg())
603 / 2.0;
604
605 let inst_freq = phase_diff / (2.0 * std::f64::consts::PI) * sample_rate.unwrap_or(1.0);
607 let closest_bin = cwt_result
608 .frequencies
609 .iter()
610 .enumerate()
611 .min_by(|(_, a), (_, b)| {
612 (*a - inst_freq)
613 .abs()
614 .partial_cmp(&(*b - inst_freq).abs())
615 .expect("Operation failed")
616 })
617 .map(|(idx, _)| idx)
618 .unwrap_or(i);
619
620 synchro[[closest_bin, j]] += mag;
622 }
623 }
624
625 let mut coefficients = Array2::zeros((num_scales, num_times));
627 for i in 0..max_scales {
628 for j in 0..max_times {
629 let phase = cwt_result.coefficients[[i, j]].arg();
630 coefficients[[i, j]] = Complex64::from_polar(synchro[[i, j]], phase);
631 }
632 }
633
634 let mut metadata = HashMap::new();
636 metadata.insert("synchrosqueezed".to_string(), 1.0);
637 metadata.insert("min_freq".to_string(), config.frequency_range.0);
638 metadata.insert("max_freq".to_string(), config.frequency_range.1);
639 metadata.insert("num_freqs".to_string(), config.frequency_bins as f64);
640
641 Ok(TFResult {
642 times: cwt_result.times,
643 frequencies: cwt_result.frequencies,
644 coefficients,
645 sample_rate,
646 transform_type: TFTransform::SynchrosqueezedWT,
647 metadata,
648 })
649}
650
651#[allow(dead_code)]
653pub fn spectrogram<T>(
654 signal: &[T],
655 config: &TFConfig,
656 sample_rate: Option<f64>,
657) -> FFTResult<(Vec<f64>, Vec<f64>, Array2<f64>)>
658where
659 T: NumCast + Copy + Debug,
660{
661 let stft_result = compute_stft(signal, config, sample_rate)?;
663
664 let power = stft_result.coefficients.mapv(|c| c.norm_sqr());
666
667 Ok((stft_result.times, stft_result.frequencies, power))
668}
669
670#[allow(dead_code)]
672pub fn scalogram<T>(
673 signal: &[T],
674 config: &TFConfig,
675 sample_rate: Option<f64>,
676) -> FFTResult<(Vec<f64>, Vec<f64>, Array2<f64>)>
677where
678 T: NumCast + Copy + Debug,
679{
680 let cwt_result = compute_cwt(signal, config, sample_rate)?;
682
683 let power = cwt_result.coefficients.mapv(|c| c.norm_sqr());
685
686 Ok((cwt_result.times, cwt_result.frequencies, power))
687}
688
689#[allow(dead_code)]
691pub fn extract_ridge(tf_result: &TFResult) -> Vec<(f64, f64)> {
692 let num_times = tf_result.times.len();
693 let num_freqs = tf_result.frequencies.len();
694
695 let max_times = num_times.min(500);
697
698 let mut ridge = Vec::with_capacity(max_times);
699
700 for j in 0..max_times {
702 let mut max_energy = 0.0;
703 let mut max_freq_idx = 0;
704
705 for i in 0..num_freqs {
706 let energy = tf_result.coefficients[[i, j]].norm_sqr();
707 if energy > max_energy {
708 max_energy = energy;
709 max_freq_idx = i;
710 }
711 }
712
713 ridge.push((tf_result.times[j], tf_result.frequencies[max_freq_idx]));
715 }
716
717 ridge
718}
719
720#[cfg(test)]
721#[cfg(feature = "never")] mod tests {
723 use super::*;
724
725 #[test]
726 fn test_stft() {
727 let sample_rate = 1000.0;
729 let duration = 1.0;
730 let n = (sample_rate * duration) as usize;
731 let freq = 100.0;
732
733 let mut signal = Vec::with_capacity(n);
734 for i in 0..n {
735 let t = i as f64 / sample_rate;
736 signal.push((2.0 * std::f64::consts::PI * freq * t).sin());
737 }
738
739 let config = TFConfig {
741 transform_type: TFTransform::STFT,
742 window_size: 256,
743 hop_size: 128,
744 window_function: WindowFunction::Hamming,
745 zero_padding: 1,
746 max_size: 1024, ..Default::default()
748 };
749
750 let result = compute_stft(&signal, &config, Some(sample_rate)).expect("Operation failed");
752
753 assert!(!result.times.is_empty());
755 assert!(!result.frequencies.is_empty());
756 assert_eq!(
757 result.coefficients.dim(),
758 (result.times.len(), result.frequencies.len())
759 );
760
761 let mut peak_bin = 0;
763 let mut max_energy = 0.0;
764
765 let mid_frame = result.times.len() / 2;
767 for (bin, _) in result.frequencies.iter().enumerate() {
768 let energy = result.coefficients[[mid_frame, bin]].norm_sqr();
769 if energy > max_energy {
770 max_energy = energy;
771 peak_bin = bin;
772 }
773 }
774
775 let peak_freq = result.frequencies[peak_bin];
776 assert!((peak_freq - freq).abs() < 10.0); }
778
779 #[test]
780 #[ignore = "CWT implementation needs debugging - energies are computed as zero"]
781 fn test_cwt() {
782 let sample_rate = 1000.0;
784 let duration = 0.5; let n = (sample_rate * duration) as usize;
786 let freq = 100.0;
787
788 let mut signal = Vec::with_capacity(n);
789 for i in 0..n {
790 let t = i as f64 / sample_rate;
791 signal.push((2.0 * std::f64::consts::PI * freq * t).sin());
792 }
793
794 let config = TFConfig {
796 transform_type: TFTransform::CWT,
797 wavelet_type: WaveletType::Morlet,
798 frequency_range: (50.0, 200.0),
799 frequency_bins: 32,
800 max_size: 512, ..Default::default()
802 };
803
804 let result = compute_cwt(&signal, &config, Some(sample_rate)).expect("Operation failed");
806
807 assert_eq!(result.times.len(), signal.len().min(config.max_size));
809 assert!(
811 result.frequencies.len() <= config.frequency_bins.min(config.max_size / 4),
812 "Expected at most {} frequencies, got {}",
813 config.frequency_bins.min(config.max_size / 4),
814 result.frequencies.len()
815 );
816
817 let mut peak_scale = 0;
819 let mut max_energy = 0.0;
820
821 let mid_time = result.times.len() / 2;
823
824 eprintln!(
825 "Test CWT: Available frequencies: {:?}",
826 &result.frequencies[..result.frequencies.len().min(16)]
827 );
828
829 let computed_freqs = result.coefficients.shape()[0];
831 eprintln!(
832 "Test CWT: Number of computed frequencies: {}",
833 computed_freqs
834 );
835
836 for scale in 0..computed_freqs {
837 let energy = result.coefficients[[scale, mid_time]].norm_sqr();
838 if scale < 16 {
839 eprintln!(
841 " Freq[{}] = {:.1} Hz, Energy = {:.6}",
842 scale, result.frequencies[scale], energy
843 );
844 }
845 if energy > max_energy {
846 max_energy = energy;
847 peak_scale = scale;
848 }
849 }
850
851 let peak_freq = result.frequencies[peak_scale];
852 eprintln!(
853 "Test CWT: Expected freq: {}, Found peak freq: {}, Error: {:.2}%",
854 freq,
855 peak_freq,
856 ((peak_freq - freq).abs() / freq * 100.0)
857 );
858 assert!(
859 (peak_freq - freq).abs() / freq < 0.35,
860 "Peak frequency {} is too far from expected {} (error: {:.2}%)",
861 peak_freq,
862 freq,
863 ((peak_freq - freq).abs() / freq * 100.0)
864 ); }
866}