1use std::f64::consts::PI;
29
30use scirs2_core::ndarray::Array2;
31use scirs2_core::numeric::Complex64;
32
33use crate::error::{FFTError, FFTResult};
34
35#[non_exhaustive]
39#[derive(Debug, Clone, PartialEq)]
40pub enum WindowType {
41 Gaussian,
43 Hamming,
45 Hann,
47 Blackman,
49 Rectangular,
51}
52
53impl WindowType {
54 pub fn samples(&self, size: usize) -> Vec<f64> {
56 let l = size as f64;
57 match self {
58 WindowType::Gaussian => {
59 let sigma = l / 6.0;
60 let centre = (l - 1.0) / 2.0;
61 (0..size)
62 .map(|n| {
63 let x = (n as f64 - centre) / sigma;
64 (-0.5 * x * x).exp()
65 })
66 .collect()
67 }
68 WindowType::Hamming => (0..size)
69 .map(|n| 0.54 - 0.46 * (2.0 * PI * n as f64 / (l - 1.0)).cos())
70 .collect(),
71 WindowType::Hann => (0..size)
72 .map(|n| 0.5 * (1.0 - (2.0 * PI * n as f64 / (l - 1.0)).cos()))
73 .collect(),
74 WindowType::Blackman => (0..size)
75 .map(|n| {
76 let t = 2.0 * PI * n as f64 / (l - 1.0);
77 0.42 - 0.5 * t.cos() + 0.08 * (2.0 * t).cos()
78 })
79 .collect(),
80 WindowType::Rectangular => vec![1.0; size],
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
89pub struct StfrftConfig {
90 pub alpha: f64,
92 pub window_size: usize,
95 pub hop_size: usize,
97 pub window_type: WindowType,
99 pub oversample: bool,
102}
103
104impl Default for StfrftConfig {
105 fn default() -> Self {
106 Self {
107 alpha: 1.0,
108 window_size: 256,
109 hop_size: 64,
110 window_type: WindowType::Gaussian,
111 oversample: false,
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct StfrftResult {
119 pub coefficients: Array2<Complex64>,
121 pub time_centers: Vec<f64>,
123 pub fractional_freqs: Vec<f64>,
125 pub alpha: f64,
127}
128
129fn grunbaum_eigendecomp(n: usize) -> FFTResult<(Vec<Vec<f64>>, Vec<i32>)> {
145 if n == 0 {
146 return Ok((vec![], vec![]));
147 }
148 if n == 1 {
149 return Ok((vec![vec![1.0]], vec![0]));
150 }
151
152 let mut mat = vec![0.0_f64; n * n];
158 for j in 0..n {
159 mat[j * n + j] = 2.0 * (2.0 * PI * j as f64 / n as f64).cos();
160 if j + 1 < n {
161 mat[j * n + j + 1] = 1.0;
162 mat[(j + 1) * n + j] = 1.0;
163 }
164 }
165 mat[n - 1] = 1.0;
167 mat[(n - 1) * n] = 1.0;
168
169 let (eigenvalues, eigenvectors) = symmetric_jacobi_eig(&mut mat, n);
171
172 let mut order_idx: Vec<usize> = (0..n).collect();
183 order_idx.sort_by(|&a, &b| {
184 eigenvalues[b]
185 .partial_cmp(&eigenvalues[a])
186 .unwrap_or(std::cmp::Ordering::Equal)
187 });
188
189 let mut ev_orders = vec![0_i32; n];
190 for (rank, &idx) in order_idx.iter().enumerate() {
191 ev_orders[idx] = (rank % 4) as i32;
192 }
193
194 Ok((eigenvectors, ev_orders))
195}
196
197fn symmetric_jacobi_eig(mat_flat: &mut [f64], n: usize) -> (Vec<f64>, Vec<Vec<f64>>) {
208 let mut z: Vec<f64> = (0..n * n)
210 .map(|k| if k / n == k % n { 1.0 } else { 0.0 })
211 .collect();
212
213 const MAX_SWEEP: usize = 100;
214 let eps = 1e-15_f64;
215
216 for _ in 0..MAX_SWEEP {
217 let mut max_val = 0.0_f64;
219 let mut p_idx = 0;
220 let mut q_idx = 1;
221 for i in 0..n {
222 for j in (i + 1)..n {
223 let v = mat_flat[i * n + j].abs();
224 if v > max_val {
225 max_val = v;
226 p_idx = i;
227 q_idx = j;
228 }
229 }
230 }
231 if max_val < eps {
232 break;
233 }
234
235 let p = p_idx;
237 let q = q_idx;
238 let app = mat_flat[p * n + p];
239 let aqq = mat_flat[q * n + q];
240 let apq = mat_flat[p * n + q];
241
242 let tau = (aqq - app) / (2.0 * apq);
243 let t = if tau >= 0.0 {
244 1.0 / (tau + (1.0 + tau * tau).sqrt())
245 } else {
246 -1.0 / (-tau + (1.0 + tau * tau).sqrt())
247 };
248 let c = 1.0 / (1.0 + t * t).sqrt();
249 let s = t * c;
250
251 mat_flat[p * n + p] = app - t * apq;
254 mat_flat[q * n + q] = aqq + t * apq;
255 mat_flat[p * n + q] = 0.0;
256 mat_flat[q * n + p] = 0.0;
257
258 for r in 0..n {
260 if r == p || r == q {
261 continue;
262 }
263 let arp = mat_flat[r * n + p];
264 let arq = mat_flat[r * n + q];
265 let new_rp = c * arp - s * arq;
266 let new_rq = s * arp + c * arq;
267 mat_flat[r * n + p] = new_rp;
268 mat_flat[p * n + r] = new_rp;
269 mat_flat[r * n + q] = new_rq;
270 mat_flat[q * n + r] = new_rq;
271 }
272
273 for r in 0..n {
275 let zrp = z[r * n + p];
276 let zrq = z[r * n + q];
277 z[r * n + p] = c * zrp - s * zrq;
278 z[r * n + q] = s * zrp + c * zrq;
279 }
280 }
281
282 let eigenvalues: Vec<f64> = (0..n).map(|i| mat_flat[i * n + i]).collect();
283 let eigenvectors: Vec<Vec<f64>> = (0..n)
284 .map(|j| (0..n).map(|i| z[i * n + j]).collect())
285 .collect();
286
287 (eigenvalues, eigenvectors)
288}
289
290pub fn dfrft(signal: &[Complex64], alpha: f64) -> FFTResult<Vec<Complex64>> {
321 let n = signal.len();
322 if n == 0 {
323 return Ok(vec![]);
324 }
325
326 let (eigvecs, ev_orders) = grunbaum_eigendecomp(n)?;
327
328 let mut vhx = vec![Complex64::new(0.0, 0.0); n];
331 for k in 0..n {
332 let mut sum = Complex64::new(0.0, 0.0);
333 for j in 0..n {
334 sum += Complex64::new(eigvecs[k][j], 0.0) * signal[j];
336 }
337 vhx[k] = sum;
338 }
339
340 let mut dvhx = vec![Complex64::new(0.0, 0.0); n];
342 for k in 0..n {
343 let angle = -PI / 2.0 * ev_orders[k] as f64 * alpha;
344 let frac_eig = Complex64::new(angle.cos(), angle.sin());
345 dvhx[k] = frac_eig * vhx[k];
346 }
347
348 let mut result = vec![Complex64::new(0.0, 0.0); n];
350 for j in 0..n {
351 let mut sum = Complex64::new(0.0, 0.0);
352 for k in 0..n {
353 sum += Complex64::new(eigvecs[k][j], 0.0) * dvhx[k];
354 }
355 result[j] = sum;
356 }
357
358 Ok(result)
359}
360
361pub fn stfrft(signal: &[f64], config: &StfrftConfig) -> FFTResult<StfrftResult> {
393 let sig_len = signal.len();
394 if sig_len == 0 {
395 return Err(FFTError::ValueError("Signal must be non-empty".to_string()));
396 }
397 let win_size = config.window_size;
398 if win_size == 0 {
399 return Err(FFTError::ValueError("Window size must be > 0".to_string()));
400 }
401 let hop = config.hop_size;
402 if hop == 0 {
403 return Err(FFTError::ValueError("Hop size must be > 0".to_string()));
404 }
405
406 let window = config.window_type.samples(win_size);
407
408 let n_frames = if sig_len <= win_size {
410 1
411 } else {
412 (sig_len - win_size) / hop + 1
413 };
414
415 let mut coefficients = Array2::zeros((n_frames, win_size));
416 let mut time_centers = Vec::with_capacity(n_frames);
417
418 for frame_idx in 0..n_frames {
419 let start = frame_idx * hop;
420 let centre = start as f64 + win_size as f64 / 2.0;
421 time_centers.push(centre);
422
423 let frame_complex: Vec<Complex64> = (0..win_size)
425 .map(|i| {
426 let sig_idx = start + i;
427 let sample = if sig_idx < sig_len {
428 signal[sig_idx]
429 } else {
430 0.0
431 };
432 Complex64::new(sample * window[i], 0.0)
433 })
434 .collect();
435
436 let dfrft_out = dfrft(&frame_complex, config.alpha)?;
438
439 for (k, val) in dfrft_out.into_iter().enumerate() {
440 coefficients[[frame_idx, k]] = val;
441 }
442 }
443
444 let fractional_freqs: Vec<f64> = (0..win_size).map(|k| k as f64 / win_size as f64).collect();
445
446 Ok(StfrftResult {
447 coefficients,
448 time_centers,
449 fractional_freqs,
450 alpha: config.alpha,
451 })
452}
453
454pub fn istfrft(
485 result: &StfrftResult,
486 signal_length: usize,
487 hop_size: usize,
488) -> FFTResult<Vec<f64>> {
489 let n_frames = result.coefficients.shape()[0];
490 let win_size = result.coefficients.shape()[1];
491 let hop = if hop_size > 0 { hop_size } else { 1 };
492
493 let mut output = vec![0.0_f64; signal_length];
494 let mut norm = vec![0.0_f64; signal_length];
495
496 let inv_alpha = -result.alpha;
498
499 for frame_idx in 0..n_frames {
500 let start = frame_idx * hop;
501
502 let frame: Vec<Complex64> = (0..win_size)
504 .map(|k| result.coefficients[[frame_idx, k]])
505 .collect();
506
507 let recon = dfrft(&frame, inv_alpha)?;
509
510 for i in 0..win_size {
512 let sig_idx = start + i;
513 if sig_idx < signal_length {
514 output[sig_idx] += recon[i].re;
515 norm[sig_idx] += 1.0;
516 }
517 }
518 }
519
520 for i in 0..signal_length {
522 if norm[i] > 0.0 {
523 output[i] /= norm[i];
524 }
525 }
526
527 Ok(output)
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use approx::assert_relative_eq;
534 use std::f64::consts::PI;
535
536 fn naive_dft(signal: &[f64]) -> Vec<Complex64> {
538 let n = signal.len();
539 (0..n)
540 .map(|k| {
541 (0..n).fold(Complex64::new(0.0, 0.0), |acc, j| {
542 let angle = -2.0 * PI * (j * k) as f64 / n as f64;
543 acc + Complex64::new(signal[j] * angle.cos(), signal[j] * angle.sin())
544 })
545 })
546 .collect()
547 }
548
549 #[test]
550 fn test_dfrft_alpha_0_identity() {
551 let n = 16;
553 let signal: Vec<Complex64> = (0..n)
554 .map(|i| Complex64::new((2.0 * PI * i as f64 / n as f64).sin(), 0.0))
555 .collect();
556 let out = dfrft(&signal, 0.0).expect("dfrft failed");
557 for (a, b) in signal.iter().zip(out.iter()) {
558 assert!((a.re - b.re).abs() < 1e-6, "real mismatch alpha=0 at re");
559 assert!((a.im - b.im).abs() < 1e-6, "imag mismatch alpha=0 at im");
560 }
561 }
562
563 #[test]
564 fn test_dfrft_alpha_1_matches_dft() {
565 let n = 8;
567 let signal: Vec<f64> = (0..n)
568 .map(|i| (2.0 * PI * i as f64 / n as f64).cos())
569 .collect();
570 let signal_complex: Vec<Complex64> =
571 signal.iter().map(|&v| Complex64::new(v, 0.0)).collect();
572
573 let dfrft_out = dfrft(&signal_complex, 1.0).expect("dfrft failed");
574 let dft_ref = naive_dft(&signal);
575
576 let dfrft_mags: Vec<f64> = dfrft_out
579 .iter()
580 .map(|c| (c.re * c.re + c.im * c.im).sqrt())
581 .collect();
582 let dft_mags: Vec<f64> = dft_ref
583 .iter()
584 .map(|c| (c.re * c.re + c.im * c.im).sqrt())
585 .collect();
586
587 let dfrft_peak = dfrft_mags
588 .iter()
589 .enumerate()
590 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
591 .map(|(i, _)| i)
592 .unwrap_or(0);
593
594 let dft_peak = dft_mags
595 .iter()
596 .enumerate()
597 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
598 .map(|(i, _)| i)
599 .unwrap_or(0);
600
601 assert_eq!(
602 dfrft_peak, dft_peak,
603 "DFrFT and DFT should have same peak frequency bin"
604 );
605 }
606
607 #[test]
608 fn test_dfrft_alpha_2_time_reversal() {
609 let n = 8;
616 let signal: Vec<Complex64> = (0..n)
617 .map(|i| Complex64::new((PI * i as f64 / n as f64).sin(), 0.0))
618 .collect();
619
620 let step1 = dfrft(&signal, 1.0).expect("dfrft step1 failed");
622 let step2 = dfrft(&step1, 1.0).expect("dfrft step2 failed");
623 let step3 = dfrft(&step2, 1.0).expect("dfrft step3 failed");
624 let step4 = dfrft(&step3, 1.0).expect("dfrft step4 failed");
625
626 for i in 0..n {
628 let diff = (step4[i].re - signal[i].re).abs() + (step4[i].im - signal[i].im).abs();
629 assert!(
630 diff < 0.3,
631 "Group property DFrFT(4)≈I violated at index {i}: diff={diff}"
632 );
633 }
634 }
635
636 #[test]
637 fn test_stfrft_output_shape() {
638 let signal: Vec<f64> = (0..512).map(|i| (i as f64 * 0.1).sin()).collect();
639 let cfg = StfrftConfig {
640 alpha: 0.8,
641 window_size: 64,
642 hop_size: 16,
643 window_type: WindowType::Hann,
644 oversample: false,
645 };
646 let result = stfrft(&signal, &cfg).expect("stfrft failed");
647 let n_frames = result.coefficients.shape()[0];
648 let n_freqs = result.coefficients.shape()[1];
649
650 assert_eq!(
652 n_freqs, 64,
653 "Expected window_size={} frequency bins",
654 cfg.window_size
655 );
656 assert!(n_frames > 0, "Expected at least one frame");
657 assert_eq!(result.time_centers.len(), n_frames);
658 assert_eq!(result.fractional_freqs.len(), n_freqs);
659 assert_eq!(result.alpha, 0.8);
660 }
661
662 #[test]
663 fn test_stfrft_alpha_1_resembles_stft() {
664 let n = 256;
666 let signal: Vec<f64> = (0..n)
667 .map(|i| (2.0 * PI * 5.0 * i as f64 / n as f64).sin())
668 .collect();
669
670 let cfg = StfrftConfig {
671 alpha: 1.0,
672 window_size: 32,
673 hop_size: 8,
674 window_type: WindowType::Rectangular,
675 oversample: false,
676 };
677 let result = stfrft(&signal, &cfg).expect("stfrft failed");
678 assert_eq!(result.coefficients.shape()[1], 32);
679 assert!(result.coefficients.shape()[0] > 1);
680 }
681
682 #[test]
683 fn test_stfrft_alpha_0_recovers_windowed_signal() {
684 let n = 128;
686 let signal: Vec<f64> = (0..n).map(|i| i as f64 * 0.01).collect();
687
688 let win_size = 16;
689 let hop = 4;
690 let cfg = StfrftConfig {
691 alpha: 0.0,
692 window_size: win_size,
693 hop_size: hop,
694 window_type: WindowType::Rectangular,
695 oversample: false,
696 };
697 let result = stfrft(&signal, &cfg).expect("stfrft failed");
698
699 let frame_0: Vec<f64> = (0..win_size)
701 .map(|k| result.coefficients[[0, k]].re)
702 .collect();
703 for k in 0..win_size {
704 assert!(
705 (frame_0[k] - signal[k]).abs() < 1e-6,
706 "Frame 0 coefficient mismatch at k={k}: {} vs {}",
707 frame_0[k],
708 signal[k]
709 );
710 }
711 }
712
713 #[test]
714 fn test_istfrft_output_length() {
715 let n = 256;
716 let signal: Vec<f64> = (0..n).map(|i| (i as f64 * 0.05).cos()).collect();
717 let cfg = StfrftConfig {
718 alpha: 0.5,
719 window_size: 32,
720 hop_size: 16,
721 window_type: WindowType::Hamming,
722 oversample: false,
723 };
724 let result = stfrft(&signal, &cfg).expect("stfrft failed");
725 let recon = istfrft(&result, n, cfg.hop_size).expect("istfrft failed");
726 assert_eq!(recon.len(), n, "Reconstructed signal has wrong length");
727 }
728
729 #[test]
730 fn test_istfrft_roundtrip_rectangular_window() {
731 let n = 64;
734 let signal: Vec<f64> = (0..n)
735 .map(|i| (2.0 * PI * 3.0 * i as f64 / n as f64).sin())
736 .collect();
737
738 let win_size = 16;
739 let hop = 4;
740 let cfg = StfrftConfig {
741 alpha: 0.0,
742 window_size: win_size,
743 hop_size: hop,
744 window_type: WindowType::Rectangular,
745 oversample: false,
746 };
747 let result = stfrft(&signal, &cfg).expect("stfrft failed");
748 let recon = istfrft(&result, n, hop).expect("istfrft failed");
749
750 let start = win_size;
752 let end = n.saturating_sub(win_size);
753 if start < end {
754 for i in start..end {
755 assert!(
756 (recon[i] - signal[i]).abs() < 0.1,
757 "Roundtrip mismatch at index {i}: {} vs {}",
758 recon[i],
759 signal[i]
760 );
761 }
762 }
763 }
764
765 #[test]
766 fn test_window_type_samples_correct_length() {
767 for size in [8, 16, 64, 256] {
768 for wt in [
769 WindowType::Gaussian,
770 WindowType::Hamming,
771 WindowType::Hann,
772 WindowType::Blackman,
773 WindowType::Rectangular,
774 ] {
775 let samples = wt.samples(size);
776 assert_eq!(
777 samples.len(),
778 size,
779 "Window {wt:?} has wrong sample count for size={size}"
780 );
781 }
782 }
783 }
784
785 #[test]
786 fn test_dfrft_energy_approximately_preserved() {
787 let n = 16;
790 let signal: Vec<Complex64> = (0..n)
791 .map(|i| Complex64::new((PI * i as f64 / n as f64).sin(), 0.0))
792 .collect();
793
794 let out = dfrft(&signal, 1.0).expect("dfrft failed");
795
796 let energy_in: f64 = signal.iter().map(|c| c.re * c.re + c.im * c.im).sum();
797 let energy_out: f64 = out.iter().map(|c| c.re * c.re + c.im * c.im).sum();
798
799 let ratio = energy_out / energy_in;
802 assert!(
803 ratio > 0.1,
804 "Energy ratio {ratio} too small — DFrFT destroyed energy"
805 );
806 }
807}