1use scirs2_core::ndarray::{Array2, ArrayView1};
14use thiserror::Error;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum WindowType {
23 Hann,
25 Hamming,
27 Rectangular,
29 Blackman,
31}
32
33#[derive(Debug, Clone)]
35pub struct GpuSpectrogramConfig {
36 pub fft_size: usize,
38 pub hop_size: usize,
40 pub window_type: WindowType,
42 pub batch_size: usize,
44 pub use_gpu: bool,
46}
47
48impl Default for GpuSpectrogramConfig {
49 fn default() -> Self {
50 Self {
51 fft_size: 512,
52 hop_size: 128,
53 window_type: WindowType::Hann,
54 batch_size: 64,
55 use_gpu: false,
56 }
57 }
58}
59
60#[derive(Debug, Error)]
62pub enum GpuSpectrogramError {
63 #[error("Invalid FFT size {0}: must be power of 2")]
65 InvalidFftSize(usize),
66
67 #[error("Signal too short: {0} samples, need at least {1}")]
69 SignalTooShort(usize, usize),
70
71 #[error("Computation error: {0}")]
73 ComputeError(String),
74}
75
76pub struct GpuSpectrogram {
100 config: GpuSpectrogramConfig,
101 window: Vec<f32>,
103}
104
105impl GpuSpectrogram {
106 pub fn new(config: GpuSpectrogramConfig) -> Result<Self, GpuSpectrogramError> {
113 let n = config.fft_size;
114 if n == 0 || !n.is_power_of_two() {
115 return Err(GpuSpectrogramError::InvalidFftSize(n));
116 }
117 let window = Self::compute_window(n, config.window_type);
118 Ok(Self { config, window })
119 }
120
121 pub fn compute(&self, signal: ArrayView1<f32>) -> Result<Array2<f32>, GpuSpectrogramError> {
135 let samples = signal.as_slice().ok_or_else(|| {
136 GpuSpectrogramError::ComputeError("signal must be contiguous".to_string())
137 })?;
138 let frames = self.extract_frames(samples)?;
139 let n_frames = frames.len();
140 let n_bins = self.config.fft_size / 2 + 1;
141
142 let mut output = Array2::<f32>::zeros((n_frames, n_bins));
143 for (i, frame) in frames.iter().enumerate() {
144 let mag = Self::fft_magnitude(frame);
145 for (j, &v) in mag.iter().enumerate() {
146 output[[i, j]] = v;
147 }
148 }
149 Ok(output)
150 }
151
152 pub fn compute_power(
156 &self,
157 signal: ArrayView1<f32>,
158 ) -> Result<Array2<f32>, GpuSpectrogramError> {
159 let mag = self.compute(signal)?;
160 Ok(mag.mapv(|v| v * v))
161 }
162
163 pub fn compute_batch(
173 &self,
174 signals: &[Vec<f32>],
175 ) -> Result<Vec<Array2<f32>>, GpuSpectrogramError> {
176 signals
177 .iter()
178 .map(|s| self.compute(ArrayView1::from(s.as_slice())))
179 .collect()
180 }
181
182 fn extract_frames(&self, samples: &[f32]) -> Result<Vec<Vec<f32>>, GpuSpectrogramError> {
190 let fft_size = self.config.fft_size;
191 let hop = self.config.hop_size;
192
193 if samples.len() < fft_size {
194 return Err(GpuSpectrogramError::SignalTooShort(samples.len(), fft_size));
195 }
196
197 let n_frames = 1 + (samples.len() - fft_size) / hop;
198 let mut frames = Vec::with_capacity(n_frames);
199
200 for k in 0..n_frames {
201 let start = k * hop;
202 let mut frame: Vec<f32> = samples[start..start + fft_size].to_vec();
203 self.apply_window(&mut frame);
204 frames.push(frame);
205 }
206
207 Ok(frames)
208 }
209
210 fn apply_window(&self, frame: &mut Vec<f32>) {
212 for (sample, &w) in frame.iter_mut().zip(self.window.iter()) {
213 *sample *= w;
214 }
215 }
216
217 fn compute_window(fft_size: usize, window_type: WindowType) -> Vec<f32> {
219 let n = fft_size as f32;
220 (0..fft_size)
221 .map(|i| {
222 let phase = std::f32::consts::PI * 2.0 * i as f32 / n;
223 match window_type {
224 WindowType::Hann => 0.5 * (1.0 - phase.cos()),
225 WindowType::Hamming => 0.54 - 0.46 * phase.cos(),
226 WindowType::Rectangular => 1.0,
227 WindowType::Blackman => 0.42 - 0.5 * phase.cos() + 0.08 * (2.0 * phase).cos(),
228 }
229 })
230 .collect()
231 }
232
233 fn fft_magnitude(frame: &[f32]) -> Vec<f32> {
239 let n = frame.len();
240 let n_bins = n / 2 + 1;
241 let mut magnitudes = Vec::with_capacity(n_bins);
242
243 for k in 0..n_bins {
244 let mut re = 0.0_f32;
245 let mut im = 0.0_f32;
246 for (j, &sample) in frame.iter().enumerate() {
247 let angle = -2.0 * std::f32::consts::PI * k as f32 * j as f32 / n as f32;
248 re += sample * angle.cos();
249 im += sample * angle.sin();
250 }
251 magnitudes.push((re * re + im * im).sqrt());
252 }
253
254 magnitudes
255 }
256}
257
258#[cfg(test)]
263mod tests {
264 use super::*;
265 use scirs2_core::ndarray::ArrayView1;
266 use std::f32::consts::PI;
267
268 fn sine_wave(freq_normalised: f32, n_samples: usize) -> Vec<f32> {
269 (0..n_samples)
270 .map(|i| (2.0 * PI * freq_normalised * i as f32).sin())
271 .collect()
272 }
273
274 #[test]
277 fn test_gpu_spectrogram_basic() {
278 let fft_size = 256_usize;
279 let config = GpuSpectrogramConfig {
280 fft_size,
281 hop_size: 128,
282 window_type: WindowType::Hann,
283 batch_size: 16,
284 use_gpu: false,
285 };
286 let sg = GpuSpectrogram::new(config).expect("valid config");
287
288 let freq_norm = 0.125_f32;
290 let expected_bin = (freq_norm * fft_size as f32).round() as usize;
291 let signal = sine_wave(freq_norm, 4 * fft_size);
292
293 let mag = sg
294 .compute(ArrayView1::from(&signal))
295 .expect("compute should succeed");
296
297 for row in 0..mag.nrows() {
299 let frame_row = mag.row(row);
300 let peak_bin = frame_row
301 .iter()
302 .enumerate()
303 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
304 .map(|(idx, _)| idx)
305 .expect("row is non-empty");
306
307 assert!(
309 peak_bin.abs_diff(expected_bin) <= 2,
310 "frame {}: peak bin {} too far from expected {}",
311 row,
312 peak_bin,
313 expected_bin
314 );
315 }
316 }
317
318 #[test]
320 fn test_gpu_spectrogram_shape() {
321 let fft_size = 128_usize;
322 let hop_size = 64_usize;
323 let n_samples = 1024_usize;
324
325 let config = GpuSpectrogramConfig {
326 fft_size,
327 hop_size,
328 window_type: WindowType::Rectangular,
329 batch_size: 8,
330 use_gpu: false,
331 };
332 let sg = GpuSpectrogram::new(config).expect("valid config");
333 let signal = vec![0.0_f32; n_samples];
334 let mag = sg
335 .compute(ArrayView1::from(&signal))
336 .expect("compute should succeed");
337
338 let expected_frames = 1 + (n_samples - fft_size) / hop_size;
339 let expected_bins = fft_size / 2 + 1;
340
341 assert_eq!(
342 mag.dim(),
343 (expected_frames, expected_bins),
344 "unexpected output shape"
345 );
346 }
347
348 #[test]
350 fn test_gpu_spectrogram_batch() {
351 let config = GpuSpectrogramConfig {
352 fft_size: 64,
353 hop_size: 32,
354 window_type: WindowType::Hann,
355 batch_size: 4,
356 use_gpu: false,
357 };
358 let sg = GpuSpectrogram::new(config).expect("valid config");
359
360 let signals: Vec<Vec<f32>> = vec![
361 sine_wave(0.1, 512),
362 sine_wave(0.2, 512),
363 sine_wave(0.3, 512),
364 ];
365
366 let batch_results = sg.compute_batch(&signals).expect("batch compute ok");
367
368 for (idx, signal) in signals.iter().enumerate() {
369 let single = sg
370 .compute(ArrayView1::from(signal.as_slice()))
371 .expect("single compute ok");
372 assert_eq!(
373 batch_results[idx].dim(),
374 single.dim(),
375 "signal {}: shape mismatch between batch and single",
376 idx
377 );
378 for (b, s) in batch_results[idx].iter().zip(single.iter()) {
379 assert!(
380 (b - s).abs() < 1e-5,
381 "signal {}: value mismatch batch={} single={}",
382 idx,
383 b,
384 s
385 );
386 }
387 }
388 }
389
390 #[test]
393 fn test_gpu_spectrogram_power() {
394 let config = GpuSpectrogramConfig {
395 fft_size: 64,
396 hop_size: 32,
397 window_type: WindowType::Hann,
398 batch_size: 4,
399 use_gpu: false,
400 };
401 let sg = GpuSpectrogram::new(config).expect("valid config");
402 let signal = sine_wave(0.1, 512);
403 let view = ArrayView1::from(&signal);
404
405 let mag = sg.compute(view).expect("magnitude compute ok");
406 let power = sg
407 .compute_power(ArrayView1::from(&signal))
408 .expect("power compute ok");
409
410 assert_eq!(mag.dim(), power.dim(), "shape mismatch");
411 for (m, p) in mag.iter().zip(power.iter()) {
412 let expected = m * m;
413 assert!(
414 (p - expected).abs() < 1e-4,
415 "power mismatch: {} vs {} (mag={})",
416 p,
417 expected,
418 m
419 );
420 }
421 }
422
423 #[test]
425 fn test_gpu_spectrogram_invalid_fft_size() {
426 let config = GpuSpectrogramConfig {
427 fft_size: 300, ..Default::default()
429 };
430 assert!(matches!(
431 GpuSpectrogram::new(config),
432 Err(GpuSpectrogramError::InvalidFftSize(300))
433 ));
434 }
435
436 #[test]
439 fn test_gpu_spectrogram_signal_too_short() {
440 let config = GpuSpectrogramConfig {
441 fft_size: 256,
442 hop_size: 128,
443 ..Default::default()
444 };
445 let sg = GpuSpectrogram::new(config).expect("valid config");
446 let short_signal = vec![0.0_f32; 100]; assert!(matches!(
449 sg.compute(ArrayView1::from(&short_signal)),
450 Err(GpuSpectrogramError::SignalTooShort(100, 256))
451 ));
452 }
453
454 #[test]
456 fn test_gpu_spectrogram_all_windows() {
457 let window_types = [
458 WindowType::Hann,
459 WindowType::Hamming,
460 WindowType::Rectangular,
461 WindowType::Blackman,
462 ];
463
464 for wt in window_types {
465 let config = GpuSpectrogramConfig {
466 fft_size: 64,
467 hop_size: 32,
468 window_type: wt,
469 batch_size: 4,
470 use_gpu: false,
471 };
472 let sg = GpuSpectrogram::new(config).expect("valid config");
473 let signal = sine_wave(0.25, 512);
474 let mag = sg
475 .compute(ArrayView1::from(&signal))
476 .expect("compute with window type should succeed");
477
478 for &v in mag.iter() {
479 assert!(
480 v.is_finite() && v >= 0.0,
481 "unexpected value {} for {:?}",
482 v,
483 wt
484 );
485 }
486 }
487 }
488}