Skip to main content

scirs2_transform/signal_transforms/
stft.rs

1//! Short-Time Fourier Transform (STFT) and Spectrogram Implementation
2//!
3//! Provides time-frequency analysis using STFT with various window functions.
4
5use crate::error::{Result, TransformError};
6use rayon::prelude::*;
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
8use scirs2_core::numeric::Complex;
9use scirs2_fft::fft;
10use std::f64::consts::PI;
11
12/// Window function types for STFT
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum WindowType {
15    /// Hann window
16    Hann,
17    /// Hamming window
18    Hamming,
19    /// Blackman window
20    Blackman,
21    /// Bartlett (triangular) window
22    Bartlett,
23    /// Rectangular window (no windowing)
24    Rectangular,
25    /// Kaiser window with beta parameter
26    Kaiser(f64),
27    /// Tukey window with alpha parameter
28    Tukey(f64),
29}
30
31impl WindowType {
32    /// Generate window function
33    pub fn generate(&self, n: usize) -> Array1<f64> {
34        match self {
35            WindowType::Hann => Self::hann(n),
36            WindowType::Hamming => Self::hamming(n),
37            WindowType::Blackman => Self::blackman(n),
38            WindowType::Bartlett => Self::bartlett(n),
39            WindowType::Rectangular => Array1::ones(n),
40            WindowType::Kaiser(beta) => Self::kaiser(n, *beta),
41            WindowType::Tukey(alpha) => Self::tukey(n, *alpha),
42        }
43    }
44
45    fn hann(n: usize) -> Array1<f64> {
46        Array1::from_vec(
47            (0..n)
48                .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f64 / (n - 1) as f64).cos()))
49                .collect(),
50        )
51    }
52
53    fn hamming(n: usize) -> Array1<f64> {
54        Array1::from_vec(
55            (0..n)
56                .map(|i| 0.54 - 0.46 * (2.0 * PI * i as f64 / (n - 1) as f64).cos())
57                .collect(),
58        )
59    }
60
61    fn blackman(n: usize) -> Array1<f64> {
62        Array1::from_vec(
63            (0..n)
64                .map(|i| {
65                    let angle = 2.0 * PI * i as f64 / (n - 1) as f64;
66                    0.42 - 0.5 * angle.cos() + 0.08 * (2.0 * angle).cos()
67                })
68                .collect(),
69        )
70    }
71
72    fn bartlett(n: usize) -> Array1<f64> {
73        Array1::from_vec(
74            (0..n)
75                .map(|i| 1.0 - (2.0 * (i as f64 - (n - 1) as f64 / 2.0).abs() / (n - 1) as f64))
76                .collect(),
77        )
78    }
79
80    fn kaiser(n: usize, beta: f64) -> Array1<f64> {
81        let i0_beta = Self::bessel_i0(beta);
82        Array1::from_vec(
83            (0..n)
84                .map(|i| {
85                    let x = 2.0 * i as f64 / (n - 1) as f64 - 1.0;
86                    let arg = beta * (1.0 - x * x).sqrt();
87                    Self::bessel_i0(arg) / i0_beta
88                })
89                .collect(),
90        )
91    }
92
93    fn tukey(n: usize, alpha: f64) -> Array1<f64> {
94        let alpha = alpha.clamp(0.0, 1.0);
95        Array1::from_vec(
96            (0..n)
97                .map(|i| {
98                    let x = i as f64 / (n - 1) as f64;
99                    if x < alpha / 2.0 {
100                        0.5 * (1.0 + (2.0 * PI * x / alpha - PI).cos())
101                    } else if x > 1.0 - alpha / 2.0 {
102                        0.5 * (1.0 + (2.0 * PI * (1.0 - x) / alpha - PI).cos())
103                    } else {
104                        1.0
105                    }
106                })
107                .collect(),
108        )
109    }
110
111    /// Modified Bessel function of the first kind, order 0
112    fn bessel_i0(x: f64) -> f64 {
113        let mut sum = 1.0;
114        let mut term = 1.0;
115        let threshold = 1e-12;
116
117        for k in 1..50 {
118            term *= (x / 2.0) * (x / 2.0) / (k as f64 * k as f64);
119            sum += term;
120            if term < threshold {
121                break;
122            }
123        }
124
125        sum
126    }
127}
128
129/// STFT configuration
130#[derive(Debug, Clone)]
131pub struct STFTConfig {
132    /// Window size (number of samples)
133    pub window_size: usize,
134    /// Hop size (number of samples to advance between windows)
135    pub hop_size: usize,
136    /// Window function type
137    pub window_type: WindowType,
138    /// FFT size (zero-padding if > window_size)
139    pub nfft: Option<usize>,
140    /// Whether to return only positive frequencies
141    pub onesided: bool,
142    /// Padding mode for signal edges
143    pub padding: PaddingMode,
144}
145
146/// Padding mode for STFT
147#[derive(Debug, Clone, Copy, PartialEq)]
148pub enum PaddingMode {
149    /// No padding
150    None,
151    /// Zero padding
152    Zero,
153    /// Constant padding (edge values)
154    Edge,
155    /// Reflect padding
156    Reflect,
157}
158
159impl Default for STFTConfig {
160    fn default() -> Self {
161        STFTConfig {
162            window_size: 256,
163            hop_size: 128,
164            window_type: WindowType::Hann,
165            nfft: None,
166            onesided: true,
167            padding: PaddingMode::Zero,
168        }
169    }
170}
171
172/// Short-Time Fourier Transform
173#[derive(Debug, Clone)]
174pub struct STFT {
175    config: STFTConfig,
176    window: Array1<f64>,
177}
178
179impl STFT {
180    /// Create a new STFT instance with configuration
181    pub fn new(config: STFTConfig) -> Self {
182        let window = config.window_type.generate(config.window_size);
183        STFT { config, window }
184    }
185
186    /// Create with default configuration
187    pub fn default() -> Self {
188        Self::new(STFTConfig::default())
189    }
190
191    /// Create with specified window size and hop size
192    pub fn with_params(window_size: usize, hop_size: usize) -> Self {
193        Self::new(STFTConfig {
194            window_size,
195            hop_size,
196            ..Default::default()
197        })
198    }
199
200    /// Compute the STFT of a signal
201    pub fn transform(&self, signal: &ArrayView1<f64>) -> Result<Array2<Complex<f64>>> {
202        let signal_len = signal.len();
203        if signal_len == 0 {
204            return Err(TransformError::InvalidInput("Empty signal".to_string()));
205        }
206
207        let nfft = self.config.nfft.unwrap_or(self.config.window_size);
208        if nfft < self.config.window_size {
209            return Err(TransformError::InvalidInput(
210                "FFT size must be >= window size".to_string(),
211            ));
212        }
213
214        // Calculate number of frames
215        let n_frames = self.calculate_n_frames(signal_len);
216        let n_freqs = if self.config.onesided {
217            nfft / 2 + 1
218        } else {
219            nfft
220        };
221
222        let mut stft = Array2::from_elem((n_freqs, n_frames), Complex::new(0.0, 0.0));
223
224        // Process each frame
225        for (frame_idx, frame_start) in (0..signal_len)
226            .step_by(self.config.hop_size)
227            .take(n_frames)
228            .enumerate()
229        {
230            let frame = self.extract_frame(signal, frame_start)?;
231            let spectrum = self.compute_frame_spectrum(&frame, nfft)?;
232
233            for (freq_idx, &val) in spectrum.iter().enumerate() {
234                if freq_idx < n_freqs {
235                    stft[[freq_idx, frame_idx]] = val;
236                }
237            }
238        }
239
240        Ok(stft)
241    }
242
243    /// Compute the inverse STFT
244    pub fn inverse(&self, stft: &Array2<Complex<f64>>) -> Result<Array1<f64>> {
245        let (n_freqs, n_frames) = stft.dim();
246
247        if n_frames == 0 {
248            return Err(TransformError::InvalidInput(
249                "No frames in STFT".to_string(),
250            ));
251        }
252
253        let nfft = self.config.nfft.unwrap_or(self.config.window_size);
254
255        // Estimate output length
256        let output_len = (n_frames - 1) * self.config.hop_size + self.config.window_size;
257        let mut output = Array1::zeros(output_len);
258        let mut window_sum: Array1<f64> = Array1::zeros(output_len);
259
260        // Overlap-add synthesis
261        for frame_idx in 0..n_frames {
262            // Extract frame spectrum
263            let mut spectrum = Vec::with_capacity(nfft);
264            for freq_idx in 0..n_freqs {
265                spectrum.push(stft[[freq_idx, frame_idx]]);
266            }
267
268            // Mirror spectrum for onesided case
269            if self.config.onesided && nfft > 1 {
270                for freq_idx in (1..(nfft - n_freqs + 1)).rev() {
271                    if freq_idx < n_freqs {
272                        spectrum.push(spectrum[freq_idx].conj());
273                    }
274                }
275            }
276
277            // Inverse FFT
278            let time_frame = scirs2_fft::ifft(&spectrum, None)?;
279
280            // Overlap-add with windowing
281            let frame_start = frame_idx * self.config.hop_size;
282            for (i, &val) in time_frame.iter().take(self.config.window_size).enumerate() {
283                let idx = frame_start + i;
284                if idx < output_len {
285                    output[idx] += val.re * self.window[i];
286                    window_sum[idx] += self.window[i] * self.window[i];
287                }
288            }
289        }
290
291        // Normalize by window sum
292        for i in 0..output_len {
293            if window_sum[i] > 1e-10 {
294                output[i] /= window_sum[i];
295            }
296        }
297
298        Ok(output)
299    }
300
301    fn extract_frame(&self, signal: &ArrayView1<f64>, start: usize) -> Result<Array1<f64>> {
302        let signal_len = signal.len();
303        let mut frame = Array1::zeros(self.config.window_size);
304
305        match self.config.padding {
306            PaddingMode::None => {
307                let end = (start + self.config.window_size).min(signal_len);
308                for i in 0..(end - start) {
309                    frame[i] = signal[start + i] * self.window[i];
310                }
311            }
312            PaddingMode::Zero => {
313                for i in 0..self.config.window_size {
314                    let idx = start + i;
315                    if idx < signal_len {
316                        frame[i] = signal[idx] * self.window[i];
317                    }
318                }
319            }
320            PaddingMode::Edge => {
321                for i in 0..self.config.window_size {
322                    let idx = (start + i).min(signal_len - 1);
323                    frame[i] = signal[idx] * self.window[i];
324                }
325            }
326            PaddingMode::Reflect => {
327                for i in 0..self.config.window_size {
328                    let mut idx = start as i64 + i as i64;
329                    if idx >= signal_len as i64 {
330                        idx = 2 * signal_len as i64 - idx - 2;
331                    }
332                    if idx < 0 {
333                        idx = -idx;
334                    }
335                    let idx = (idx as usize).min(signal_len - 1);
336                    frame[i] = signal[idx] * self.window[i];
337                }
338            }
339        }
340
341        Ok(frame)
342    }
343
344    fn compute_frame_spectrum(
345        &self,
346        frame: &Array1<f64>,
347        nfft: usize,
348    ) -> Result<Vec<Complex<f64>>> {
349        // Zero-pad if necessary
350        let mut padded = vec![0.0; nfft];
351        for (i, &val) in frame.iter().enumerate() {
352            if i < nfft {
353                padded[i] = val;
354            }
355        }
356
357        Ok(fft(&padded, None)?)
358    }
359
360    fn calculate_n_frames(&self, signal_len: usize) -> usize {
361        if signal_len < self.config.window_size {
362            return 1;
363        }
364        ((signal_len - self.config.window_size) / self.config.hop_size) + 1
365    }
366
367    /// Get the window function
368    pub fn window(&self) -> &Array1<f64> {
369        &self.window
370    }
371
372    /// Get the configuration
373    pub fn config(&self) -> &STFTConfig {
374        &self.config
375    }
376}
377
378/// Spectrogram computation
379#[derive(Debug, Clone)]
380pub struct Spectrogram {
381    stft: STFT,
382    scaling: SpectrogramScaling,
383}
384
385/// Spectrogram scaling modes
386#[derive(Debug, Clone, Copy, PartialEq)]
387pub enum SpectrogramScaling {
388    /// Power spectrum (magnitude squared)
389    Power,
390    /// Magnitude spectrum
391    Magnitude,
392    /// Decibel scale (10 * log10)
393    Decibel,
394}
395
396impl Spectrogram {
397    /// Create a new spectrogram with STFT configuration
398    pub fn new(config: STFTConfig) -> Self {
399        Spectrogram {
400            stft: STFT::new(config),
401            scaling: SpectrogramScaling::Power,
402        }
403    }
404
405    /// Set the scaling mode
406    pub fn with_scaling(mut self, scaling: SpectrogramScaling) -> Self {
407        self.scaling = scaling;
408        self
409    }
410
411    /// Compute the spectrogram
412    pub fn compute(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
413        let stft = self.stft.transform(signal)?;
414        let (n_freqs, n_frames) = stft.dim();
415
416        let mut spectrogram = Array2::zeros((n_freqs, n_frames));
417
418        for i in 0..n_freqs {
419            for j in 0..n_frames {
420                let mag = stft[[i, j]].norm();
421                spectrogram[[i, j]] = match self.scaling {
422                    SpectrogramScaling::Power => mag * mag,
423                    SpectrogramScaling::Magnitude => mag,
424                    SpectrogramScaling::Decibel => {
425                        let power = mag * mag;
426                        if power > 1e-10 {
427                            10.0 * power.log10()
428                        } else {
429                            -100.0 // Floor value
430                        }
431                    }
432                };
433            }
434        }
435
436        Ok(spectrogram)
437    }
438
439    /// Get frequency bins in Hz
440    pub fn frequency_bins(&self, sampling_rate: f64) -> Vec<f64> {
441        let nfft = self
442            .stft
443            .config
444            .nfft
445            .unwrap_or(self.stft.config.window_size);
446        let n_freqs = if self.stft.config.onesided {
447            nfft / 2 + 1
448        } else {
449            nfft
450        };
451
452        (0..n_freqs)
453            .map(|i| i as f64 * sampling_rate / nfft as f64)
454            .collect()
455    }
456
457    /// Get time bins in seconds
458    pub fn time_bins(&self, signal_len: usize, sampling_rate: f64) -> Vec<f64> {
459        let n_frames = self.stft.calculate_n_frames(signal_len);
460        (0..n_frames)
461            .map(|i| (i * self.stft.config.hop_size) as f64 / sampling_rate)
462            .collect()
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use approx::assert_abs_diff_eq;
470
471    #[test]
472    fn test_window_generation() {
473        let hann = WindowType::Hann.generate(64);
474        assert_eq!(hann.len(), 64);
475        assert_abs_diff_eq!(hann[0], 0.0, epsilon = 1e-10);
476        assert_abs_diff_eq!(hann[63], 0.0, epsilon = 1e-10);
477        assert!(hann[32] > 0.9); // Peak near center
478
479        let hamming = WindowType::Hamming.generate(64);
480        assert_eq!(hamming.len(), 64);
481        assert!(hamming[0] > 0.0); // Hamming doesn't go to zero
482    }
483
484    #[test]
485    fn test_stft_simple() -> Result<()> {
486        let signal = Array1::from_vec((0..256).map(|i| (i as f64 * 0.1).sin()).collect());
487        let stft = STFT::with_params(64, 32);
488
489        let result = stft.transform(&signal.view())?;
490
491        assert!(result.dim().0 > 0);
492        assert!(result.dim().1 > 0);
493
494        Ok(())
495    }
496
497    #[test]
498    fn test_stft_inverse() -> Result<()> {
499        let signal = Array1::from_vec((0..256).map(|i| (i as f64 * 0.1).sin()).collect());
500        let stft = STFT::with_params(64, 32);
501
502        let transformed = stft.transform(&signal.view())?;
503        let reconstructed = stft.inverse(&transformed)?;
504
505        // Check that reconstruction is approximately correct
506        assert!(reconstructed.len() > 0);
507
508        Ok(())
509    }
510
511    #[test]
512    fn test_spectrogram() -> Result<()> {
513        let signal = Array1::from_vec((0..512).map(|i| (i as f64 * 0.05).sin()).collect());
514        let config = STFTConfig {
515            window_size: 128,
516            hop_size: 64,
517            ..Default::default()
518        };
519
520        let spectrogram = Spectrogram::new(config);
521        let spec = spectrogram.compute(&signal.view())?;
522
523        assert!(spec.dim().0 > 0);
524        assert!(spec.dim().1 > 0);
525        assert!(spec.iter().all(|&x| x >= 0.0));
526
527        Ok(())
528    }
529
530    #[test]
531    fn test_spectrogram_scaling() -> Result<()> {
532        let signal = Array1::from_vec((0..256).map(|i| (i as f64 * 0.1).sin()).collect());
533        let config = STFTConfig::default();
534
535        let spec_power = Spectrogram::new(config.clone())
536            .with_scaling(SpectrogramScaling::Power)
537            .compute(&signal.view())?;
538
539        let spec_mag = Spectrogram::new(config.clone())
540            .with_scaling(SpectrogramScaling::Magnitude)
541            .compute(&signal.view())?;
542
543        let spec_db = Spectrogram::new(config)
544            .with_scaling(SpectrogramScaling::Decibel)
545            .compute(&signal.view())?;
546
547        assert_eq!(spec_power.dim(), spec_mag.dim());
548        assert_eq!(spec_power.dim(), spec_db.dim());
549
550        Ok(())
551    }
552
553    #[test]
554    fn test_frequency_time_bins() {
555        let config = STFTConfig {
556            window_size: 256,
557            hop_size: 128,
558            ..Default::default()
559        };
560        let spectrogram = Spectrogram::new(config);
561
562        let freqs = spectrogram.frequency_bins(1000.0);
563        let times = spectrogram.time_bins(1000, 1000.0);
564
565        assert!(freqs.len() > 0);
566        assert!(times.len() > 0);
567        assert_abs_diff_eq!(freqs[0], 0.0, epsilon = 1e-10);
568    }
569}