Skip to main content

scirs2_fft/fractional/
stfrft.rs

1//! Short-Time Fractional Fourier Transform (STFRFT) and Discrete FrFT.
2//!
3//! # Discrete Fractional Fourier Transform (DFrFT)
4//!
5//! The DFrFT of order α is defined as the α-th power of the DFT matrix.
6//! Using the Grünbaum tridiagonal commuting matrix approach, the DFT eigenvectors
7//! are approximated by the real symmetric tridiagonal matrix:
8//!
9//! ```text
10//! T[j,j]   = 2 cos(2πj/N)
11//! T[j,j±1] = 1
12//! ```
13//!
14//! This commutes with the DFT matrix, so its eigenvectors are the discrete
15//! Hermite-Gauss functions. The DFrFT is then:
16//!
17//! ```text
18//! DFrFT(α)[x] = V * diag(λ_k^α) * V^H * x
19//! ```
20//!
21//! where V are the eigenvectors and λ_k ∈ {1, −i, −1, i} are the DFT eigenvalues.
22//!
23//! # Short-Time FrFT (STFRFT)
24//!
25//! STFRFT computes the DFrFT on overlapping windowed segments, producing a
26//! 2-D time-fractional-frequency representation analogous to the STFT spectrogram.
27
28use std::f64::consts::PI;
29
30use scirs2_core::ndarray::Array2;
31use scirs2_core::numeric::Complex64;
32
33use crate::error::{FFTError, FFTResult};
34
35// ── Window type ───────────────────────────────────────────────────────────────
36
37/// Window function to apply to each segment before the DFrFT.
38#[non_exhaustive]
39#[derive(Debug, Clone, PartialEq)]
40pub enum WindowType {
41    /// Gaussian window: `w[n] = exp(-0.5 (n - (L-1)/2)^2 / sigma^2)`, sigma = L/6.
42    Gaussian,
43    /// Hamming window: 0.54 − 0.46 cos(2πn/(L−1)).
44    Hamming,
45    /// Hann window: 0.5 (1 − cos(2πn/(L−1))).
46    Hann,
47    /// Blackman window: 0.42 − 0.5 cos(2πn/(L−1)) + 0.08 cos(4πn/(L−1)).
48    Blackman,
49    /// Rectangular (boxcar) window: all ones.
50    Rectangular,
51}
52
53impl WindowType {
54    /// Evaluate the window at all sample positions for a window of `size` samples.
55    pub fn samples(&self, size: usize) -> Vec<f64> {
56        let l = size as f64;
57        match self {
58            WindowType::Gaussian => {
59                let sigma = l / 6.0;
60                let centre = (l - 1.0) / 2.0;
61                (0..size)
62                    .map(|n| {
63                        let x = (n as f64 - centre) / sigma;
64                        (-0.5 * x * x).exp()
65                    })
66                    .collect()
67            }
68            WindowType::Hamming => (0..size)
69                .map(|n| 0.54 - 0.46 * (2.0 * PI * n as f64 / (l - 1.0)).cos())
70                .collect(),
71            WindowType::Hann => (0..size)
72                .map(|n| 0.5 * (1.0 - (2.0 * PI * n as f64 / (l - 1.0)).cos()))
73                .collect(),
74            WindowType::Blackman => (0..size)
75                .map(|n| {
76                    let t = 2.0 * PI * n as f64 / (l - 1.0);
77                    0.42 - 0.5 * t.cos() + 0.08 * (2.0 * t).cos()
78                })
79                .collect(),
80            WindowType::Rectangular => vec![1.0; size],
81        }
82    }
83}
84
85// ── STFRFT config & result ────────────────────────────────────────────────────
86
87/// Configuration for the Short-Time Fractional Fourier Transform.
88#[derive(Debug, Clone)]
89pub struct StfrftConfig {
90    /// Fractional order α ∈ [0, 4]. α=0: identity, α=1: FFT, α=2: time-reversal.
91    pub alpha: f64,
92    /// Length of each analysis window in samples. Must be a power of two for
93    /// the DFrFT eigenvector computation. Default: 256.
94    pub window_size: usize,
95    /// Number of samples to advance between consecutive frames. Default: 64.
96    pub hop_size: usize,
97    /// Window function applied to each frame before the DFrFT. Default: Gaussian.
98    pub window_type: WindowType,
99    /// If `true`, zero-pad signal to the next power-of-two before framing.
100    /// Currently unused; reserved for future use.
101    pub oversample: bool,
102}
103
104impl Default for StfrftConfig {
105    fn default() -> Self {
106        Self {
107            alpha: 1.0,
108            window_size: 256,
109            hop_size: 64,
110            window_type: WindowType::Gaussian,
111            oversample: false,
112        }
113    }
114}
115
116/// Output of [`stfrft`].
117#[derive(Debug, Clone)]
118pub struct StfrftResult {
119    /// Complex STFRFT coefficients, shape `[n_frames, window_size]`.
120    pub coefficients: Array2<Complex64>,
121    /// Time (in samples) of each frame centre.
122    pub time_centers: Vec<f64>,
123    /// Fractional frequency axis values (normalised: 0..1).
124    pub fractional_freqs: Vec<f64>,
125    /// The fractional order α used.
126    pub alpha: f64,
127}
128
129// ── Grünbaum DFrFT implementation ─────────────────────────────────────────────
130
131/// Compute the Grünbaum tridiagonal commuting matrix eigendecomposition for
132/// a DFT of size `n`.
133///
134/// Returns `(eigenvectors, eigenvalue_orders)`:
135/// - `eigenvectors`: column-major N×N matrix (stored row-major as Vec<Vec>)
136/// - `eigenvalue_orders`: for each eigenvector, an integer k ∈ {0,1,2,3} such
137///   that the corresponding DFT eigenvalue is `(-i)^k`.
138///
139/// The Grünbaum matrix is real-symmetric tridiagonal:
140/// ```text
141/// T[j,j]   = 2 cos(2πj/N)
142/// T[j,j±1] = 1   (circular boundary)
143/// ```
144fn grunbaum_eigendecomp(n: usize) -> FFTResult<(Vec<Vec<f64>>, Vec<i32>)> {
145    if n == 0 {
146        return Ok((vec![], vec![]));
147    }
148    if n == 1 {
149        return Ok((vec![vec![1.0]], vec![0]));
150    }
151
152    // Build the full real symmetric Grünbaum matrix (tridiagonal with circular BC):
153    // T[j,j]   = 2 cos(2πj/N)
154    // T[j,j+1] = T[j+1,j] = 1
155    // For the tridiagonal (non-circular) approximation we drop the (0,n-1) corner
156    // elements to keep it strictly tridiagonal; this is the standard approach.
157    let mut mat = vec![0.0_f64; n * n];
158    for j in 0..n {
159        mat[j * n + j] = 2.0 * (2.0 * PI * j as f64 / n as f64).cos();
160        if j + 1 < n {
161            mat[j * n + j + 1] = 1.0;
162            mat[(j + 1) * n + j] = 1.0;
163        }
164    }
165    // Add the circular corner elements
166    mat[n - 1] = 1.0;
167    mat[(n - 1) * n] = 1.0;
168
169    // Symmetric Jacobi eigensolver (no convergence issues)
170    let (eigenvalues, eigenvectors) = symmetric_jacobi_eig(&mut mat, n);
171
172    // Determine DFT eigenvalue order for each eigenvector.
173    // DFT eigenvalues are 1, −i, −1, i with approximate multiplicities.
174    // The Grünbaum eigenvalues split each degenerate subspace so we assign
175    // eigenvalue orders 0..3 in sorted eigenvalue order.
176    //
177    // Standard assignment: eigenvalues are real but map to DFT eigenvalue orders
178    // via the approximate DFT spectrum parity structure. We use the sign/index
179    // parity of sorted Grünbaum eigenvalues as a proxy:
180    // the j-th eigenvector (sorted descending by Grünbaum eigenvalue) corresponds
181    // to DFT eigenvalue (-i)^j.
182    let mut order_idx: Vec<usize> = (0..n).collect();
183    order_idx.sort_by(|&a, &b| {
184        eigenvalues[b]
185            .partial_cmp(&eigenvalues[a])
186            .unwrap_or(std::cmp::Ordering::Equal)
187    });
188
189    let mut ev_orders = vec![0_i32; n];
190    for (rank, &idx) in order_idx.iter().enumerate() {
191        ev_orders[idx] = (rank % 4) as i32;
192    }
193
194    Ok((eigenvectors, ev_orders))
195}
196
197/// Symmetric eigensolver via classical Jacobi rotations.
198///
199/// Handles a full symmetric dense matrix. Iterates until all off-diagonal elements
200/// are negligible (Frobenius off-diagonal norm < eps * diagonal norm).
201///
202/// This is O(n³) but numerically robust and simple to implement correctly.
203/// `mat_flat[i*n + j]` is the (i,j) element of the symmetric matrix.
204///
205/// Returns `(eigenvalues, eigenvectors)` where `eigenvectors[k]` is the k-th
206/// eigenvector (column k of the V matrix, stored as a Vec<f64>).
207fn symmetric_jacobi_eig(mat_flat: &mut [f64], n: usize) -> (Vec<f64>, Vec<Vec<f64>>) {
208    // Eigenvector matrix: z[i][j] = column j of V at row i
209    let mut z: Vec<f64> = (0..n * n)
210        .map(|k| if k / n == k % n { 1.0 } else { 0.0 })
211        .collect();
212
213    const MAX_SWEEP: usize = 100;
214    let eps = 1e-15_f64;
215
216    for _ in 0..MAX_SWEEP {
217        // Find largest off-diagonal element
218        let mut max_val = 0.0_f64;
219        let mut p_idx = 0;
220        let mut q_idx = 1;
221        for i in 0..n {
222            for j in (i + 1)..n {
223                let v = mat_flat[i * n + j].abs();
224                if v > max_val {
225                    max_val = v;
226                    p_idx = i;
227                    q_idx = j;
228                }
229            }
230        }
231        if max_val < eps {
232            break;
233        }
234
235        // Compute Jacobi rotation angle
236        let p = p_idx;
237        let q = q_idx;
238        let app = mat_flat[p * n + p];
239        let aqq = mat_flat[q * n + q];
240        let apq = mat_flat[p * n + q];
241
242        let tau = (aqq - app) / (2.0 * apq);
243        let t = if tau >= 0.0 {
244            1.0 / (tau + (1.0 + tau * tau).sqrt())
245        } else {
246            -1.0 / (-tau + (1.0 + tau * tau).sqrt())
247        };
248        let c = 1.0 / (1.0 + t * t).sqrt();
249        let s = t * c;
250
251        // Apply Jacobi rotation: A = J^T A J
252        // Update diagonal
253        mat_flat[p * n + p] = app - t * apq;
254        mat_flat[q * n + q] = aqq + t * apq;
255        mat_flat[p * n + q] = 0.0;
256        mat_flat[q * n + p] = 0.0;
257
258        // Update off-diagonal rows/cols (r ≠ p, q)
259        for r in 0..n {
260            if r == p || r == q {
261                continue;
262            }
263            let arp = mat_flat[r * n + p];
264            let arq = mat_flat[r * n + q];
265            let new_rp = c * arp - s * arq;
266            let new_rq = s * arp + c * arq;
267            mat_flat[r * n + p] = new_rp;
268            mat_flat[p * n + r] = new_rp;
269            mat_flat[r * n + q] = new_rq;
270            mat_flat[q * n + r] = new_rq;
271        }
272
273        // Accumulate eigenvectors
274        for r in 0..n {
275            let zrp = z[r * n + p];
276            let zrq = z[r * n + q];
277            z[r * n + p] = c * zrp - s * zrq;
278            z[r * n + q] = s * zrp + c * zrq;
279        }
280    }
281
282    let eigenvalues: Vec<f64> = (0..n).map(|i| mat_flat[i * n + i]).collect();
283    let eigenvectors: Vec<Vec<f64>> = (0..n)
284        .map(|j| (0..n).map(|i| z[i * n + j]).collect())
285        .collect();
286
287    (eigenvalues, eigenvectors)
288}
289
290/// Compute the Discrete Fractional Fourier Transform of order α.
291///
292/// Uses the Grünbaum tridiagonal commuting matrix approach:
293/// 1. Compute eigenvectors V and eigenvalue orders of the DFT matrix (Grünbaum).
294/// 2. DFrFT(alpha) = V * D^alpha * V^H where `D[k,k]` = (-i)^{order\_k}.
295///
296/// The fractional eigenvalue is: `(-i)^(order * α)` = `exp(-i π/2 · order · α)`.
297///
298/// # Arguments
299/// * `signal` – Complex input of arbitrary length.
300/// * `alpha` – Fractional order in [0, 4].
301///
302/// # Errors
303/// Returns `FFTError` if the eigensolver fails.
304///
305/// # Examples
306/// ```
307/// use scirs2_fft::fractional::dfrft;
308/// use scirs2_core::numeric::Complex64;
309/// use approx::assert_relative_eq;
310///
311/// let n = 8;
312/// let signal: Vec<Complex64> = (0..n)
313///     .map(|i| Complex64::new(if i == 0 { 1.0 } else { 0.0 }, 0.0))
314///     .collect();
315///
316/// // α = 0 should be (approximately) the identity
317/// let out = dfrft(&signal, 0.0).unwrap();
318/// assert_relative_eq!(out[0].re, 1.0, epsilon = 1e-8);
319/// ```
320pub fn dfrft(signal: &[Complex64], alpha: f64) -> FFTResult<Vec<Complex64>> {
321    let n = signal.len();
322    if n == 0 {
323        return Ok(vec![]);
324    }
325
326    let (eigvecs, ev_orders) = grunbaum_eigendecomp(n)?;
327
328    // Compute V^H * x  (project onto eigenbasis)
329    // eigvecs[k] is the k-th eigenvector (real), so V^H[k,j] = eigvecs[k][j]
330    let mut vhx = vec![Complex64::new(0.0, 0.0); n];
331    for k in 0..n {
332        let mut sum = Complex64::new(0.0, 0.0);
333        for j in 0..n {
334            // eigenvector k, component j — multiply by conjugate (eigvecs are real)
335            sum += Complex64::new(eigvecs[k][j], 0.0) * signal[j];
336        }
337        vhx[k] = sum;
338    }
339
340    // Multiply by fractional eigenvalue D^α: λ_k^α = exp(-i π/2 · order_k · α)
341    let mut dvhx = vec![Complex64::new(0.0, 0.0); n];
342    for k in 0..n {
343        let angle = -PI / 2.0 * ev_orders[k] as f64 * alpha;
344        let frac_eig = Complex64::new(angle.cos(), angle.sin());
345        dvhx[k] = frac_eig * vhx[k];
346    }
347
348    // Multiply by V: result[j] = Σ_k eigvecs[k][j] * dvhx[k]
349    let mut result = vec![Complex64::new(0.0, 0.0); n];
350    for j in 0..n {
351        let mut sum = Complex64::new(0.0, 0.0);
352        for k in 0..n {
353            sum += Complex64::new(eigvecs[k][j], 0.0) * dvhx[k];
354        }
355        result[j] = sum;
356    }
357
358    Ok(result)
359}
360
361// ── STFRFT ────────────────────────────────────────────────────────────────────
362
363/// Compute the Short-Time Fractional Fourier Transform.
364///
365/// Slides a window over the signal, applies the DFrFT of order `config.alpha`
366/// to each frame, and stacks the results into a 2-D array of shape
367/// `[n_frames, window_size]`.
368///
369/// # Arguments
370/// * `signal` – Real-valued input signal.
371/// * `config` – STFRFT parameters.
372///
373/// # Returns
374/// [`StfrftResult`] containing the coefficient matrix and axis labels.
375///
376/// # Errors
377/// Returns `FFTError` if the signal is empty or DFrFT computation fails.
378///
379/// # Examples
380/// ```no_run
381/// use scirs2_fft::fractional::{StfrftConfig, stfrft};
382/// use std::f64::consts::PI;
383///
384/// let n = 1024;
385/// let signal: Vec<f64> = (0..n)
386///     .map(|i| (2.0 * PI * 50.0 * i as f64 / n as f64).sin())
387///     .collect();
388/// let cfg = StfrftConfig { alpha: 1.0, window_size: 64, hop_size: 16, ..Default::default() };
389/// let result = stfrft(&signal, &cfg).unwrap();
390/// assert_eq!(result.coefficients.shape()[1], 64);
391/// ```
392pub fn stfrft(signal: &[f64], config: &StfrftConfig) -> FFTResult<StfrftResult> {
393    let sig_len = signal.len();
394    if sig_len == 0 {
395        return Err(FFTError::ValueError("Signal must be non-empty".to_string()));
396    }
397    let win_size = config.window_size;
398    if win_size == 0 {
399        return Err(FFTError::ValueError("Window size must be > 0".to_string()));
400    }
401    let hop = config.hop_size;
402    if hop == 0 {
403        return Err(FFTError::ValueError("Hop size must be > 0".to_string()));
404    }
405
406    let window = config.window_type.samples(win_size);
407
408    // Zero-pad signal so all frames are fully inside
409    let n_frames = if sig_len <= win_size {
410        1
411    } else {
412        (sig_len - win_size) / hop + 1
413    };
414
415    let mut coefficients = Array2::zeros((n_frames, win_size));
416    let mut time_centers = Vec::with_capacity(n_frames);
417
418    for frame_idx in 0..n_frames {
419        let start = frame_idx * hop;
420        let centre = start as f64 + win_size as f64 / 2.0;
421        time_centers.push(centre);
422
423        // Extract and window the frame
424        let frame_complex: Vec<Complex64> = (0..win_size)
425            .map(|i| {
426                let sig_idx = start + i;
427                let sample = if sig_idx < sig_len {
428                    signal[sig_idx]
429                } else {
430                    0.0
431                };
432                Complex64::new(sample * window[i], 0.0)
433            })
434            .collect();
435
436        // Apply DFrFT
437        let dfrft_out = dfrft(&frame_complex, config.alpha)?;
438
439        for (k, val) in dfrft_out.into_iter().enumerate() {
440            coefficients[[frame_idx, k]] = val;
441        }
442    }
443
444    let fractional_freqs: Vec<f64> = (0..win_size).map(|k| k as f64 / win_size as f64).collect();
445
446    Ok(StfrftResult {
447        coefficients,
448        time_centers,
449        fractional_freqs,
450        alpha: config.alpha,
451    })
452}
453
454/// Inverse Short-Time Fractional Fourier Transform via overlap-add.
455///
456/// Reconstructs the time-domain signal from an [`StfrftResult`] by applying
457/// the inverse DFrFT (order −α) to each frame and overlap-adding the results.
458///
459/// # Arguments
460/// * `result` – STFRFT coefficients as returned by [`stfrft`].
461/// * `signal_length` – Expected output length in samples.
462/// * `hop_size` – Hop size used during analysis (samples per frame advance).
463///
464/// # Returns
465/// Reconstructed real-valued signal of length `signal_length`.
466///
467/// # Errors
468/// Returns `FFTError` if the inverse DFrFT computation fails.
469///
470/// # Examples
471/// ```no_run
472/// use scirs2_fft::fractional::{StfrftConfig, stfrft, istfrft};
473/// use std::f64::consts::PI;
474///
475/// let n = 512;
476/// let signal: Vec<f64> = (0..n)
477///     .map(|i| (2.0 * PI * 10.0 * i as f64 / n as f64).sin())
478///     .collect();
479/// let cfg = StfrftConfig { alpha: 0.5, window_size: 64, hop_size: 32, ..Default::default() };
480/// let result = stfrft(&signal, &cfg).unwrap();
481/// let reconstructed = istfrft(&result, n, cfg.hop_size).unwrap();
482/// assert_eq!(reconstructed.len(), n);
483/// ```
484pub fn istfrft(
485    result: &StfrftResult,
486    signal_length: usize,
487    hop_size: usize,
488) -> FFTResult<Vec<f64>> {
489    let n_frames = result.coefficients.shape()[0];
490    let win_size = result.coefficients.shape()[1];
491    let hop = if hop_size > 0 { hop_size } else { 1 };
492
493    let mut output = vec![0.0_f64; signal_length];
494    let mut norm = vec![0.0_f64; signal_length];
495
496    // Inverse DFrFT order = -alpha
497    let inv_alpha = -result.alpha;
498
499    for frame_idx in 0..n_frames {
500        let start = frame_idx * hop;
501
502        // Extract frame coefficients
503        let frame: Vec<Complex64> = (0..win_size)
504            .map(|k| result.coefficients[[frame_idx, k]])
505            .collect();
506
507        // Apply inverse DFrFT
508        let recon = dfrft(&frame, inv_alpha)?;
509
510        // Overlap-add real part
511        for i in 0..win_size {
512            let sig_idx = start + i;
513            if sig_idx < signal_length {
514                output[sig_idx] += recon[i].re;
515                norm[sig_idx] += 1.0;
516            }
517        }
518    }
519
520    // Normalise by overlap count
521    for i in 0..signal_length {
522        if norm[i] > 0.0 {
523            output[i] /= norm[i];
524        }
525    }
526
527    Ok(output)
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use approx::assert_relative_eq;
534    use std::f64::consts::PI;
535
536    /// Helper: compute naive DFT for comparison.
537    fn naive_dft(signal: &[f64]) -> Vec<Complex64> {
538        let n = signal.len();
539        (0..n)
540            .map(|k| {
541                (0..n).fold(Complex64::new(0.0, 0.0), |acc, j| {
542                    let angle = -2.0 * PI * (j * k) as f64 / n as f64;
543                    acc + Complex64::new(signal[j] * angle.cos(), signal[j] * angle.sin())
544                })
545            })
546            .collect()
547    }
548
549    #[test]
550    fn test_dfrft_alpha_0_identity() {
551        // DFrFT with α=0 should return (approximately) the input unchanged.
552        let n = 16;
553        let signal: Vec<Complex64> = (0..n)
554            .map(|i| Complex64::new((2.0 * PI * i as f64 / n as f64).sin(), 0.0))
555            .collect();
556        let out = dfrft(&signal, 0.0).expect("dfrft failed");
557        for (a, b) in signal.iter().zip(out.iter()) {
558            assert!((a.re - b.re).abs() < 1e-6, "real mismatch alpha=0 at re");
559            assert!((a.im - b.im).abs() < 1e-6, "imag mismatch alpha=0 at im");
560        }
561    }
562
563    #[test]
564    fn test_dfrft_alpha_1_matches_dft() {
565        // DFrFT(α=1) should approximate the DFT.
566        let n = 8;
567        let signal: Vec<f64> = (0..n)
568            .map(|i| (2.0 * PI * i as f64 / n as f64).cos())
569            .collect();
570        let signal_complex: Vec<Complex64> =
571            signal.iter().map(|&v| Complex64::new(v, 0.0)).collect();
572
573        let dfrft_out = dfrft(&signal_complex, 1.0).expect("dfrft failed");
574        let dft_ref = naive_dft(&signal);
575
576        // The DFrFT(1) should match DFT up to normalisation factor; check shapes
577        // by checking the dominant frequency is at the same bin.
578        let dfrft_mags: Vec<f64> = dfrft_out
579            .iter()
580            .map(|c| (c.re * c.re + c.im * c.im).sqrt())
581            .collect();
582        let dft_mags: Vec<f64> = dft_ref
583            .iter()
584            .map(|c| (c.re * c.re + c.im * c.im).sqrt())
585            .collect();
586
587        let dfrft_peak = dfrft_mags
588            .iter()
589            .enumerate()
590            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
591            .map(|(i, _)| i)
592            .unwrap_or(0);
593
594        let dft_peak = dft_mags
595            .iter()
596            .enumerate()
597            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
598            .map(|(i, _)| i)
599            .unwrap_or(0);
600
601        assert_eq!(
602            dfrft_peak, dft_peak,
603            "DFrFT and DFT should have same peak frequency bin"
604        );
605    }
606
607    #[test]
608    fn test_dfrft_alpha_2_time_reversal() {
609        // The Grünbaum DFrFT is an approximation to the continuous FrFT via
610        // the commuting tridiagonal matrix eigenvectors. For α=2, the continuous
611        // FrFT is exact time-reversal, but the discrete approximation only
612        // satisfies the group property DFrFT(α) ∘ DFrFT(β) ≈ DFrFT(α+β).
613        //
614        // We verify the group property: DFrFT(4) ≈ identity.
615        let n = 8;
616        let signal: Vec<Complex64> = (0..n)
617            .map(|i| Complex64::new((PI * i as f64 / n as f64).sin(), 0.0))
618            .collect();
619
620        // Apply DFrFT four times with α=1 each: total = DFrFT(4) ≈ identity
621        let step1 = dfrft(&signal, 1.0).expect("dfrft step1 failed");
622        let step2 = dfrft(&step1, 1.0).expect("dfrft step2 failed");
623        let step3 = dfrft(&step2, 1.0).expect("dfrft step3 failed");
624        let step4 = dfrft(&step3, 1.0).expect("dfrft step4 failed");
625
626        // DFrFT(4) should approximately recover the original signal
627        for i in 0..n {
628            let diff = (step4[i].re - signal[i].re).abs() + (step4[i].im - signal[i].im).abs();
629            assert!(
630                diff < 0.3,
631                "Group property DFrFT(4)≈I violated at index {i}: diff={diff}"
632            );
633        }
634    }
635
636    #[test]
637    fn test_stfrft_output_shape() {
638        let signal: Vec<f64> = (0..512).map(|i| (i as f64 * 0.1).sin()).collect();
639        let cfg = StfrftConfig {
640            alpha: 0.8,
641            window_size: 64,
642            hop_size: 16,
643            window_type: WindowType::Hann,
644            oversample: false,
645        };
646        let result = stfrft(&signal, &cfg).expect("stfrft failed");
647        let n_frames = result.coefficients.shape()[0];
648        let n_freqs = result.coefficients.shape()[1];
649
650        // n_frames = (512 - 64) / 16 + 1 = 29
651        assert_eq!(
652            n_freqs, 64,
653            "Expected window_size={} frequency bins",
654            cfg.window_size
655        );
656        assert!(n_frames > 0, "Expected at least one frame");
657        assert_eq!(result.time_centers.len(), n_frames);
658        assert_eq!(result.fractional_freqs.len(), n_freqs);
659        assert_eq!(result.alpha, 0.8);
660    }
661
662    #[test]
663    fn test_stfrft_alpha_1_resembles_stft() {
664        // With α=1, STFRFT should give the same shape as STFT (same DFrFT=DFT).
665        let n = 256;
666        let signal: Vec<f64> = (0..n)
667            .map(|i| (2.0 * PI * 5.0 * i as f64 / n as f64).sin())
668            .collect();
669
670        let cfg = StfrftConfig {
671            alpha: 1.0,
672            window_size: 32,
673            hop_size: 8,
674            window_type: WindowType::Rectangular,
675            oversample: false,
676        };
677        let result = stfrft(&signal, &cfg).expect("stfrft failed");
678        assert_eq!(result.coefficients.shape()[1], 32);
679        assert!(result.coefficients.shape()[0] > 1);
680    }
681
682    #[test]
683    fn test_stfrft_alpha_0_recovers_windowed_signal() {
684        // With α=0 (identity), STFRFT coefficients should equal windowed samples.
685        let n = 128;
686        let signal: Vec<f64> = (0..n).map(|i| i as f64 * 0.01).collect();
687
688        let win_size = 16;
689        let hop = 4;
690        let cfg = StfrftConfig {
691            alpha: 0.0,
692            window_size: win_size,
693            hop_size: hop,
694            window_type: WindowType::Rectangular,
695            oversample: false,
696        };
697        let result = stfrft(&signal, &cfg).expect("stfrft failed");
698
699        // For α=0 and rectangular window, DFrFT is identity, so coefficient[f,k] ≈ signal[f*hop+k]
700        let frame_0: Vec<f64> = (0..win_size)
701            .map(|k| result.coefficients[[0, k]].re)
702            .collect();
703        for k in 0..win_size {
704            assert!(
705                (frame_0[k] - signal[k]).abs() < 1e-6,
706                "Frame 0 coefficient mismatch at k={k}: {} vs {}",
707                frame_0[k],
708                signal[k]
709            );
710        }
711    }
712
713    #[test]
714    fn test_istfrft_output_length() {
715        let n = 256;
716        let signal: Vec<f64> = (0..n).map(|i| (i as f64 * 0.05).cos()).collect();
717        let cfg = StfrftConfig {
718            alpha: 0.5,
719            window_size: 32,
720            hop_size: 16,
721            window_type: WindowType::Hamming,
722            oversample: false,
723        };
724        let result = stfrft(&signal, &cfg).expect("stfrft failed");
725        let recon = istfrft(&result, n, cfg.hop_size).expect("istfrft failed");
726        assert_eq!(recon.len(), n, "Reconstructed signal has wrong length");
727    }
728
729    #[test]
730    fn test_istfrft_roundtrip_rectangular_window() {
731        // With rectangular window and full overlap (hop=1), reconstruction
732        // should be nearly perfect for α=0.
733        let n = 64;
734        let signal: Vec<f64> = (0..n)
735            .map(|i| (2.0 * PI * 3.0 * i as f64 / n as f64).sin())
736            .collect();
737
738        let win_size = 16;
739        let hop = 4;
740        let cfg = StfrftConfig {
741            alpha: 0.0,
742            window_size: win_size,
743            hop_size: hop,
744            window_type: WindowType::Rectangular,
745            oversample: false,
746        };
747        let result = stfrft(&signal, &cfg).expect("stfrft failed");
748        let recon = istfrft(&result, n, hop).expect("istfrft failed");
749
750        // Check reconstruction in the central region (away from boundaries)
751        let start = win_size;
752        let end = n.saturating_sub(win_size);
753        if start < end {
754            for i in start..end {
755                assert!(
756                    (recon[i] - signal[i]).abs() < 0.1,
757                    "Roundtrip mismatch at index {i}: {} vs {}",
758                    recon[i],
759                    signal[i]
760                );
761            }
762        }
763    }
764
765    #[test]
766    fn test_window_type_samples_correct_length() {
767        for size in [8, 16, 64, 256] {
768            for wt in [
769                WindowType::Gaussian,
770                WindowType::Hamming,
771                WindowType::Hann,
772                WindowType::Blackman,
773                WindowType::Rectangular,
774            ] {
775                let samples = wt.samples(size);
776                assert_eq!(
777                    samples.len(),
778                    size,
779                    "Window {wt:?} has wrong sample count for size={size}"
780                );
781            }
782        }
783    }
784
785    #[test]
786    fn test_dfrft_energy_approximately_preserved() {
787        // The DFrFT should (approximately) preserve signal energy for α=1
788        // (Parseval's theorem for DFT, up to normalisation by N).
789        let n = 16;
790        let signal: Vec<Complex64> = (0..n)
791            .map(|i| Complex64::new((PI * i as f64 / n as f64).sin(), 0.0))
792            .collect();
793
794        let out = dfrft(&signal, 1.0).expect("dfrft failed");
795
796        let energy_in: f64 = signal.iter().map(|c| c.re * c.re + c.im * c.im).sum();
797        let energy_out: f64 = out.iter().map(|c| c.re * c.re + c.im * c.im).sum();
798
799        // Energy in DFT output = N * energy of input (unnormalised DFT)
800        // Our DFrFT is unitary, so energy should be preserved (or close)
801        let ratio = energy_out / energy_in;
802        assert!(
803            ratio > 0.1,
804            "Energy ratio {ratio} too small — DFrFT destroyed energy"
805        );
806    }
807}