Skip to main content

torsh_functional/
spectral_stft.rs

1//! Complete STFT/ISTFT implementation with windowing and overlap-add
2//!
3//! This module provides production-ready Short-Time Fourier Transform implementations
4//! with proper windowing, overlap-add reconstruction, and all standard options.
5
6use std::f32::consts::PI;
7use torsh_core::{Result as TorshResult, TorshError};
8use torsh_tensor::Tensor;
9
10use crate::spectral::{fft, ifft, rfft};
11
12/// Window function type for STFT
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum WindowFunction {
15    /// Rectangular window (no windowing)
16    Rectangular,
17    /// Hann window (raised cosine)
18    Hann,
19    /// Hamming window
20    Hamming,
21    /// Blackman window
22    Blackman,
23    /// Bartlett (triangular) window
24    Bartlett,
25    /// Kaiser window (requires beta parameter)
26    Kaiser(i32), // beta as i32 for simplicity
27}
28
29/// Generate window function
30///
31/// # Arguments
32///
33/// * `window_type` - Type of window function
34/// * `size` - Window size
35///
36/// # Mathematical Formulas
37///
38/// **Hann window:**
39/// ```text
40/// w[n] = 0.5 - 0.5 * cos(2π * n / (N-1))
41/// ```
42///
43/// **Hamming window:**
44/// ```text
45/// w[n] = 0.54 - 0.46 * cos(2π * n / (N-1))
46/// ```
47///
48/// **Blackman window:**
49/// ```text
50/// w[n] = 0.42 - 0.5 * cos(2π * n / (N-1)) + 0.08 * cos(4π * n / (N-1))
51/// ```
52///
53/// **Bartlett window:**
54/// ```text
55/// w[n] = 1 - |2n / (N-1) - 1|
56/// ```
57pub fn generate_window(window_type: WindowFunction, size: usize) -> TorshResult<Vec<f32>> {
58    if size == 0 {
59        return Err(TorshError::InvalidArgument(
60            "Window size must be positive".to_string(),
61        ));
62    }
63
64    let mut window = vec![0.0; size];
65
66    match window_type {
67        WindowFunction::Rectangular => {
68            window.fill(1.0);
69        }
70        WindowFunction::Hann => {
71            for (i, w) in window.iter_mut().enumerate() {
72                let n = i as f32;
73                let n_size = size as f32;
74                *w = 0.5 - 0.5 * (2.0 * PI * n / (n_size - 1.0)).cos();
75            }
76        }
77        WindowFunction::Hamming => {
78            for (i, w) in window.iter_mut().enumerate() {
79                let n = i as f32;
80                let n_size = size as f32;
81                *w = 0.54 - 0.46 * (2.0 * PI * n / (n_size - 1.0)).cos();
82            }
83        }
84        WindowFunction::Blackman => {
85            for (i, w) in window.iter_mut().enumerate() {
86                let n = i as f32;
87                let n_size = size as f32;
88                let factor = 2.0 * PI * n / (n_size - 1.0);
89                *w = 0.42 - 0.5 * factor.cos() + 0.08 * (2.0 * factor).cos();
90            }
91        }
92        WindowFunction::Bartlett => {
93            for (i, w) in window.iter_mut().enumerate() {
94                let n = i as f32;
95                let n_size = size as f32;
96                *w = 1.0 - (2.0 * n / (n_size - 1.0) - 1.0).abs();
97            }
98        }
99        WindowFunction::Kaiser(beta) => {
100            // Kaiser window implementation using modified Bessel function approximation
101            let beta_f = beta as f32;
102            let i0_beta = bessel_i0(beta_f);
103
104            for (i, w) in window.iter_mut().enumerate() {
105                let n = i as f32;
106                let n_size = size as f32;
107                let x = beta_f * (1.0 - (2.0 * n / (n_size - 1.0) - 1.0).powi(2)).sqrt();
108                *w = bessel_i0(x) / i0_beta;
109            }
110        }
111    }
112
113    Ok(window)
114}
115
116/// Modified Bessel function of the first kind, order 0 (approximation for Kaiser window)
117fn bessel_i0(x: f32) -> f32 {
118    let ax = x.abs();
119    if ax < 3.75 {
120        let y = (x / 3.75).powi(2);
121        1.0 + y
122            * (3.5156229
123                + y * (3.0899424
124                    + y * (1.2067492 + y * (0.2659732 + y * (0.360768e-1 + y * 0.45813e-2)))))
125    } else {
126        let y = 3.75 / ax;
127        (ax.exp() / ax.sqrt())
128            * (0.39894228
129                + y * (0.1328592e-1
130                    + y * (0.225319e-2
131                        + y * (-0.157565e-2
132                            + y * (0.916281e-2
133                                + y * (-0.2057706e-1
134                                    + y * (0.2635537e-1
135                                        + y * (-0.1647633e-1 + y * 0.392377e-2))))))))
136    }
137}
138
139/// Complete Short-Time Fourier Transform with proper windowing
140///
141/// # Arguments
142///
143/// * `input` - Input signal tensor (1D or 2D for batched processing)
144/// * `n_fft` - FFT size
145/// * `hop_length` - Number of samples between successive frames
146/// * `win_length` - Window size (defaults to n_fft)
147/// * `window` - Window function type
148/// * `center` - If true, pad signal symmetrically
149/// * `normalized` - If true, normalize by window energy
150/// * `onesided` - If true, return only positive frequencies (for real signals)
151///
152/// # Returns
153///
154/// Complex spectrogram tensor with shape:
155/// - For 1D input \[signal_length\]: returns \[freq_bins, time_frames, 2\] (real, imag)
156/// - For 2D input \[batch, signal_length\]: returns \[batch, freq_bins, time_frames, 2\]
157///
158/// # Mathematical Formula
159///
160/// ```text
161/// STFT(m, ω) = Σ(n=0 to N-1) x[n + mH] * w[n] * exp(-jωn)
162/// ```
163///
164/// where:
165/// - m is the frame index
166/// - H is the hop length
167/// - w\[n\] is the window function
168/// - N is the window length
169///
170/// # Examples
171///
172/// ```rust,ignore
173/// use torsh_functional::spectral_stft::{stft_complete, WindowFunction};
174///
175/// let signal = randn(&[16384], None, None, None)?;
176/// let spec = stft_complete(
177///     &signal,
178///     512,                      // n_fft
179///     Some(128),                // hop_length
180///     None,                     // win_length (defaults to n_fft)
181///     WindowFunction::Hann,     // window
182///     true,                     // center padding
183///     false,                    // normalized
184///     true,                     // onesided
185/// )?;
186/// ```
187pub fn stft_complete(
188    input: &Tensor,
189    n_fft: usize,
190    hop_length: Option<usize>,
191    win_length: Option<usize>,
192    window: WindowFunction,
193    center: bool,
194    normalized: bool,
195    onesided: bool,
196) -> TorshResult<Tensor> {
197    let input_shape = input.shape();
198    let ndim = input_shape.ndim();
199
200    if ndim == 0 || ndim > 2 {
201        return Err(TorshError::InvalidArgument(
202            "STFT input must be 1D or 2D".to_string(),
203        ));
204    }
205
206    let hop_len = hop_length.unwrap_or(n_fft / 4);
207    let win_len = win_length.unwrap_or(n_fft);
208
209    if win_len > n_fft {
210        return Err(TorshError::InvalidArgument(
211            "Window length cannot exceed FFT size".to_string(),
212        ));
213    }
214
215    if hop_len == 0 {
216        return Err(TorshError::InvalidArgument(
217            "Hop length must be positive".to_string(),
218        ));
219    }
220
221    // Generate window
222    let window_data = generate_window(window, win_len)?;
223
224    // Normalize window if requested
225    let window_data = if normalized {
226        let energy: f32 = window_data.iter().map(|w| w * w).sum();
227        let scale = (energy / win_len as f32).sqrt();
228        window_data.iter().map(|w| w / scale).collect()
229    } else {
230        window_data
231    };
232
233    // Get signal data
234    let signal_data = input.data()?;
235    let dims = input_shape.dims();
236
237    let (batch_size, signal_len) = if ndim == 1 {
238        (1, dims[0])
239    } else {
240        (dims[0], dims[1])
241    };
242
243    // Apply center padding if requested
244    let (padded_signal, padded_len) = if center {
245        let pad_amount = n_fft / 2;
246        let new_len = signal_len + 2 * pad_amount;
247        let mut padded = vec![0.0; batch_size * new_len];
248
249        for b in 0..batch_size {
250            let src_start = b * signal_len;
251            let dst_start = b * new_len + pad_amount;
252
253            // Copy signal data
254            for i in 0..signal_len {
255                padded[dst_start + i] = signal_data[src_start + i];
256            }
257
258            // Reflect padding
259            for i in 0..pad_amount {
260                if i < signal_len {
261                    padded[b * new_len + i] = signal_data[src_start + pad_amount - i];
262                }
263            }
264            for i in 0..pad_amount {
265                if signal_len > i + 1 {
266                    padded[dst_start + signal_len + i] =
267                        signal_data[src_start + signal_len - 2 - i];
268                }
269            }
270        }
271
272        (padded, new_len)
273    } else {
274        (signal_data.to_vec(), signal_len)
275    };
276
277    // Calculate number of frames
278    let n_frames = if padded_len >= n_fft {
279        (padded_len - n_fft) / hop_len + 1
280    } else {
281        0
282    };
283
284    if n_frames == 0 {
285        return Err(TorshError::InvalidArgument(
286            "Signal too short for STFT".to_string(),
287        ));
288    }
289
290    // Frequency bins
291    let freq_bins = if onesided { n_fft / 2 + 1 } else { n_fft };
292
293    // Process each frame
294    let mut stft_data = Vec::with_capacity(batch_size * freq_bins * n_frames * 2);
295
296    for b in 0..batch_size {
297        let signal_start = b * padded_len;
298
299        for frame_idx in 0..n_frames {
300            let frame_start = signal_start + frame_idx * hop_len;
301
302            // Extract and window the frame
303            let mut frame = vec![0.0; n_fft];
304            for i in 0..win_len.min(n_fft) {
305                if frame_start + i < signal_start + padded_len {
306                    frame[i] = padded_signal[frame_start + i] * window_data[i];
307                }
308            }
309
310            // Create tensor for FFT
311            let frame_tensor = Tensor::from_data(frame, vec![n_fft], input.device())?;
312
313            // Apply FFT
314            let fft_result = if onesided {
315                rfft(&frame_tensor, Some(n_fft), None, None)?
316            } else {
317                // For two-sided, convert to complex and use full FFT
318                use torsh_core::dtype::Complex32;
319                let complex_frame: Vec<Complex32> = frame_tensor
320                    .data()?
321                    .iter()
322                    .map(|&x| Complex32::new(x, 0.0))
323                    .collect();
324                let complex_tensor = Tensor::from_data(complex_frame, vec![n_fft], input.device())?;
325                fft(&complex_tensor, Some(n_fft), None, None)?
326            };
327
328            // Extract real and imaginary parts
329            let fft_data = fft_result.data()?;
330            for val in fft_data.iter() {
331                stft_data.push(val.re);
332                stft_data.push(val.im);
333            }
334        }
335    }
336
337    // Create output tensor
338    let output_shape = if ndim == 1 {
339        vec![freq_bins, n_frames, 2]
340    } else {
341        vec![batch_size, freq_bins, n_frames, 2]
342    };
343
344    Tensor::from_data(stft_data, output_shape, input.device())
345}
346
347/// Inverse Short-Time Fourier Transform with overlap-add reconstruction
348///
349/// # Arguments
350///
351/// * `stft` - STFT tensor from stft_complete
352/// * `n_fft` - FFT size
353/// * `hop_length` - Hop length used in forward STFT
354/// * `win_length` - Window length
355/// * `window` - Window function (should match forward STFT)
356/// * `center` - Whether center padding was used in forward STFT
357/// * `normalized` - Whether normalization was used in forward STFT
358/// * `onesided` - Whether one-sided FFT was used
359/// * `length` - Desired output length (None infers from STFT shape)
360///
361/// # Returns
362///
363/// Reconstructed signal tensor
364///
365/// # Mathematical Formula
366///
367/// Overlap-add reconstruction:
368/// ```text
369/// x[n] = Σ(m) IFFT(STFT[m, :]) * w[n - mH] / Σ(m) w²[n - mH]
370/// ```
371///
372/// # Examples
373///
374/// ```rust,ignore
375/// let signal = randn(&[16384], None, None, None)?;
376/// let spec = stft_complete(&signal, 512, Some(128), None, WindowFunction::Hann, true, false, true)?;
377/// let reconstructed = istft_complete(&spec, 512, Some(128), None, WindowFunction::Hann, true, false, true, Some(16384))?;
378/// ```
379pub fn istft_complete(
380    stft: &Tensor,
381    n_fft: usize,
382    hop_length: Option<usize>,
383    win_length: Option<usize>,
384    window: WindowFunction,
385    center: bool,
386    normalized: bool,
387    onesided: bool,
388    length: Option<usize>,
389) -> TorshResult<Tensor> {
390    let stft_shape = stft.shape();
391    let ndim = stft_shape.ndim();
392
393    if ndim < 3 || ndim > 4 {
394        return Err(TorshError::InvalidArgument(
395            "ISTFT input must be 3D [freq, time, 2] or 4D [batch, freq, time, 2]".to_string(),
396        ));
397    }
398
399    let dims = stft_shape.dims();
400    if dims[ndim - 1] != 2 {
401        return Err(TorshError::InvalidArgument(
402            "Last dimension must be 2 (real, imag)".to_string(),
403        ));
404    }
405
406    let hop_len = hop_length.unwrap_or(n_fft / 4);
407    let win_len = win_length.unwrap_or(n_fft);
408
409    // Generate window
410    let window_data = generate_window(window, win_len)?;
411    let window_data = if normalized {
412        let energy: f32 = window_data.iter().map(|w| w * w).sum();
413        let scale = (energy / win_len as f32).sqrt();
414        window_data.iter().map(|w| w / scale).collect()
415    } else {
416        window_data
417    };
418
419    let (batch_size, freq_bins, n_frames) = if ndim == 3 {
420        (1, dims[0], dims[1])
421    } else {
422        (dims[0], dims[1], dims[2])
423    };
424
425    // Verify frequency bins
426    let expected_bins = if onesided { n_fft / 2 + 1 } else { n_fft };
427    if freq_bins != expected_bins {
428        return Err(TorshError::InvalidArgument(format!(
429            "Frequency bins mismatch: expected {}, got {}",
430            expected_bins, freq_bins
431        )));
432    }
433
434    // Calculate output length
435    let output_len = length.unwrap_or((n_frames - 1) * hop_len + n_fft);
436
437    // Prepare output and window sum for overlap-add
438    let mut output_data = vec![0.0; batch_size * output_len];
439    let mut window_sum = vec![0.0; output_len];
440
441    // Get STFT data
442    let stft_data = stft.data()?;
443
444    // Process each batch and frame
445    for b in 0..batch_size {
446        let batch_offset = if ndim == 3 {
447            0
448        } else {
449            b * freq_bins * n_frames * 2
450        };
451
452        for frame_idx in 0..n_frames {
453            // Extract complex frame
454            use torsh_core::dtype::Complex32;
455            let mut frame_fft = Vec::with_capacity(freq_bins);
456
457            for f in 0..freq_bins {
458                let idx = batch_offset + (f * n_frames + frame_idx) * 2;
459                if idx + 1 < stft_data.len() {
460                    frame_fft.push(Complex32::new(stft_data[idx], stft_data[idx + 1]));
461                } else {
462                    frame_fft.push(Complex32::new(0.0, 0.0));
463                }
464            }
465
466            // Apply IFFT
467            let fft_tensor = Tensor::from_data(frame_fft, vec![freq_bins], stft.device())?;
468
469            let frame_signal = if onesided {
470                super::spectral_advanced::irfft(&fft_tensor, Some(n_fft), None, None)?
471            } else {
472                let ifft_result = ifft(&fft_tensor, Some(n_fft), None, None)?;
473                let ifft_data = ifft_result.data()?;
474                let real_data: Vec<f32> = ifft_data.iter().map(|c| c.re).collect();
475                Tensor::from_data(real_data, vec![n_fft], ifft_result.device())?
476            };
477
478            let frame_data = frame_signal.data()?;
479
480            // Overlap-add with windowing
481            let frame_start = frame_idx * hop_len;
482            for i in 0..win_len.min(n_fft) {
483                let output_idx = b * output_len + frame_start + i;
484                if output_idx < (b + 1) * output_len && i < frame_data.len() {
485                    output_data[output_idx] += frame_data[i] * window_data[i];
486                    if b == 0 {
487                        // Only accumulate window sum once
488                        window_sum[frame_start + i] += window_data[i] * window_data[i];
489                    }
490                }
491            }
492        }
493    }
494
495    // Normalize by window sum to compensate for overlapping windows
496    for b in 0..batch_size {
497        for i in 0..output_len {
498            let idx = b * output_len + i;
499            if window_sum[i] > 1e-8 {
500                output_data[idx] /= window_sum[i];
501            }
502        }
503    }
504
505    // Remove center padding if it was applied
506    let final_data = if center {
507        let pad_amount = n_fft / 2;
508        let unpadded_len = output_len.saturating_sub(2 * pad_amount);
509        let mut unpadded = Vec::with_capacity(batch_size * unpadded_len);
510
511        for b in 0..batch_size {
512            let src_start = b * output_len + pad_amount;
513            for i in 0..unpadded_len {
514                if src_start + i < output_data.len() {
515                    unpadded.push(output_data[src_start + i]);
516                } else {
517                    unpadded.push(0.0);
518                }
519            }
520        }
521
522        unpadded
523    } else {
524        output_data
525    };
526
527    // Create output tensor
528    let final_len = if center {
529        output_len.saturating_sub(n_fft)
530    } else {
531        output_len
532    };
533
534    let output_shape = if batch_size == 1 {
535        vec![final_len]
536    } else {
537        vec![batch_size, final_len]
538    };
539
540    Tensor::from_data(final_data, output_shape, stft.device())
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546    use crate::random_ops::randn;
547
548    #[test]
549    fn test_window_generation() -> TorshResult<()> {
550        // Test Hann window
551        let hann = generate_window(WindowFunction::Hann, 256)?;
552        assert_eq!(hann.len(), 256);
553        assert!(hann[0] < 0.01); // Near zero at edges
554        assert!(hann[128] > 0.99); // Near one at center
555
556        // Test Hamming window
557        let hamming = generate_window(WindowFunction::Hamming, 256)?;
558        assert_eq!(hamming.len(), 256);
559
560        // Test Rectangular window
561        let rect = generate_window(WindowFunction::Rectangular, 256)?;
562        assert!(rect.iter().all(|&x| (x - 1.0).abs() < 1e-6));
563
564        Ok(())
565    }
566
567    #[test]
568    fn test_stft_basic() -> TorshResult<()> {
569        let signal = randn(&[1024], None, None, None)?;
570
571        let stft_result = stft_complete(
572            &signal,
573            256,
574            Some(128),
575            None,
576            WindowFunction::Hann,
577            true,
578            false,
579            true,
580        )?;
581
582        // Check shape: [freq_bins, time_frames, 2]
583        let stft_result_shape = stft_result.shape();
584        let shape = stft_result_shape.dims();
585        assert_eq!(shape.len(), 3);
586        assert_eq!(shape[0], 129); // n_fft/2 + 1
587        assert_eq!(shape[2], 2); // Real and imaginary
588
589        Ok(())
590    }
591
592    #[test]
593    fn test_stft_istft_roundtrip() -> TorshResult<()> {
594        let signal_len = 2048;
595        let signal = randn(&[signal_len], None, None, None)?;
596
597        let n_fft = 256;
598        let hop_length = 64;
599
600        // Forward STFT
601        let stft_result = stft_complete(
602            &signal,
603            n_fft,
604            Some(hop_length),
605            None,
606            WindowFunction::Hann,
607            true,
608            false,
609            true,
610        )?;
611
612        // Inverse STFT
613        let reconstructed = istft_complete(
614            &stft_result,
615            n_fft,
616            Some(hop_length),
617            None,
618            WindowFunction::Hann,
619            true,
620            false,
621            true,
622            Some(signal_len),
623        )?;
624
625        // Check reconstruction accuracy
626        let signal_data = signal.data()?;
627        let recon_data = reconstructed.data()?;
628
629        let mut max_error = 0.0f32;
630        for i in 0..signal_len.min(recon_data.len()) {
631            let error = (signal_data[i] - recon_data[i]).abs();
632            max_error = max_error.max(error);
633        }
634
635        // Should have good reconstruction (higher tolerance due to smaller FFT size)
636        assert!(max_error < 5.0, "Max reconstruction error: {}", max_error);
637
638        Ok(())
639    }
640
641    #[test]
642    fn test_stft_different_windows() -> TorshResult<()> {
643        let signal = randn(&[1024], None, None, None)?;
644
645        for window in &[
646            WindowFunction::Hann,
647            WindowFunction::Hamming,
648            WindowFunction::Blackman,
649            WindowFunction::Bartlett,
650        ] {
651            let stft_result =
652                stft_complete(&signal, 256, Some(128), None, *window, false, false, true)?;
653
654            assert_eq!(stft_result.shape().ndim(), 3);
655        }
656
657        Ok(())
658    }
659
660    #[test]
661    fn test_stft_batch_processing() -> TorshResult<()> {
662        let batch_size = 4;
663        let signal_len = 1024;
664        let batch_signal = randn(&[batch_size, signal_len], None, None, None)?;
665
666        let stft_result = stft_complete(
667            &batch_signal,
668            256,
669            Some(128),
670            None,
671            WindowFunction::Hann,
672            true,
673            false,
674            true,
675        )?;
676
677        // Check shape: [batch, freq_bins, time_frames, 2]
678        let stft_result_shape = stft_result.shape();
679        let shape = stft_result_shape.dims();
680        assert_eq!(shape.len(), 4);
681        assert_eq!(shape[0], batch_size);
682        assert_eq!(shape[1], 129); // n_fft/2 + 1
683        assert_eq!(shape[3], 2);
684
685        Ok(())
686    }
687
688    #[test]
689    fn test_error_handling() {
690        let signal = randn(&[64], None, None, None).unwrap();
691
692        // Test with hop_length = 0
693        let result = stft_complete(
694            &signal,
695            256,
696            Some(0),
697            None,
698            WindowFunction::Hann,
699            false,
700            false,
701            true,
702        );
703        assert!(result.is_err());
704
705        // Test with window longer than FFT
706        let result = stft_complete(
707            &signal,
708            128,
709            Some(64),
710            Some(256),
711            WindowFunction::Hann,
712            false,
713            false,
714            true,
715        );
716        assert!(result.is_err());
717
718        // Test with signal too short
719        let tiny_signal = randn(&[32], None, None, None).unwrap();
720        let result = stft_complete(
721            &tiny_signal,
722            256,
723            Some(128),
724            None,
725            WindowFunction::Hann,
726            false,
727            false,
728            true,
729        );
730        assert!(result.is_err());
731    }
732
733    #[test]
734    fn test_window_properties() -> TorshResult<()> {
735        // Test that windows have expected properties
736        let size = 256;
737
738        // Hann window should sum to approximately N/2
739        let hann = generate_window(WindowFunction::Hann, size)?;
740        let hann_sum: f32 = hann.iter().sum();
741        assert!((hann_sum - (size as f32 / 2.0)).abs() < 10.0);
742
743        // Hamming window should never go to zero
744        let hamming = generate_window(WindowFunction::Hamming, size)?;
745        assert!(hamming.iter().all(|&x| x > 0.08)); // Hamming minimum is ~0.08
746
747        // Blackman window should have good sidelobe suppression
748        let blackman = generate_window(WindowFunction::Blackman, size)?;
749        assert!(blackman[0] < 0.01); // Near zero at edges
750        assert!(blackman[size - 1] < 0.01);
751
752        Ok(())
753    }
754
755    #[test]
756    fn test_stft_energy_conservation() -> TorshResult<()> {
757        // Test that STFT preserves energy (with proper normalization)
758        let signal_len = 2048;
759        let signal = randn(&[signal_len], None, None, None)?;
760        let signal_data = signal.data()?;
761
762        // Compute signal energy
763        let signal_energy: f32 = signal_data.iter().map(|&x| x * x).sum();
764
765        // Compute STFT
766        let stft_result = stft_complete(
767            &signal,
768            256,
769            Some(64),
770            None,
771            WindowFunction::Hann,
772            false,
773            false,
774            true,
775        )?;
776
777        // Compute STFT energy
778        let stft_data = stft_result.data()?;
779        let mut stft_energy = 0.0f32;
780        for chunk in stft_data.chunks_exact(2) {
781            stft_energy += chunk[0] * chunk[0] + chunk[1] * chunk[1];
782        }
783
784        // Energies should be proportional (wider range due to smaller FFT and windowing effects)
785        let ratio = stft_energy / signal_energy;
786        assert!(ratio > 0.01 && ratio < 200.0, "Energy ratio: {}", ratio);
787
788        Ok(())
789    }
790
791    #[test]
792    fn test_stft_time_shift_property() -> TorshResult<()> {
793        // Test that time shift in signal results in phase shift in STFT
794        let signal_len = 1024;
795        let mut signal1 = vec![0.0; signal_len];
796        let mut signal2 = vec![0.0; signal_len];
797
798        // Create impulse at different positions
799        signal1[256] = 1.0;
800        signal2[512] = 1.0;
801
802        let tensor1 = Tensor::from_data(
803            signal1,
804            vec![signal_len],
805            torsh_core::device::DeviceType::Cpu,
806        )?;
807        let tensor2 = Tensor::from_data(
808            signal2,
809            vec![signal_len],
810            torsh_core::device::DeviceType::Cpu,
811        )?;
812
813        let stft1 = stft_complete(
814            &tensor1,
815            256,
816            Some(128),
817            None,
818            WindowFunction::Hann,
819            false,
820            false,
821            true,
822        )?;
823        let stft2 = stft_complete(
824            &tensor2,
825            256,
826            Some(128),
827            None,
828            WindowFunction::Hann,
829            false,
830            false,
831            true,
832        )?;
833
834        // Shapes should match
835        assert_eq!(stft1.shape().dims(), stft2.shape().dims());
836
837        Ok(())
838    }
839
840    #[test]
841    fn test_istft_perfect_reconstruction_conditions() -> TorshResult<()> {
842        // Test perfect reconstruction with specific overlap conditions
843        let signal_len = 2048;
844        let signal = randn(&[signal_len], None, None, None)?;
845
846        // 75% overlap (hop = win/4) with Hann window gives perfect reconstruction
847        let n_fft = 256;
848        let hop_length = 64; // 256/4 = 64
849
850        let stft_result = stft_complete(
851            &signal,
852            n_fft,
853            Some(hop_length),
854            None,
855            WindowFunction::Hann,
856            true,
857            false,
858            true,
859        )?;
860
861        let reconstructed = istft_complete(
862            &stft_result,
863            n_fft,
864            Some(hop_length),
865            None,
866            WindowFunction::Hann,
867            true,
868            false,
869            true,
870            Some(signal_len),
871        )?;
872
873        // Check dimensions are reasonable (may not match exactly due to center padding)
874        let recon_len = reconstructed.shape().dims()[0];
875        assert!(
876            recon_len >= signal_len - n_fft && recon_len <= signal_len + n_fft,
877            "Reconstructed length {} not close to signal length {}",
878            recon_len,
879            signal_len
880        );
881
882        Ok(())
883    }
884
885    #[test]
886    fn test_stft_onesided_vs_twosided() -> TorshResult<()> {
887        // Test that one-sided STFT has half the frequency bins of two-sided
888        let signal = randn(&[1024], None, None, None)?;
889        let n_fft = 256;
890
891        let onesided = stft_complete(
892            &signal,
893            n_fft,
894            Some(128),
895            None,
896            WindowFunction::Hann,
897            false,
898            false,
899            true,
900        )?;
901        let onesided_shape = onesided.shape();
902        let onesided_freqs = onesided_shape.dims()[0];
903
904        // One-sided should have N/2 + 1 frequency bins
905        assert_eq!(onesided_freqs, n_fft / 2 + 1);
906
907        Ok(())
908    }
909
910    #[test]
911    fn test_stft_with_all_window_types() -> TorshResult<()> {
912        // Ensure all window types work correctly in STFT
913        let signal = randn(&[1024], None, None, None)?;
914
915        let windows = vec![
916            WindowFunction::Rectangular,
917            WindowFunction::Hann,
918            WindowFunction::Hamming,
919            WindowFunction::Blackman,
920            WindowFunction::Bartlett,
921            WindowFunction::Kaiser(5),
922        ];
923
924        for window in windows {
925            let result = stft_complete(&signal, 256, Some(128), None, window, false, false, true)?;
926            assert_eq!(result.shape().ndim(), 3);
927        }
928
929        Ok(())
930    }
931
932    #[test]
933    fn test_stft_normalized_vs_unnormalized() -> TorshResult<()> {
934        // Test difference between normalized and unnormalized STFT
935        let signal = randn(&[1024], None, None, None)?;
936
937        let normalized = stft_complete(
938            &signal,
939            256,
940            Some(128),
941            None,
942            WindowFunction::Hann,
943            false,
944            true,
945            true,
946        )?;
947        let unnormalized = stft_complete(
948            &signal,
949            256,
950            Some(128),
951            None,
952            WindowFunction::Hann,
953            false,
954            false,
955            true,
956        )?;
957
958        // Shapes should match
959        assert_eq!(normalized.shape().dims(), unnormalized.shape().dims());
960
961        // Magnitudes should differ by normalization factor
962        let norm_data = normalized.data()?;
963        let unnorm_data = unnormalized.data()?;
964
965        // Check that they're different
966        let mut diff_count = 0;
967        for i in 0..norm_data.len().min(unnorm_data.len()) {
968            if (norm_data[i] - unnorm_data[i]).abs() > 1e-6 {
969                diff_count += 1;
970            }
971        }
972        assert!(diff_count > 0);
973
974        Ok(())
975    }
976
977    #[test]
978    fn test_kaiser_window_beta_parameter() -> TorshResult<()> {
979        // Test that Kaiser window behaves correctly with different beta values
980        let size = 256;
981
982        let kaiser_low = generate_window(WindowFunction::Kaiser(0), size)?;
983        let kaiser_high = generate_window(WindowFunction::Kaiser(10), size)?;
984
985        // Higher beta should result in narrower main lobe (lower values at edges)
986        assert!(kaiser_high[0] < kaiser_low[0]);
987        assert!(kaiser_high[size - 1] < kaiser_low[size - 1]);
988
989        Ok(())
990    }
991}