scirs2_fft/spectrogram.rs
1//! Spectrogram module for time-frequency analysis
2//!
3//! This module provides functions for computing spectrograms of signals,
4//! which are visual representations of the spectrum of frequencies as they
5//! vary with time. It builds on the Short-Time Fourier Transform (STFT).
6//!
7//! A spectrogram is useful for analyzing audio signals, vibration data,
8//! and other time-varying signals to understand how frequency content
9//! changes over time.
10
11use crate::error::{FFTError, FFTResult};
12use crate::window::{get_window, Window};
13use scirs2_core::ndarray::{Array2, Axis};
14use scirs2_core::numeric::Complex64;
15use scirs2_core::numeric::NumCast;
16use std::f64::consts::PI;
17
18/// Compute the Short-Time Fourier Transform (STFT) of a signal.
19///
20/// The STFT is used to determine the sinusoidal frequency and phase content
21/// of local sections of a signal as it changes over time.
22///
23/// # Arguments
24///
25/// * `x` - Input signal array
26/// * `window` - Window specification (function or array of length `nperseg`)
27/// * `nperseg` - Length of each segment
28/// * `noverlap` - Number of points to overlap between segments (default: `nperseg // 2`)
29/// * `nfft` - Length of the FFT used (default: `nperseg`)
30/// * `fs` - Sampling frequency of the `x` time series (default: 1.0)
31/// * `detrend` - Whether to remove the mean from each segment (default: true)
32/// * `return_onesided` - If true, return half of the spectrum (real signals) (default: true)
33/// * `boundary` - Behavior at boundaries (default: None)
34///
35/// # Returns
36///
37/// * A tuple containing:
38/// - Frequencies vector (f)
39/// - Time vector (t)
40/// - STFT result matrix (Zxx) where rows are frequencies and columns are time segments
41///
42/// # Errors
43///
44/// Returns an error if the computation fails or if parameters are invalid.
45///
46/// # Examples
47///
48/// ```
49/// use scirs2_fft::spectrogram::stft;
50/// use scirs2_fft::window::Window;
51/// use std::f64::consts::PI;
52///
53/// // Generate a chirp signal
54/// let fs = 1000.0; // 1 kHz sampling rate
55/// let time = (0..1000).map(|i| i as f64 / fs).collect::<Vec<_>>();
56/// let chirp = time.iter().map(|&ti| (2.0 * PI * (10.0 + 10.0 * ti) * ti).sin()).collect::<Vec<_>>();
57///
58/// // Compute STFT
59/// let (f, t, zxx) = stft(
60/// &chirp,
61/// Window::Hann,
62/// 256,
63/// Some(128),
64/// None,
65/// Some(fs),
66/// Some(true),
67/// Some(true),
68/// None,
69/// ).unwrap();
70///
71/// // Check dimensions
72/// assert_eq!(f.len(), zxx.shape()[0]);
73/// assert_eq!(t.len(), zxx.shape()[1]);
74/// ```
75#[allow(clippy::too_many_arguments)]
76#[allow(dead_code)]
77pub fn stft<T>(
78 x: &[T],
79 window: Window,
80 nperseg: usize,
81 noverlap: Option<usize>,
82 nfft: Option<usize>,
83 fs: Option<f64>,
84 detrend: Option<bool>,
85 return_onesided: Option<bool>,
86 boundary: Option<&str>,
87) -> FFTResult<(Vec<f64>, Vec<f64>, Array2<Complex64>)>
88where
89 T: NumCast + Copy + std::fmt::Debug,
90{
91 // Input validation
92 if x.is_empty() {
93 return Err(FFTError::ValueError("Input signal is empty".to_string()));
94 }
95
96 if nperseg == 0 {
97 return Err(FFTError::ValueError(
98 "Segment length must be positive".to_string(),
99 ));
100 }
101
102 // Set default parameters
103 let fs = fs.unwrap_or(1.0);
104 if fs <= 0.0 {
105 return Err(FFTError::ValueError(
106 "Sampling frequency must be positive".to_string(),
107 ));
108 }
109
110 let nfft = nfft.unwrap_or(nperseg);
111 if nfft < nperseg {
112 return Err(FFTError::ValueError(
113 "FFT length must be greater than or equal to segment length".to_string(),
114 ));
115 }
116
117 let noverlap = noverlap.unwrap_or(nperseg / 2);
118 if noverlap >= nperseg {
119 return Err(FFTError::ValueError(
120 "Overlap must be less than segment length".to_string(),
121 ));
122 }
123
124 let detrend = detrend.unwrap_or(true);
125 let return_onesided = return_onesided.unwrap_or(true);
126
127 // Convert input to f64
128 let x_f64: Vec<f64> = x
129 .iter()
130 .map(|&val| {
131 NumCast::from(val).ok_or_else(|| {
132 FFTError::ValueError(format!("Could not convert value to f64: {val:?}"))
133 })
134 })
135 .collect::<Result<Vec<_>, _>>()?;
136
137 // Generate window function
138 let win = get_window(window, nperseg, true)?;
139
140 // Calculate step size
141 let step = nperseg - noverlap;
142
143 // Compute number of segments
144 let mut num_frames = 1 + (x_f64.len() - nperseg) / step;
145
146 // Handle boundary conditions
147 let mut padded = x_f64.clone();
148 match boundary {
149 Some("reflect") => {
150 // Reflect signal at boundaries
151 let pad_size = nperseg;
152 let mut reflected = Vec::with_capacity(x_f64.len() + 2 * pad_size);
153 // Left padding
154 for i in (0..pad_size).rev() {
155 reflected.push(x_f64[i]);
156 }
157 // Original signal
158 reflected.extend_from_slice(&x_f64);
159 // Right padding
160 let len = x_f64.len();
161 for i in (len - pad_size..len).rev() {
162 reflected.push(x_f64[i]);
163 }
164 padded = reflected;
165 num_frames = 1 + (padded.len() - nperseg) / step;
166 }
167 Some("zeros") | Some("constant") => {
168 // Pad with zeros or last value
169 let pad_size = nperseg;
170 let mut padded_signal = Vec::with_capacity(x_f64.len() + 2 * pad_size);
171
172 // Left padding
173 if boundary == Some("zeros") {
174 padded_signal.extend(vec![0.0; pad_size]);
175 } else {
176 padded_signal.extend(vec![x_f64[0]; pad_size]);
177 }
178
179 // Original signal
180 padded_signal.extend_from_slice(&x_f64);
181
182 // Right padding
183 if boundary == Some("zeros") {
184 padded_signal.extend(vec![0.0; pad_size]);
185 } else {
186 padded_signal.extend(vec![*x_f64.last().unwrap_or(&0.0); pad_size]);
187 }
188
189 padded = padded_signal;
190 num_frames = 1 + (padded.len() - nperseg) / step;
191 }
192 _ => {}
193 }
194
195 // Calculate frequency values
196 let freq_len = if return_onesided { nfft / 2 + 1 } else { nfft };
197 let frequencies: Vec<f64> = (0..freq_len).map(|i| i as f64 * fs / nfft as f64).collect();
198
199 // Calculate time values (center of each segment)
200 let times: Vec<f64> = (0..num_frames)
201 .map(|i| (i * step + nperseg / 2) as f64 / fs)
202 .collect();
203
204 // Compute STFT
205 let mut stft_matrix = Array2::zeros((freq_len, num_frames));
206
207 for (i, time_idx) in (0..padded.len() - nperseg + 1).step_by(step).enumerate() {
208 if i >= num_frames {
209 break;
210 }
211
212 // Extract segment
213 let segment: Vec<f64> = padded[time_idx..time_idx + nperseg].to_vec();
214
215 // Detrend if required
216 let mut detrended = segment;
217 if detrend {
218 let mean = detrended.iter().sum::<f64>() / detrended.len() as f64;
219 detrended.iter_mut().for_each(|x| *x -= mean);
220 }
221
222 // Apply window
223 let windowed: Vec<f64> = detrended
224 .iter()
225 .zip(win.iter())
226 .map(|(&x, &w)| x * w)
227 .collect();
228
229 // Pad with zeros if nfft > nperseg
230 let mut padded_segment = windowed;
231 if nfft > nperseg {
232 padded_segment.extend(vec![0.0; nfft - nperseg]);
233 }
234
235 // Compute FFT
236 let fft_result = crate::fft::fft(&padded_segment, None)?;
237
238 // Store result (keep only first half for real signals if return_onesided is true)
239 let relevant_fft = if return_onesided {
240 fft_result[0..freq_len].to_vec()
241 } else {
242 fft_result
243 };
244
245 for (j, &value) in relevant_fft.iter().enumerate() {
246 stft_matrix[[j, i]] = value;
247 }
248 }
249
250 Ok((frequencies, times, stft_matrix))
251}
252
253/// Compute a spectrogram of a time-domain signal.
254///
255/// A spectrogram is a visual representation of the frequency spectrum of
256/// a signal as it varies with time. It is often displayed as a heatmap
257/// where the x-axis represents time, the y-axis represents frequency,
258/// and the color intensity represents signal power.
259///
260/// # Arguments
261///
262/// * `x` - Input signal array
263/// * `fs` - Sampling frequency of the signal (default: 1.0)
264/// * `window` - Window specification (function or array of length `nperseg`) (default: Hann)
265/// * `nperseg` - Length of each segment (default: 256)
266/// * `noverlap` - Number of points to overlap between segments (default: `nperseg // 2`)
267/// * `nfft` - Length of the FFT used (default: `nperseg`)
268/// * `detrend` - Whether to remove the mean from each segment (default: true)
269/// * `scaling` - Power spectrum scaling mode: "density" or "spectrum" (default: "density")
270/// * `mode` - Power spectrum mode: "psd", "magnitude", "angle", "phase" (default: "psd")
271///
272/// # Returns
273///
274/// * A tuple containing:
275/// - Frequencies vector (f)
276/// - Time vector (t)
277/// - Spectrogram result matrix (Sxx) where rows are frequencies and columns are time segments
278///
279/// # Errors
280///
281/// Returns an error if the computation fails or if parameters are invalid.
282///
283/// # Examples
284///
285/// ```
286/// use scirs2_fft::spectrogram;
287/// use scirs2_fft::window::Window;
288/// use std::f64::consts::PI;
289///
290/// // Generate a chirp signal
291/// let fs = 1000.0; // 1 kHz sampling rate
292/// let time = (0..1000).map(|i| i as f64 / fs).collect::<Vec<_>>();
293/// let chirp = time.iter().map(|&ti| (2.0 * PI * (10.0 + 50.0 * ti) * ti).sin()).collect::<Vec<_>>();
294///
295/// // Compute spectrogram
296/// let (f, t, sxx) = spectrogram(
297/// &chirp,
298/// Some(fs),
299/// Some(Window::Hann),
300/// Some(128),
301/// Some(64),
302/// None,
303/// Some(true),
304/// Some("density"),
305/// Some("psd"),
306/// ).unwrap();
307///
308/// // Check dimensions
309/// assert_eq!(f.len(), sxx.shape()[0]);
310/// assert_eq!(t.len(), sxx.shape()[1]);
311/// ```
312#[allow(clippy::too_many_arguments)]
313#[allow(dead_code)]
314pub fn spectrogram<T>(
315 x: &[T],
316 fs: Option<f64>,
317 window: Option<Window>,
318 nperseg: Option<usize>,
319 noverlap: Option<usize>,
320 nfft: Option<usize>,
321 detrend: Option<bool>,
322 scaling: Option<&str>,
323 mode: Option<&str>,
324) -> FFTResult<(Vec<f64>, Vec<f64>, Array2<f64>)>
325where
326 T: NumCast + Copy + std::fmt::Debug,
327{
328 // Set default parameters
329 let fs = fs.unwrap_or(1.0);
330 let window = window.unwrap_or(Window::Hann);
331 let nperseg = nperseg.unwrap_or(256);
332
333 // Compute STFT
334 let (frequencies, times, stft_result) = stft(
335 x,
336 window.clone(),
337 nperseg,
338 noverlap,
339 nfft,
340 Some(fs),
341 detrend,
342 Some(true), // Always use onesided for real signals
343 None,
344 )?;
345
346 // Determine scaling factor
347 let win_vals = get_window(window, nperseg, true)?;
348 let win_sum_sq = win_vals.iter().map(|&x| x * x).sum::<f64>();
349
350 let scaling = scaling.unwrap_or("density");
351 let scale_factor = match scaling {
352 "density" => 1.0 / (fs * win_sum_sq),
353 "spectrum" => 1.0 / win_sum_sq,
354 _ => {
355 return Err(FFTError::ValueError(format!(
356 "Unknown scaling mode: {scaling}. Use 'density' or 'spectrum'."
357 )));
358 }
359 };
360
361 // Compute spectrogram based on the requested mode
362 let mode = mode.unwrap_or("psd");
363 let spectrogram_result = match mode {
364 "psd" => {
365 // Power spectral density
366 let mut psd = Array2::zeros(stft_result.dim());
367 for (i, row) in stft_result.axis_iter(Axis(0)).enumerate() {
368 for (j, &val) in row.iter().enumerate() {
369 psd[[i, j]] = val.norm_sqr() * scale_factor;
370 }
371 }
372 psd
373 }
374 "magnitude" => {
375 // Magnitude spectrum (linear scale)
376 let mut magnitude = Array2::zeros(stft_result.dim());
377 for (i, row) in stft_result.axis_iter(Axis(0)).enumerate() {
378 for (j, &val) in row.iter().enumerate() {
379 magnitude[[i, j]] = val.norm() * scale_factor.sqrt();
380 }
381 }
382 magnitude
383 }
384 "angle" | "phase" => {
385 // Phase spectrum in radians or degrees
386 let mut phase = Array2::zeros(stft_result.dim());
387 for (i, row) in stft_result.axis_iter(Axis(0)).enumerate() {
388 for (j, &val) in row.iter().enumerate() {
389 phase[[i, j]] = val.arg();
390 if mode == "angle" {
391 // Convert to degrees
392 phase[[i, j]] = phase[[i, j]] * 180.0 / PI;
393 }
394 }
395 }
396 phase
397 }
398 _ => {
399 return Err(FFTError::ValueError(format!(
400 "Unknown mode: {mode}. Use 'psd', 'magnitude', 'angle', or 'phase'."
401 )));
402 }
403 };
404
405 Ok((frequencies, times, spectrogram_result))
406}
407
408/// Compute a normalized spectrogram suitable for display as a heatmap.
409///
410/// This is a convenience function that computes a spectrogram and normalizes
411/// its values to a range suitable for visualization. It also applies a
412/// logarithmic scaling to better visualize the dynamic range of the signal.
413///
414/// # Arguments
415///
416/// * `x` - Input signal array
417/// * `fs` - Sampling frequency of the signal (default: 1.0)
418/// * `nperseg` - Length of each segment (default: 256)
419/// * `noverlap` - Number of points to overlap between segments (default: `nperseg // 2`)
420/// * `db_range` - Dynamic range in dB for normalization (default: 80.0)
421///
422/// # Returns
423///
424/// * A tuple containing:
425/// - Frequencies vector (f)
426/// - Time vector (t)
427/// - Normalized spectrogram result matrix (Sxx_norm) with values in [0, 1]
428///
429/// # Errors
430///
431/// Returns an error if the computation fails or if parameters are invalid.
432///
433/// # Examples
434///
435/// ```
436/// use scirs2_fft::spectrogram_normalized;
437/// use std::f64::consts::PI;
438///
439/// // Generate a chirp signal
440/// let fs = 1000.0; // 1 kHz sampling rate
441/// let time = (0..1000).map(|i| i as f64 / fs).collect::<Vec<_>>();
442/// let chirp = time.iter().map(|&ti| (2.0 * PI * (10.0 + 50.0 * ti) * ti).sin()).collect::<Vec<_>>();
443///
444/// // Compute normalized spectrogram
445/// let (f, t, sxx_norm) = spectrogram_normalized(
446/// &chirp,
447/// Some(fs),
448/// Some(128),
449/// Some(64),
450/// Some(80.0),
451/// ).unwrap();
452///
453/// // Values should be normalized to [0, 1]
454/// for row in sxx_norm.axis_iter(scirs2_core::ndarray::Axis(0)) {
455/// for &val in row {
456/// assert!((0.0..=1.0).contains(&val));
457/// }
458/// }
459/// ```
460#[allow(dead_code)]
461pub fn spectrogram_normalized<T>(
462 x: &[T],
463 fs: Option<f64>,
464 nperseg: Option<usize>,
465 noverlap: Option<usize>,
466 db_range: Option<f64>,
467) -> FFTResult<(Vec<f64>, Vec<f64>, Array2<f64>)>
468where
469 T: NumCast + Copy + std::fmt::Debug,
470{
471 // Set default parameters
472 let fs = fs.unwrap_or(1.0);
473 let nperseg = nperseg.unwrap_or(256);
474 let db_range = db_range.unwrap_or(80.0);
475
476 // Compute spectrogram
477 let (frequencies, times, spectrogram_result) = spectrogram(
478 x,
479 Some(fs),
480 Some(Window::Hann),
481 Some(nperseg),
482 noverlap,
483 None,
484 Some(true),
485 Some("density"),
486 Some("psd"),
487 )?;
488
489 // Convert to dB scale with reference to maximum value
490 let max_val = spectrogram_result.iter().fold(f64::MIN, |a, &b| a.max(b));
491
492 if max_val <= 0.0 {
493 return Err(FFTError::ValueError(
494 "Spectrogram has no positive values".to_string(),
495 ));
496 }
497
498 // Convert to dB
499 let mut spec_db = Array2::zeros(spectrogram_result.dim());
500 for (i, row) in spectrogram_result.axis_iter(Axis(0)).enumerate() {
501 for (j, &val) in row.iter().enumerate() {
502 // Avoid taking log of zero
503 let val_db = if val > 0.0 {
504 10.0 * (val / max_val).log10()
505 } else {
506 -db_range
507 };
508 spec_db[[i, j]] = val_db;
509 }
510 }
511
512 // Normalize to [0, 1] _range
513 let mut spec_norm = Array2::zeros(spec_db.dim());
514 for (i, row) in spec_db.axis_iter(Axis(0)).enumerate() {
515 for (j, &val) in row.iter().enumerate() {
516 spec_norm[[i, j]] = (val + db_range).max(0.0).min(db_range) / db_range;
517 }
518 }
519
520 Ok((frequencies, times, spec_norm))
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 // Generate a test signal (sine wave)
528 fn generate_sine_wave(_freq: f64, fs: f64, n_samples: usize) -> Vec<f64> {
529 (0..n_samples)
530 .map(|i| (2.0 * PI * _freq * (i as f64 / fs)).sin())
531 .collect()
532 }
533
534 #[test]
535 fn test_stft_dimensions() {
536 // Generate a sine wave
537 let fs = 1000.0;
538 let signal = generate_sine_wave(100.0, fs, 1000);
539
540 // Compute STFT
541 let nperseg = 256;
542 let noverlap = 128;
543 let (f, t, zxx) = stft(
544 &signal,
545 Window::Hann,
546 nperseg,
547 Some(noverlap),
548 None,
549 Some(fs),
550 Some(true),
551 Some(true),
552 None,
553 )
554 .expect("STFT computation should succeed for test data");
555
556 // Check dimensions
557 let expected_num_freqs = nperseg / 2 + 1;
558 let expected_num_frames = 1 + (signal.len() - nperseg) / (nperseg - noverlap);
559
560 assert_eq!(f.len(), expected_num_freqs);
561 assert_eq!(t.len(), expected_num_frames);
562 assert_eq!(zxx.shape(), &[expected_num_freqs, expected_num_frames]);
563 }
564
565 #[test]
566 fn test_stft_frequency_content() {
567 // Generate a sine wave with known frequency
568 let fs = 1000.0;
569 let freq = 100.0;
570 let signal = generate_sine_wave(freq, fs, 1000);
571
572 // Compute STFT
573 let nperseg = 256;
574 let (f_freq, f_t, zxx) = stft(
575 &signal,
576 Window::Hann,
577 nperseg,
578 Some(128),
579 None,
580 Some(fs),
581 Some(true),
582 Some(true),
583 None,
584 )
585 .expect("STFT computation should succeed for frequency test");
586
587 // Find the frequency bin closest to our signal frequency
588 let freq_idx = f_freq
589 .iter()
590 .enumerate()
591 .min_by(|(_, &a), (_, &b)| {
592 (a - freq)
593 .abs()
594 .partial_cmp(&(b - freq).abs())
595 .expect("Frequency comparison should succeed")
596 })
597 .expect("Should find minimum frequency difference")
598 .0;
599
600 // Check that the power at this frequency is higher than at other frequencies
601 let mean_frame_idx = zxx.shape()[1] / 2; // Use middle frame
602 let power_at_freq = zxx[[freq_idx, mean_frame_idx]].norm_sqr();
603
604 // Calculate average power across all frequencies
605 let total_power: f64 = (0..zxx.shape()[0])
606 .map(|i| zxx[[i, mean_frame_idx]].norm_sqr())
607 .sum();
608 let avg_power = total_power / zxx.shape()[0] as f64;
609
610 // The power at our signal frequency should be much higher than average
611 assert!(power_at_freq > 5.0 * avg_power);
612 }
613
614 #[test]
615 fn test_spectrogram() {
616 // Generate a chirp signal (frequency increasing with time)
617 let fs = 1000.0;
618 let n_samples = 1000;
619 let t: Vec<f64> = (0..n_samples).map(|i| i as f64 / fs).collect();
620 let chirp: Vec<f64> = t
621 .iter()
622 .map(|&ti| (2.0 * PI * (10.0 + 50.0 * ti) * ti).sin())
623 .collect();
624
625 // Compute spectrogram
626 let (f, t, sxx) = spectrogram(
627 &chirp,
628 Some(fs),
629 Some(Window::Hann),
630 Some(128),
631 Some(64),
632 None,
633 Some(true),
634 Some("density"),
635 Some("psd"),
636 )
637 .expect("Spectrogram computation should succeed for test data");
638
639 // Verify basic properties
640 assert!(!f.is_empty());
641 assert!(!t.is_empty());
642 assert_eq!(sxx.shape(), &[f.len(), t.len()]);
643
644 // Values should be non-negative for PSD
645 for &val in sxx.iter() {
646 assert!(val >= 0.0);
647 }
648 }
649
650 #[test]
651 fn test_spectrogram_modes() {
652 // Generate a sine wave
653 let fs = 1000.0;
654 let signal = generate_sine_wave(100.0, fs, 1000);
655
656 // Test different modes
657 let modes = ["psd", "magnitude", "angle", "phase"];
658
659 for &mode in &modes {
660 let (f, t, sxx) = spectrogram(
661 &signal,
662 Some(fs),
663 Some(Window::Hann),
664 Some(128),
665 Some(64),
666 None,
667 Some(true),
668 Some("density"),
669 Some(mode),
670 )
671 .expect("Spectrogram mode computation should succeed");
672
673 // Check dimensions
674 assert!(!f.is_empty());
675 assert!(!t.is_empty());
676 assert_eq!(sxx.shape(), &[f.len(), t.len()]);
677
678 // For phase/angle modes, values should be in expected range
679 if mode == "phase" {
680 for &val in sxx.iter() {
681 assert!((-PI..=PI).contains(&val));
682 }
683 } else if mode == "angle" {
684 for &val in sxx.iter() {
685 assert!((-180.0..=180.0).contains(&val));
686 }
687 }
688 }
689 }
690
691 #[test]
692 fn test_spectrogram_normalized() {
693 // Generate a sine wave
694 let fs = 1000.0;
695 let signal = generate_sine_wave(100.0, fs, 1000);
696
697 // Compute normalized spectrogram
698 let (f, t, sxx) =
699 spectrogram_normalized(&signal, Some(fs), Some(128), Some(64), Some(80.0))
700 .expect("Normalized spectrogram should succeed");
701
702 // Check dimensions
703 assert!(!f.is_empty());
704 assert!(!t.is_empty());
705 assert_eq!(sxx.shape(), &[f.len(), t.len()]);
706
707 // Values should be in range [0, 1]
708 for &val in sxx.iter() {
709 assert!((0.0..=1.0).contains(&val));
710 }
711 }
712}