scirs2_fft/
spectrogram.rs

1//! Spectrogram module for time-frequency analysis
2//!
3//! This module provides functions for computing spectrograms of signals,
4//! which are visual representations of the spectrum of frequencies as they
5//! vary with time. It builds on the Short-Time Fourier Transform (STFT).
6//!
7//! A spectrogram is useful for analyzing audio signals, vibration data,
8//! and other time-varying signals to understand how frequency content
9//! changes over time.
10
11use crate::error::{FFTError, FFTResult};
12use crate::window::{get_window, Window};
13use scirs2_core::ndarray::{Array2, Axis};
14use scirs2_core::numeric::Complex64;
15use scirs2_core::numeric::NumCast;
16use std::f64::consts::PI;
17
18/// Compute the Short-Time Fourier Transform (STFT) of a signal.
19///
20/// The STFT is used to determine the sinusoidal frequency and phase content
21/// of local sections of a signal as it changes over time.
22///
23/// # Arguments
24///
25/// * `x` - Input signal array
26/// * `window` - Window specification (function or array of length `nperseg`)
27/// * `nperseg` - Length of each segment
28/// * `noverlap` - Number of points to overlap between segments (default: `nperseg // 2`)
29/// * `nfft` - Length of the FFT used (default: `nperseg`)
30/// * `fs` - Sampling frequency of the `x` time series (default: 1.0)
31/// * `detrend` - Whether to remove the mean from each segment (default: true)
32/// * `return_onesided` - If true, return half of the spectrum (real signals) (default: true)
33/// * `boundary` - Behavior at boundaries (default: None)
34///
35/// # Returns
36///
37/// * A tuple containing:
38///   - Frequencies vector (f)
39///   - Time vector (t)
40///   - STFT result matrix (Zxx) where rows are frequencies and columns are time segments
41///
42/// # Errors
43///
44/// Returns an error if the computation fails or if parameters are invalid.
45///
46/// # Examples
47///
48/// ```
49/// use scirs2_fft::spectrogram::stft;
50/// use scirs2_fft::window::Window;
51/// use std::f64::consts::PI;
52///
53/// // Generate a chirp signal
54/// let fs = 1000.0; // 1 kHz sampling rate
55/// let time = (0..1000).map(|i| i as f64 / fs).collect::<Vec<_>>();
56/// let chirp = time.iter().map(|&ti| (2.0 * PI * (10.0 + 10.0 * ti) * ti).sin()).collect::<Vec<_>>();
57///
58/// // Compute STFT
59/// let (f, t, zxx) = stft(
60///     &chirp,
61///     Window::Hann,
62///     256,
63///     Some(128),
64///     None,
65///     Some(fs),
66///     Some(true),
67///     Some(true),
68///     None,
69/// ).unwrap();
70///
71/// // Check dimensions
72/// assert_eq!(f.len(), zxx.shape()[0]);
73/// assert_eq!(t.len(), zxx.shape()[1]);
74/// ```
75#[allow(clippy::too_many_arguments)]
76#[allow(dead_code)]
77pub fn stft<T>(
78    x: &[T],
79    window: Window,
80    nperseg: usize,
81    noverlap: Option<usize>,
82    nfft: Option<usize>,
83    fs: Option<f64>,
84    detrend: Option<bool>,
85    return_onesided: Option<bool>,
86    boundary: Option<&str>,
87) -> FFTResult<(Vec<f64>, Vec<f64>, Array2<Complex64>)>
88where
89    T: NumCast + Copy + std::fmt::Debug,
90{
91    // Input validation
92    if x.is_empty() {
93        return Err(FFTError::ValueError("Input signal is empty".to_string()));
94    }
95
96    if nperseg == 0 {
97        return Err(FFTError::ValueError(
98            "Segment length must be positive".to_string(),
99        ));
100    }
101
102    // Set default parameters
103    let fs = fs.unwrap_or(1.0);
104    if fs <= 0.0 {
105        return Err(FFTError::ValueError(
106            "Sampling frequency must be positive".to_string(),
107        ));
108    }
109
110    let nfft = nfft.unwrap_or(nperseg);
111    if nfft < nperseg {
112        return Err(FFTError::ValueError(
113            "FFT length must be greater than or equal to segment length".to_string(),
114        ));
115    }
116
117    let noverlap = noverlap.unwrap_or(nperseg / 2);
118    if noverlap >= nperseg {
119        return Err(FFTError::ValueError(
120            "Overlap must be less than segment length".to_string(),
121        ));
122    }
123
124    let detrend = detrend.unwrap_or(true);
125    let return_onesided = return_onesided.unwrap_or(true);
126
127    // Convert input to f64
128    let x_f64: Vec<f64> = x
129        .iter()
130        .map(|&val| {
131            NumCast::from(val).ok_or_else(|| {
132                FFTError::ValueError(format!("Could not convert value to f64: {val:?}"))
133            })
134        })
135        .collect::<Result<Vec<_>, _>>()?;
136
137    // Generate window function
138    let win = get_window(window, nperseg, true)?;
139
140    // Calculate step size
141    let step = nperseg - noverlap;
142
143    // Compute number of segments
144    let mut num_frames = 1 + (x_f64.len() - nperseg) / step;
145
146    // Handle boundary conditions
147    let mut padded = x_f64.clone();
148    match boundary {
149        Some("reflect") => {
150            // Reflect signal at boundaries
151            let pad_size = nperseg;
152            let mut reflected = Vec::with_capacity(x_f64.len() + 2 * pad_size);
153            // Left padding
154            for i in (0..pad_size).rev() {
155                reflected.push(x_f64[i]);
156            }
157            // Original signal
158            reflected.extend_from_slice(&x_f64);
159            // Right padding
160            let len = x_f64.len();
161            for i in (len - pad_size..len).rev() {
162                reflected.push(x_f64[i]);
163            }
164            padded = reflected;
165            num_frames = 1 + (padded.len() - nperseg) / step;
166        }
167        Some("zeros") | Some("constant") => {
168            // Pad with zeros or last value
169            let pad_size = nperseg;
170            let mut padded_signal = Vec::with_capacity(x_f64.len() + 2 * pad_size);
171
172            // Left padding
173            if boundary == Some("zeros") {
174                padded_signal.extend(vec![0.0; pad_size]);
175            } else {
176                padded_signal.extend(vec![x_f64[0]; pad_size]);
177            }
178
179            // Original signal
180            padded_signal.extend_from_slice(&x_f64);
181
182            // Right padding
183            if boundary == Some("zeros") {
184                padded_signal.extend(vec![0.0; pad_size]);
185            } else {
186                padded_signal.extend(vec![*x_f64.last().unwrap_or(&0.0); pad_size]);
187            }
188
189            padded = padded_signal;
190            num_frames = 1 + (padded.len() - nperseg) / step;
191        }
192        _ => {}
193    }
194
195    // Calculate frequency values
196    let freq_len = if return_onesided { nfft / 2 + 1 } else { nfft };
197    let frequencies: Vec<f64> = (0..freq_len).map(|i| i as f64 * fs / nfft as f64).collect();
198
199    // Calculate time values (center of each segment)
200    let times: Vec<f64> = (0..num_frames)
201        .map(|i| (i * step + nperseg / 2) as f64 / fs)
202        .collect();
203
204    // Compute STFT
205    let mut stft_matrix = Array2::zeros((freq_len, num_frames));
206
207    for (i, time_idx) in (0..padded.len() - nperseg + 1).step_by(step).enumerate() {
208        if i >= num_frames {
209            break;
210        }
211
212        // Extract segment
213        let segment: Vec<f64> = padded[time_idx..time_idx + nperseg].to_vec();
214
215        // Detrend if required
216        let mut detrended = segment;
217        if detrend {
218            let mean = detrended.iter().sum::<f64>() / detrended.len() as f64;
219            detrended.iter_mut().for_each(|x| *x -= mean);
220        }
221
222        // Apply window
223        let windowed: Vec<f64> = detrended
224            .iter()
225            .zip(win.iter())
226            .map(|(&x, &w)| x * w)
227            .collect();
228
229        // Pad with zeros if nfft > nperseg
230        let mut padded_segment = windowed;
231        if nfft > nperseg {
232            padded_segment.extend(vec![0.0; nfft - nperseg]);
233        }
234
235        // Compute FFT
236        let fft_result = crate::fft::fft(&padded_segment, None)?;
237
238        // Store result (keep only first half for real signals if return_onesided is true)
239        let relevant_fft = if return_onesided {
240            fft_result[0..freq_len].to_vec()
241        } else {
242            fft_result
243        };
244
245        for (j, &value) in relevant_fft.iter().enumerate() {
246            stft_matrix[[j, i]] = value;
247        }
248    }
249
250    Ok((frequencies, times, stft_matrix))
251}
252
253/// Compute a spectrogram of a time-domain signal.
254///
255/// A spectrogram is a visual representation of the frequency spectrum of
256/// a signal as it varies with time. It is often displayed as a heatmap
257/// where the x-axis represents time, the y-axis represents frequency,
258/// and the color intensity represents signal power.
259///
260/// # Arguments
261///
262/// * `x` - Input signal array
263/// * `fs` - Sampling frequency of the signal (default: 1.0)
264/// * `window` - Window specification (function or array of length `nperseg`) (default: Hann)
265/// * `nperseg` - Length of each segment (default: 256)
266/// * `noverlap` - Number of points to overlap between segments (default: `nperseg // 2`)
267/// * `nfft` - Length of the FFT used (default: `nperseg`)
268/// * `detrend` - Whether to remove the mean from each segment (default: true)
269/// * `scaling` - Power spectrum scaling mode: "density" or "spectrum" (default: "density")
270/// * `mode` - Power spectrum mode: "psd", "magnitude", "angle", "phase" (default: "psd")
271///
272/// # Returns
273///
274/// * A tuple containing:
275///   - Frequencies vector (f)
276///   - Time vector (t)
277///   - Spectrogram result matrix (Sxx) where rows are frequencies and columns are time segments
278///
279/// # Errors
280///
281/// Returns an error if the computation fails or if parameters are invalid.
282///
283/// # Examples
284///
285/// ```
286/// use scirs2_fft::spectrogram;
287/// use scirs2_fft::window::Window;
288/// use std::f64::consts::PI;
289///
290/// // Generate a chirp signal
291/// let fs = 1000.0; // 1 kHz sampling rate
292/// let time = (0..1000).map(|i| i as f64 / fs).collect::<Vec<_>>();
293/// let chirp = time.iter().map(|&ti| (2.0 * PI * (10.0 + 50.0 * ti) * ti).sin()).collect::<Vec<_>>();
294///
295/// // Compute spectrogram
296/// let (f, t, sxx) = spectrogram(
297///     &chirp,
298///     Some(fs),
299///     Some(Window::Hann),
300///     Some(128),
301///     Some(64),
302///     None,
303///     Some(true),
304///     Some("density"),
305///     Some("psd"),
306/// ).unwrap();
307///
308/// // Check dimensions
309/// assert_eq!(f.len(), sxx.shape()[0]);
310/// assert_eq!(t.len(), sxx.shape()[1]);
311/// ```
312#[allow(clippy::too_many_arguments)]
313#[allow(dead_code)]
314pub fn spectrogram<T>(
315    x: &[T],
316    fs: Option<f64>,
317    window: Option<Window>,
318    nperseg: Option<usize>,
319    noverlap: Option<usize>,
320    nfft: Option<usize>,
321    detrend: Option<bool>,
322    scaling: Option<&str>,
323    mode: Option<&str>,
324) -> FFTResult<(Vec<f64>, Vec<f64>, Array2<f64>)>
325where
326    T: NumCast + Copy + std::fmt::Debug,
327{
328    // Set default parameters
329    let fs = fs.unwrap_or(1.0);
330    let window = window.unwrap_or(Window::Hann);
331    let nperseg = nperseg.unwrap_or(256);
332
333    // Compute STFT
334    let (frequencies, times, stft_result) = stft(
335        x,
336        window.clone(),
337        nperseg,
338        noverlap,
339        nfft,
340        Some(fs),
341        detrend,
342        Some(true), // Always use onesided for real signals
343        None,
344    )?;
345
346    // Determine scaling factor
347    let win_vals = get_window(window, nperseg, true)?;
348    let win_sum_sq = win_vals.iter().map(|&x| x * x).sum::<f64>();
349
350    let scaling = scaling.unwrap_or("density");
351    let scale_factor = match scaling {
352        "density" => 1.0 / (fs * win_sum_sq),
353        "spectrum" => 1.0 / win_sum_sq,
354        _ => {
355            return Err(FFTError::ValueError(format!(
356                "Unknown scaling mode: {scaling}. Use 'density' or 'spectrum'."
357            )));
358        }
359    };
360
361    // Compute spectrogram based on the requested mode
362    let mode = mode.unwrap_or("psd");
363    let spectrogram_result = match mode {
364        "psd" => {
365            // Power spectral density
366            let mut psd = Array2::zeros(stft_result.dim());
367            for (i, row) in stft_result.axis_iter(Axis(0)).enumerate() {
368                for (j, &val) in row.iter().enumerate() {
369                    psd[[i, j]] = val.norm_sqr() * scale_factor;
370                }
371            }
372            psd
373        }
374        "magnitude" => {
375            // Magnitude spectrum (linear scale)
376            let mut magnitude = Array2::zeros(stft_result.dim());
377            for (i, row) in stft_result.axis_iter(Axis(0)).enumerate() {
378                for (j, &val) in row.iter().enumerate() {
379                    magnitude[[i, j]] = val.norm() * scale_factor.sqrt();
380                }
381            }
382            magnitude
383        }
384        "angle" | "phase" => {
385            // Phase spectrum in radians or degrees
386            let mut phase = Array2::zeros(stft_result.dim());
387            for (i, row) in stft_result.axis_iter(Axis(0)).enumerate() {
388                for (j, &val) in row.iter().enumerate() {
389                    phase[[i, j]] = val.arg();
390                    if mode == "angle" {
391                        // Convert to degrees
392                        phase[[i, j]] = phase[[i, j]] * 180.0 / PI;
393                    }
394                }
395            }
396            phase
397        }
398        _ => {
399            return Err(FFTError::ValueError(format!(
400                "Unknown mode: {mode}. Use 'psd', 'magnitude', 'angle', or 'phase'."
401            )));
402        }
403    };
404
405    Ok((frequencies, times, spectrogram_result))
406}
407
408/// Compute a normalized spectrogram suitable for display as a heatmap.
409///
410/// This is a convenience function that computes a spectrogram and normalizes
411/// its values to a range suitable for visualization. It also applies a
412/// logarithmic scaling to better visualize the dynamic range of the signal.
413///
414/// # Arguments
415///
416/// * `x` - Input signal array
417/// * `fs` - Sampling frequency of the signal (default: 1.0)
418/// * `nperseg` - Length of each segment (default: 256)
419/// * `noverlap` - Number of points to overlap between segments (default: `nperseg // 2`)
420/// * `db_range` - Dynamic range in dB for normalization (default: 80.0)
421///
422/// # Returns
423///
424/// * A tuple containing:
425///   - Frequencies vector (f)
426///   - Time vector (t)
427///   - Normalized spectrogram result matrix (Sxx_norm) with values in [0, 1]
428///
429/// # Errors
430///
431/// Returns an error if the computation fails or if parameters are invalid.
432///
433/// # Examples
434///
435/// ```
436/// use scirs2_fft::spectrogram_normalized;
437/// use std::f64::consts::PI;
438///
439/// // Generate a chirp signal
440/// let fs = 1000.0; // 1 kHz sampling rate
441/// let time = (0..1000).map(|i| i as f64 / fs).collect::<Vec<_>>();
442/// let chirp = time.iter().map(|&ti| (2.0 * PI * (10.0 + 50.0 * ti) * ti).sin()).collect::<Vec<_>>();
443///
444/// // Compute normalized spectrogram
445/// let (f, t, sxx_norm) = spectrogram_normalized(
446///     &chirp,
447///     Some(fs),
448///     Some(128),
449///     Some(64),
450///     Some(80.0),
451/// ).unwrap();
452///
453/// // Values should be normalized to [0, 1]
454/// for row in sxx_norm.axis_iter(scirs2_core::ndarray::Axis(0)) {
455///     for &val in row {
456///         assert!((0.0..=1.0).contains(&val));
457///     }
458/// }
459/// ```
460#[allow(dead_code)]
461pub fn spectrogram_normalized<T>(
462    x: &[T],
463    fs: Option<f64>,
464    nperseg: Option<usize>,
465    noverlap: Option<usize>,
466    db_range: Option<f64>,
467) -> FFTResult<(Vec<f64>, Vec<f64>, Array2<f64>)>
468where
469    T: NumCast + Copy + std::fmt::Debug,
470{
471    // Set default parameters
472    let fs = fs.unwrap_or(1.0);
473    let nperseg = nperseg.unwrap_or(256);
474    let db_range = db_range.unwrap_or(80.0);
475
476    // Compute spectrogram
477    let (frequencies, times, spectrogram_result) = spectrogram(
478        x,
479        Some(fs),
480        Some(Window::Hann),
481        Some(nperseg),
482        noverlap,
483        None,
484        Some(true),
485        Some("density"),
486        Some("psd"),
487    )?;
488
489    // Convert to dB scale with reference to maximum value
490    let max_val = spectrogram_result.iter().fold(f64::MIN, |a, &b| a.max(b));
491
492    if max_val <= 0.0 {
493        return Err(FFTError::ValueError(
494            "Spectrogram has no positive values".to_string(),
495        ));
496    }
497
498    // Convert to dB
499    let mut spec_db = Array2::zeros(spectrogram_result.dim());
500    for (i, row) in spectrogram_result.axis_iter(Axis(0)).enumerate() {
501        for (j, &val) in row.iter().enumerate() {
502            // Avoid taking log of zero
503            let val_db = if val > 0.0 {
504                10.0 * (val / max_val).log10()
505            } else {
506                -db_range
507            };
508            spec_db[[i, j]] = val_db;
509        }
510    }
511
512    // Normalize to [0, 1] _range
513    let mut spec_norm = Array2::zeros(spec_db.dim());
514    for (i, row) in spec_db.axis_iter(Axis(0)).enumerate() {
515        for (j, &val) in row.iter().enumerate() {
516            spec_norm[[i, j]] = (val + db_range).max(0.0).min(db_range) / db_range;
517        }
518    }
519
520    Ok((frequencies, times, spec_norm))
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    // Generate a test signal (sine wave)
528    fn generate_sine_wave(_freq: f64, fs: f64, n_samples: usize) -> Vec<f64> {
529        (0..n_samples)
530            .map(|i| (2.0 * PI * _freq * (i as f64 / fs)).sin())
531            .collect()
532    }
533
534    #[test]
535    fn test_stft_dimensions() {
536        // Generate a sine wave
537        let fs = 1000.0;
538        let signal = generate_sine_wave(100.0, fs, 1000);
539
540        // Compute STFT
541        let nperseg = 256;
542        let noverlap = 128;
543        let (f, t, zxx) = stft(
544            &signal,
545            Window::Hann,
546            nperseg,
547            Some(noverlap),
548            None,
549            Some(fs),
550            Some(true),
551            Some(true),
552            None,
553        )
554        .expect("STFT computation should succeed for test data");
555
556        // Check dimensions
557        let expected_num_freqs = nperseg / 2 + 1;
558        let expected_num_frames = 1 + (signal.len() - nperseg) / (nperseg - noverlap);
559
560        assert_eq!(f.len(), expected_num_freqs);
561        assert_eq!(t.len(), expected_num_frames);
562        assert_eq!(zxx.shape(), &[expected_num_freqs, expected_num_frames]);
563    }
564
565    #[test]
566    fn test_stft_frequency_content() {
567        // Generate a sine wave with known frequency
568        let fs = 1000.0;
569        let freq = 100.0;
570        let signal = generate_sine_wave(freq, fs, 1000);
571
572        // Compute STFT
573        let nperseg = 256;
574        let (f_freq, f_t, zxx) = stft(
575            &signal,
576            Window::Hann,
577            nperseg,
578            Some(128),
579            None,
580            Some(fs),
581            Some(true),
582            Some(true),
583            None,
584        )
585        .expect("STFT computation should succeed for frequency test");
586
587        // Find the frequency bin closest to our signal frequency
588        let freq_idx = f_freq
589            .iter()
590            .enumerate()
591            .min_by(|(_, &a), (_, &b)| {
592                (a - freq)
593                    .abs()
594                    .partial_cmp(&(b - freq).abs())
595                    .expect("Frequency comparison should succeed")
596            })
597            .expect("Should find minimum frequency difference")
598            .0;
599
600        // Check that the power at this frequency is higher than at other frequencies
601        let mean_frame_idx = zxx.shape()[1] / 2; // Use middle frame
602        let power_at_freq = zxx[[freq_idx, mean_frame_idx]].norm_sqr();
603
604        // Calculate average power across all frequencies
605        let total_power: f64 = (0..zxx.shape()[0])
606            .map(|i| zxx[[i, mean_frame_idx]].norm_sqr())
607            .sum();
608        let avg_power = total_power / zxx.shape()[0] as f64;
609
610        // The power at our signal frequency should be much higher than average
611        assert!(power_at_freq > 5.0 * avg_power);
612    }
613
614    #[test]
615    fn test_spectrogram() {
616        // Generate a chirp signal (frequency increasing with time)
617        let fs = 1000.0;
618        let n_samples = 1000;
619        let t: Vec<f64> = (0..n_samples).map(|i| i as f64 / fs).collect();
620        let chirp: Vec<f64> = t
621            .iter()
622            .map(|&ti| (2.0 * PI * (10.0 + 50.0 * ti) * ti).sin())
623            .collect();
624
625        // Compute spectrogram
626        let (f, t, sxx) = spectrogram(
627            &chirp,
628            Some(fs),
629            Some(Window::Hann),
630            Some(128),
631            Some(64),
632            None,
633            Some(true),
634            Some("density"),
635            Some("psd"),
636        )
637        .expect("Spectrogram computation should succeed for test data");
638
639        // Verify basic properties
640        assert!(!f.is_empty());
641        assert!(!t.is_empty());
642        assert_eq!(sxx.shape(), &[f.len(), t.len()]);
643
644        // Values should be non-negative for PSD
645        for &val in sxx.iter() {
646            assert!(val >= 0.0);
647        }
648    }
649
650    #[test]
651    fn test_spectrogram_modes() {
652        // Generate a sine wave
653        let fs = 1000.0;
654        let signal = generate_sine_wave(100.0, fs, 1000);
655
656        // Test different modes
657        let modes = ["psd", "magnitude", "angle", "phase"];
658
659        for &mode in &modes {
660            let (f, t, sxx) = spectrogram(
661                &signal,
662                Some(fs),
663                Some(Window::Hann),
664                Some(128),
665                Some(64),
666                None,
667                Some(true),
668                Some("density"),
669                Some(mode),
670            )
671            .expect("Spectrogram mode computation should succeed");
672
673            // Check dimensions
674            assert!(!f.is_empty());
675            assert!(!t.is_empty());
676            assert_eq!(sxx.shape(), &[f.len(), t.len()]);
677
678            // For phase/angle modes, values should be in expected range
679            if mode == "phase" {
680                for &val in sxx.iter() {
681                    assert!((-PI..=PI).contains(&val));
682                }
683            } else if mode == "angle" {
684                for &val in sxx.iter() {
685                    assert!((-180.0..=180.0).contains(&val));
686                }
687            }
688        }
689    }
690
691    #[test]
692    fn test_spectrogram_normalized() {
693        // Generate a sine wave
694        let fs = 1000.0;
695        let signal = generate_sine_wave(100.0, fs, 1000);
696
697        // Compute normalized spectrogram
698        let (f, t, sxx) =
699            spectrogram_normalized(&signal, Some(fs), Some(128), Some(64), Some(80.0))
700                .expect("Normalized spectrogram should succeed");
701
702        // Check dimensions
703        assert!(!f.is_empty());
704        assert!(!t.is_empty());
705        assert_eq!(sxx.shape(), &[f.len(), t.len()]);
706
707        // Values should be in range [0, 1]
708        for &val in sxx.iter() {
709            assert!((0.0..=1.0).contains(&val));
710        }
711    }
712}