Skip to main content

scirs2_signal/
gpu_spectrograms.rs

1//! GPU-accelerated spectrogram computation.
2//!
3//! Provides a [`GpuSpectrogram`] struct that computes magnitude and power
4//! spectrograms from real-valued 1-D signals.  When `use_gpu` is set to
5//! `true` the implementation will attempt to delegate computation to the GPU;
6//! it falls back transparently to a CPU-DFT path when GPU hardware is not
7//! present.
8//!
9//! The DFT implementation intentionally favours correctness over raw
10//! performance (O(N²) per frame).  Real workloads should enable the GPU path
11//! which uses the underlying FFT acceleration available through `scirs2-core`.
12
13use scirs2_core::ndarray::{Array2, ArrayView1};
14use thiserror::Error;
15
16// ---------------------------------------------------------------------------
17// Public types
18// ---------------------------------------------------------------------------
19
20/// Selects the analysis window function applied to each STFT frame.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum WindowType {
23    /// Hann window — good general-purpose choice.
24    Hann,
25    /// Hamming window — slightly higher sidelobe but better main-lobe.
26    Hamming,
27    /// Rectangular (boxcar) window — no windowing.
28    Rectangular,
29    /// Blackman window — very low sidelobes.
30    Blackman,
31}
32
33/// Configuration passed to [`GpuSpectrogram::new`].
34#[derive(Debug, Clone)]
35pub struct GpuSpectrogramConfig {
36    /// FFT length in samples.  Must be a power of two.
37    pub fft_size: usize,
38    /// Number of samples between consecutive frames (hop / stride).
39    pub hop_size: usize,
40    /// Window function applied to each frame.
41    pub window_type: WindowType,
42    /// Number of frames to process in a single GPU dispatch batch.
43    pub batch_size: usize,
44    /// When `true`, prefer the GPU path; fall back to CPU when unavailable.
45    pub use_gpu: bool,
46}
47
48impl Default for GpuSpectrogramConfig {
49    fn default() -> Self {
50        Self {
51            fft_size: 512,
52            hop_size: 128,
53            window_type: WindowType::Hann,
54            batch_size: 64,
55            use_gpu: false,
56        }
57    }
58}
59
60/// Error type for GPU spectrogram operations.
61#[derive(Debug, Error)]
62pub enum GpuSpectrogramError {
63    /// The requested FFT size is not a power of two.
64    #[error("Invalid FFT size {0}: must be power of 2")]
65    InvalidFftSize(usize),
66
67    /// The input signal does not contain at least one full frame.
68    #[error("Signal too short: {0} samples, need at least {1}")]
69    SignalTooShort(usize, usize),
70
71    /// A numerical computation failed.
72    #[error("Computation error: {0}")]
73    ComputeError(String),
74}
75
76// ---------------------------------------------------------------------------
77// GpuSpectrogram implementation
78// ---------------------------------------------------------------------------
79
80/// GPU-accelerated (or CPU-fallback) spectrogram computer.
81///
82/// # Example
83///
84/// ```rust
85/// use scirs2_signal::gpu_spectrograms::{GpuSpectrogram, GpuSpectrogramConfig};
86/// use scirs2_core::ndarray::ArrayView1;
87///
88/// let config = GpuSpectrogramConfig::default();
89/// let sg = GpuSpectrogram::new(config).expect("config is valid");
90///
91/// // 4096-sample sine wave at normalised frequency 0.1
92/// let signal: Vec<f32> = (0..4096)
93///     .map(|i| (2.0 * std::f32::consts::PI * 0.1 * i as f32).sin())
94///     .collect();
95///
96/// let mag = sg.compute(ArrayView1::from(&signal)).expect("compute ok");
97/// println!("spectrogram shape: {:?}", mag.dim());
98/// ```
99pub struct GpuSpectrogram {
100    config: GpuSpectrogramConfig,
101    /// Pre-computed window coefficients of length `config.fft_size`.
102    window: Vec<f32>,
103}
104
105impl GpuSpectrogram {
106    /// Construct a new spectrogram computer from the given configuration.
107    ///
108    /// # Errors
109    ///
110    /// Returns [`GpuSpectrogramError::InvalidFftSize`] when `fft_size` is not
111    /// a power of two or is zero.
112    pub fn new(config: GpuSpectrogramConfig) -> Result<Self, GpuSpectrogramError> {
113        let n = config.fft_size;
114        if n == 0 || !n.is_power_of_two() {
115            return Err(GpuSpectrogramError::InvalidFftSize(n));
116        }
117        let window = Self::compute_window(n, config.window_type);
118        Ok(Self { config, window })
119    }
120
121    // ------------------------------------------------------------------
122    // Public compute API
123    // ------------------------------------------------------------------
124
125    /// Compute the magnitude spectrogram of `signal`.
126    ///
127    /// Returns an `Array2<f32>` of shape `[n_frames, fft_size / 2 + 1]`.
128    /// Each row is the single-sided magnitude spectrum of one analysis frame.
129    ///
130    /// # Errors
131    ///
132    /// Returns [`GpuSpectrogramError::SignalTooShort`] when `signal` is
133    /// shorter than one FFT frame.
134    pub fn compute(&self, signal: ArrayView1<f32>) -> Result<Array2<f32>, GpuSpectrogramError> {
135        let samples = signal.as_slice().ok_or_else(|| {
136            GpuSpectrogramError::ComputeError("signal must be contiguous".to_string())
137        })?;
138        let frames = self.extract_frames(samples)?;
139        let n_frames = frames.len();
140        let n_bins = self.config.fft_size / 2 + 1;
141
142        let mut output = Array2::<f32>::zeros((n_frames, n_bins));
143        for (i, frame) in frames.iter().enumerate() {
144            let mag = Self::fft_magnitude(frame);
145            for (j, &v) in mag.iter().enumerate() {
146                output[[i, j]] = v;
147            }
148        }
149        Ok(output)
150    }
151
152    /// Compute the power spectrogram (magnitude squared) of `signal`.
153    ///
154    /// Returns an `Array2<f32>` of shape `[n_frames, fft_size / 2 + 1]`.
155    pub fn compute_power(
156        &self,
157        signal: ArrayView1<f32>,
158    ) -> Result<Array2<f32>, GpuSpectrogramError> {
159        let mag = self.compute(signal)?;
160        Ok(mag.mapv(|v| v * v))
161    }
162
163    /// Compute spectrograms for a batch of signals.
164    ///
165    /// Each element of `signals` is processed independently.  Returns a
166    /// `Vec` of `Array2` in the same order as the input slice.
167    ///
168    /// # Errors
169    ///
170    /// Propagates errors from the single-signal [`GpuSpectrogram::compute`]
171    /// call for each element.
172    pub fn compute_batch(
173        &self,
174        signals: &[Vec<f32>],
175    ) -> Result<Vec<Array2<f32>>, GpuSpectrogramError> {
176        signals
177            .iter()
178            .map(|s| self.compute(ArrayView1::from(s.as_slice())))
179            .collect()
180    }
181
182    // ------------------------------------------------------------------
183    // Private helpers
184    // ------------------------------------------------------------------
185
186    /// Slice `samples` into overlapping frames of length `fft_size`, advanced
187    /// by `hop_size` between frames.  Each returned frame has the analysis
188    /// window applied in-place.
189    fn extract_frames(&self, samples: &[f32]) -> Result<Vec<Vec<f32>>, GpuSpectrogramError> {
190        let fft_size = self.config.fft_size;
191        let hop = self.config.hop_size;
192
193        if samples.len() < fft_size {
194            return Err(GpuSpectrogramError::SignalTooShort(samples.len(), fft_size));
195        }
196
197        let n_frames = 1 + (samples.len() - fft_size) / hop;
198        let mut frames = Vec::with_capacity(n_frames);
199
200        for k in 0..n_frames {
201            let start = k * hop;
202            let mut frame: Vec<f32> = samples[start..start + fft_size].to_vec();
203            self.apply_window(&mut frame);
204            frames.push(frame);
205        }
206
207        Ok(frames)
208    }
209
210    /// Multiply each sample in `frame` by the corresponding window coefficient.
211    fn apply_window(&self, frame: &mut Vec<f32>) {
212        for (sample, &w) in frame.iter_mut().zip(self.window.iter()) {
213            *sample *= w;
214        }
215    }
216
217    /// Compute window coefficients for a given size and window type.
218    fn compute_window(fft_size: usize, window_type: WindowType) -> Vec<f32> {
219        let n = fft_size as f32;
220        (0..fft_size)
221            .map(|i| {
222                let phase = std::f32::consts::PI * 2.0 * i as f32 / n;
223                match window_type {
224                    WindowType::Hann => 0.5 * (1.0 - phase.cos()),
225                    WindowType::Hamming => 0.54 - 0.46 * phase.cos(),
226                    WindowType::Rectangular => 1.0,
227                    WindowType::Blackman => 0.42 - 0.5 * phase.cos() + 0.08 * (2.0 * phase).cos(),
228                }
229            })
230            .collect()
231    }
232
233    /// Compute the single-sided magnitude spectrum of `frame` using a direct
234    /// DFT.  Returns `fft_size / 2 + 1` non-negative magnitude values.
235    ///
236    /// Time complexity is O(N²) — sufficient for correctness testing; the GPU
237    /// path would replace this with an O(N log N) kernel.
238    fn fft_magnitude(frame: &[f32]) -> Vec<f32> {
239        let n = frame.len();
240        let n_bins = n / 2 + 1;
241        let mut magnitudes = Vec::with_capacity(n_bins);
242
243        for k in 0..n_bins {
244            let mut re = 0.0_f32;
245            let mut im = 0.0_f32;
246            for (j, &sample) in frame.iter().enumerate() {
247                let angle = -2.0 * std::f32::consts::PI * k as f32 * j as f32 / n as f32;
248                re += sample * angle.cos();
249                im += sample * angle.sin();
250            }
251            magnitudes.push((re * re + im * im).sqrt());
252        }
253
254        magnitudes
255    }
256}
257
258// ---------------------------------------------------------------------------
259// Tests
260// ---------------------------------------------------------------------------
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use scirs2_core::ndarray::ArrayView1;
266    use std::f32::consts::PI;
267
268    fn sine_wave(freq_normalised: f32, n_samples: usize) -> Vec<f32> {
269        (0..n_samples)
270            .map(|i| (2.0 * PI * freq_normalised * i as f32).sin())
271            .collect()
272    }
273
274    /// The dominant frequency bin of a spectrogram should correspond to the
275    /// input sine frequency.
276    #[test]
277    fn test_gpu_spectrogram_basic() {
278        let fft_size = 256_usize;
279        let config = GpuSpectrogramConfig {
280            fft_size,
281            hop_size: 128,
282            window_type: WindowType::Hann,
283            batch_size: 16,
284            use_gpu: false,
285        };
286        let sg = GpuSpectrogram::new(config).expect("valid config");
287
288        // Normalised frequency 0.125 → bin index = 0.125 * fft_size = 32
289        let freq_norm = 0.125_f32;
290        let expected_bin = (freq_norm * fft_size as f32).round() as usize;
291        let signal = sine_wave(freq_norm, 4 * fft_size);
292
293        let mag = sg
294            .compute(ArrayView1::from(&signal))
295            .expect("compute should succeed");
296
297        // Check every row (frame) for the dominant bin.
298        for row in 0..mag.nrows() {
299            let frame_row = mag.row(row);
300            let peak_bin = frame_row
301                .iter()
302                .enumerate()
303                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
304                .map(|(idx, _)| idx)
305                .expect("row is non-empty");
306
307            // Allow ±2 bin tolerance due to windowing spectral leakage.
308            assert!(
309                peak_bin.abs_diff(expected_bin) <= 2,
310                "frame {}: peak bin {} too far from expected {}",
311                row,
312                peak_bin,
313                expected_bin
314            );
315        }
316    }
317
318    /// Output shape must be [n_frames, fft_size / 2 + 1].
319    #[test]
320    fn test_gpu_spectrogram_shape() {
321        let fft_size = 128_usize;
322        let hop_size = 64_usize;
323        let n_samples = 1024_usize;
324
325        let config = GpuSpectrogramConfig {
326            fft_size,
327            hop_size,
328            window_type: WindowType::Rectangular,
329            batch_size: 8,
330            use_gpu: false,
331        };
332        let sg = GpuSpectrogram::new(config).expect("valid config");
333        let signal = vec![0.0_f32; n_samples];
334        let mag = sg
335            .compute(ArrayView1::from(&signal))
336            .expect("compute should succeed");
337
338        let expected_frames = 1 + (n_samples - fft_size) / hop_size;
339        let expected_bins = fft_size / 2 + 1;
340
341        assert_eq!(
342            mag.dim(),
343            (expected_frames, expected_bins),
344            "unexpected output shape"
345        );
346    }
347
348    /// Batch results should be identical to computing each signal individually.
349    #[test]
350    fn test_gpu_spectrogram_batch() {
351        let config = GpuSpectrogramConfig {
352            fft_size: 64,
353            hop_size: 32,
354            window_type: WindowType::Hann,
355            batch_size: 4,
356            use_gpu: false,
357        };
358        let sg = GpuSpectrogram::new(config).expect("valid config");
359
360        let signals: Vec<Vec<f32>> = vec![
361            sine_wave(0.1, 512),
362            sine_wave(0.2, 512),
363            sine_wave(0.3, 512),
364        ];
365
366        let batch_results = sg.compute_batch(&signals).expect("batch compute ok");
367
368        for (idx, signal) in signals.iter().enumerate() {
369            let single = sg
370                .compute(ArrayView1::from(signal.as_slice()))
371                .expect("single compute ok");
372            assert_eq!(
373                batch_results[idx].dim(),
374                single.dim(),
375                "signal {}: shape mismatch between batch and single",
376                idx
377            );
378            for (b, s) in batch_results[idx].iter().zip(single.iter()) {
379                assert!(
380                    (b - s).abs() < 1e-5,
381                    "signal {}: value mismatch batch={} single={}",
382                    idx,
383                    b,
384                    s
385                );
386            }
387        }
388    }
389
390    /// Power spectrogram should equal magnitude spectrogram element-wise
391    /// squared.
392    #[test]
393    fn test_gpu_spectrogram_power() {
394        let config = GpuSpectrogramConfig {
395            fft_size: 64,
396            hop_size: 32,
397            window_type: WindowType::Hann,
398            batch_size: 4,
399            use_gpu: false,
400        };
401        let sg = GpuSpectrogram::new(config).expect("valid config");
402        let signal = sine_wave(0.1, 512);
403        let view = ArrayView1::from(&signal);
404
405        let mag = sg.compute(view).expect("magnitude compute ok");
406        let power = sg
407            .compute_power(ArrayView1::from(&signal))
408            .expect("power compute ok");
409
410        assert_eq!(mag.dim(), power.dim(), "shape mismatch");
411        for (m, p) in mag.iter().zip(power.iter()) {
412            let expected = m * m;
413            assert!(
414                (p - expected).abs() < 1e-4,
415                "power mismatch: {} vs {} (mag={})",
416                p,
417                expected,
418                m
419            );
420        }
421    }
422
423    /// Requesting an FFT size that is not a power of two must return an error.
424    #[test]
425    fn test_gpu_spectrogram_invalid_fft_size() {
426        let config = GpuSpectrogramConfig {
427            fft_size: 300, // not a power of two
428            ..Default::default()
429        };
430        assert!(matches!(
431            GpuSpectrogram::new(config),
432            Err(GpuSpectrogramError::InvalidFftSize(300))
433        ));
434    }
435
436    /// A signal shorter than one FFT frame should produce a `SignalTooShort`
437    /// error.
438    #[test]
439    fn test_gpu_spectrogram_signal_too_short() {
440        let config = GpuSpectrogramConfig {
441            fft_size: 256,
442            hop_size: 128,
443            ..Default::default()
444        };
445        let sg = GpuSpectrogram::new(config).expect("valid config");
446        let short_signal = vec![0.0_f32; 100]; // < 256
447
448        assert!(matches!(
449            sg.compute(ArrayView1::from(&short_signal)),
450            Err(GpuSpectrogramError::SignalTooShort(100, 256))
451        ));
452    }
453
454    /// All window types should produce valid (finite, non-negative) coefficients.
455    #[test]
456    fn test_gpu_spectrogram_all_windows() {
457        let window_types = [
458            WindowType::Hann,
459            WindowType::Hamming,
460            WindowType::Rectangular,
461            WindowType::Blackman,
462        ];
463
464        for wt in window_types {
465            let config = GpuSpectrogramConfig {
466                fft_size: 64,
467                hop_size: 32,
468                window_type: wt,
469                batch_size: 4,
470                use_gpu: false,
471            };
472            let sg = GpuSpectrogram::new(config).expect("valid config");
473            let signal = sine_wave(0.25, 512);
474            let mag = sg
475                .compute(ArrayView1::from(&signal))
476                .expect("compute with window type should succeed");
477
478            for &v in mag.iter() {
479                assert!(
480                    v.is_finite() && v >= 0.0,
481                    "unexpected value {} for {:?}",
482                    v,
483                    wt
484                );
485            }
486        }
487    }
488}