scirs2_transform/signal_transforms/
stft.rs1use crate::error::{Result, TransformError};
6use rayon::prelude::*;
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
8use scirs2_core::numeric::Complex;
9use scirs2_fft::fft;
10use std::f64::consts::PI;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum WindowType {
15 Hann,
17 Hamming,
19 Blackman,
21 Bartlett,
23 Rectangular,
25 Kaiser(f64),
27 Tukey(f64),
29}
30
31impl WindowType {
32 pub fn generate(&self, n: usize) -> Array1<f64> {
34 match self {
35 WindowType::Hann => Self::hann(n),
36 WindowType::Hamming => Self::hamming(n),
37 WindowType::Blackman => Self::blackman(n),
38 WindowType::Bartlett => Self::bartlett(n),
39 WindowType::Rectangular => Array1::ones(n),
40 WindowType::Kaiser(beta) => Self::kaiser(n, *beta),
41 WindowType::Tukey(alpha) => Self::tukey(n, *alpha),
42 }
43 }
44
45 fn hann(n: usize) -> Array1<f64> {
46 Array1::from_vec(
47 (0..n)
48 .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f64 / (n - 1) as f64).cos()))
49 .collect(),
50 )
51 }
52
53 fn hamming(n: usize) -> Array1<f64> {
54 Array1::from_vec(
55 (0..n)
56 .map(|i| 0.54 - 0.46 * (2.0 * PI * i as f64 / (n - 1) as f64).cos())
57 .collect(),
58 )
59 }
60
61 fn blackman(n: usize) -> Array1<f64> {
62 Array1::from_vec(
63 (0..n)
64 .map(|i| {
65 let angle = 2.0 * PI * i as f64 / (n - 1) as f64;
66 0.42 - 0.5 * angle.cos() + 0.08 * (2.0 * angle).cos()
67 })
68 .collect(),
69 )
70 }
71
72 fn bartlett(n: usize) -> Array1<f64> {
73 Array1::from_vec(
74 (0..n)
75 .map(|i| 1.0 - (2.0 * (i as f64 - (n - 1) as f64 / 2.0).abs() / (n - 1) as f64))
76 .collect(),
77 )
78 }
79
80 fn kaiser(n: usize, beta: f64) -> Array1<f64> {
81 let i0_beta = Self::bessel_i0(beta);
82 Array1::from_vec(
83 (0..n)
84 .map(|i| {
85 let x = 2.0 * i as f64 / (n - 1) as f64 - 1.0;
86 let arg = beta * (1.0 - x * x).sqrt();
87 Self::bessel_i0(arg) / i0_beta
88 })
89 .collect(),
90 )
91 }
92
93 fn tukey(n: usize, alpha: f64) -> Array1<f64> {
94 let alpha = alpha.clamp(0.0, 1.0);
95 Array1::from_vec(
96 (0..n)
97 .map(|i| {
98 let x = i as f64 / (n - 1) as f64;
99 if x < alpha / 2.0 {
100 0.5 * (1.0 + (2.0 * PI * x / alpha - PI).cos())
101 } else if x > 1.0 - alpha / 2.0 {
102 0.5 * (1.0 + (2.0 * PI * (1.0 - x) / alpha - PI).cos())
103 } else {
104 1.0
105 }
106 })
107 .collect(),
108 )
109 }
110
111 fn bessel_i0(x: f64) -> f64 {
113 let mut sum = 1.0;
114 let mut term = 1.0;
115 let threshold = 1e-12;
116
117 for k in 1..50 {
118 term *= (x / 2.0) * (x / 2.0) / (k as f64 * k as f64);
119 sum += term;
120 if term < threshold {
121 break;
122 }
123 }
124
125 sum
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct STFTConfig {
132 pub window_size: usize,
134 pub hop_size: usize,
136 pub window_type: WindowType,
138 pub nfft: Option<usize>,
140 pub onesided: bool,
142 pub padding: PaddingMode,
144}
145
146#[derive(Debug, Clone, Copy, PartialEq)]
148pub enum PaddingMode {
149 None,
151 Zero,
153 Edge,
155 Reflect,
157}
158
159impl Default for STFTConfig {
160 fn default() -> Self {
161 STFTConfig {
162 window_size: 256,
163 hop_size: 128,
164 window_type: WindowType::Hann,
165 nfft: None,
166 onesided: true,
167 padding: PaddingMode::Zero,
168 }
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct STFT {
175 config: STFTConfig,
176 window: Array1<f64>,
177}
178
179impl STFT {
180 pub fn new(config: STFTConfig) -> Self {
182 let window = config.window_type.generate(config.window_size);
183 STFT { config, window }
184 }
185
186 pub fn default() -> Self {
188 Self::new(STFTConfig::default())
189 }
190
191 pub fn with_params(window_size: usize, hop_size: usize) -> Self {
193 Self::new(STFTConfig {
194 window_size,
195 hop_size,
196 ..Default::default()
197 })
198 }
199
200 pub fn transform(&self, signal: &ArrayView1<f64>) -> Result<Array2<Complex<f64>>> {
202 let signal_len = signal.len();
203 if signal_len == 0 {
204 return Err(TransformError::InvalidInput("Empty signal".to_string()));
205 }
206
207 let nfft = self.config.nfft.unwrap_or(self.config.window_size);
208 if nfft < self.config.window_size {
209 return Err(TransformError::InvalidInput(
210 "FFT size must be >= window size".to_string(),
211 ));
212 }
213
214 let n_frames = self.calculate_n_frames(signal_len);
216 let n_freqs = if self.config.onesided {
217 nfft / 2 + 1
218 } else {
219 nfft
220 };
221
222 let mut stft = Array2::from_elem((n_freqs, n_frames), Complex::new(0.0, 0.0));
223
224 for (frame_idx, frame_start) in (0..signal_len)
226 .step_by(self.config.hop_size)
227 .take(n_frames)
228 .enumerate()
229 {
230 let frame = self.extract_frame(signal, frame_start)?;
231 let spectrum = self.compute_frame_spectrum(&frame, nfft)?;
232
233 for (freq_idx, &val) in spectrum.iter().enumerate() {
234 if freq_idx < n_freqs {
235 stft[[freq_idx, frame_idx]] = val;
236 }
237 }
238 }
239
240 Ok(stft)
241 }
242
243 pub fn inverse(&self, stft: &Array2<Complex<f64>>) -> Result<Array1<f64>> {
245 let (n_freqs, n_frames) = stft.dim();
246
247 if n_frames == 0 {
248 return Err(TransformError::InvalidInput(
249 "No frames in STFT".to_string(),
250 ));
251 }
252
253 let nfft = self.config.nfft.unwrap_or(self.config.window_size);
254
255 let output_len = (n_frames - 1) * self.config.hop_size + self.config.window_size;
257 let mut output = Array1::zeros(output_len);
258 let mut window_sum: Array1<f64> = Array1::zeros(output_len);
259
260 for frame_idx in 0..n_frames {
262 let mut spectrum = Vec::with_capacity(nfft);
264 for freq_idx in 0..n_freqs {
265 spectrum.push(stft[[freq_idx, frame_idx]]);
266 }
267
268 if self.config.onesided && nfft > 1 {
270 for freq_idx in (1..(nfft - n_freqs + 1)).rev() {
271 if freq_idx < n_freqs {
272 spectrum.push(spectrum[freq_idx].conj());
273 }
274 }
275 }
276
277 let time_frame = scirs2_fft::ifft(&spectrum, None)?;
279
280 let frame_start = frame_idx * self.config.hop_size;
282 for (i, &val) in time_frame.iter().take(self.config.window_size).enumerate() {
283 let idx = frame_start + i;
284 if idx < output_len {
285 output[idx] += val.re * self.window[i];
286 window_sum[idx] += self.window[i] * self.window[i];
287 }
288 }
289 }
290
291 for i in 0..output_len {
293 if window_sum[i] > 1e-10 {
294 output[i] /= window_sum[i];
295 }
296 }
297
298 Ok(output)
299 }
300
301 fn extract_frame(&self, signal: &ArrayView1<f64>, start: usize) -> Result<Array1<f64>> {
302 let signal_len = signal.len();
303 let mut frame = Array1::zeros(self.config.window_size);
304
305 match self.config.padding {
306 PaddingMode::None => {
307 let end = (start + self.config.window_size).min(signal_len);
308 for i in 0..(end - start) {
309 frame[i] = signal[start + i] * self.window[i];
310 }
311 }
312 PaddingMode::Zero => {
313 for i in 0..self.config.window_size {
314 let idx = start + i;
315 if idx < signal_len {
316 frame[i] = signal[idx] * self.window[i];
317 }
318 }
319 }
320 PaddingMode::Edge => {
321 for i in 0..self.config.window_size {
322 let idx = (start + i).min(signal_len - 1);
323 frame[i] = signal[idx] * self.window[i];
324 }
325 }
326 PaddingMode::Reflect => {
327 for i in 0..self.config.window_size {
328 let mut idx = start as i64 + i as i64;
329 if idx >= signal_len as i64 {
330 idx = 2 * signal_len as i64 - idx - 2;
331 }
332 if idx < 0 {
333 idx = -idx;
334 }
335 let idx = (idx as usize).min(signal_len - 1);
336 frame[i] = signal[idx] * self.window[i];
337 }
338 }
339 }
340
341 Ok(frame)
342 }
343
344 fn compute_frame_spectrum(
345 &self,
346 frame: &Array1<f64>,
347 nfft: usize,
348 ) -> Result<Vec<Complex<f64>>> {
349 let mut padded = vec![0.0; nfft];
351 for (i, &val) in frame.iter().enumerate() {
352 if i < nfft {
353 padded[i] = val;
354 }
355 }
356
357 Ok(fft(&padded, None)?)
358 }
359
360 fn calculate_n_frames(&self, signal_len: usize) -> usize {
361 if signal_len < self.config.window_size {
362 return 1;
363 }
364 ((signal_len - self.config.window_size) / self.config.hop_size) + 1
365 }
366
367 pub fn window(&self) -> &Array1<f64> {
369 &self.window
370 }
371
372 pub fn config(&self) -> &STFTConfig {
374 &self.config
375 }
376}
377
378#[derive(Debug, Clone)]
380pub struct Spectrogram {
381 stft: STFT,
382 scaling: SpectrogramScaling,
383}
384
385#[derive(Debug, Clone, Copy, PartialEq)]
387pub enum SpectrogramScaling {
388 Power,
390 Magnitude,
392 Decibel,
394}
395
396impl Spectrogram {
397 pub fn new(config: STFTConfig) -> Self {
399 Spectrogram {
400 stft: STFT::new(config),
401 scaling: SpectrogramScaling::Power,
402 }
403 }
404
405 pub fn with_scaling(mut self, scaling: SpectrogramScaling) -> Self {
407 self.scaling = scaling;
408 self
409 }
410
411 pub fn compute(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
413 let stft = self.stft.transform(signal)?;
414 let (n_freqs, n_frames) = stft.dim();
415
416 let mut spectrogram = Array2::zeros((n_freqs, n_frames));
417
418 for i in 0..n_freqs {
419 for j in 0..n_frames {
420 let mag = stft[[i, j]].norm();
421 spectrogram[[i, j]] = match self.scaling {
422 SpectrogramScaling::Power => mag * mag,
423 SpectrogramScaling::Magnitude => mag,
424 SpectrogramScaling::Decibel => {
425 let power = mag * mag;
426 if power > 1e-10 {
427 10.0 * power.log10()
428 } else {
429 -100.0 }
431 }
432 };
433 }
434 }
435
436 Ok(spectrogram)
437 }
438
439 pub fn frequency_bins(&self, sampling_rate: f64) -> Vec<f64> {
441 let nfft = self
442 .stft
443 .config
444 .nfft
445 .unwrap_or(self.stft.config.window_size);
446 let n_freqs = if self.stft.config.onesided {
447 nfft / 2 + 1
448 } else {
449 nfft
450 };
451
452 (0..n_freqs)
453 .map(|i| i as f64 * sampling_rate / nfft as f64)
454 .collect()
455 }
456
457 pub fn time_bins(&self, signal_len: usize, sampling_rate: f64) -> Vec<f64> {
459 let n_frames = self.stft.calculate_n_frames(signal_len);
460 (0..n_frames)
461 .map(|i| (i * self.stft.config.hop_size) as f64 / sampling_rate)
462 .collect()
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469 use approx::assert_abs_diff_eq;
470
471 #[test]
472 fn test_window_generation() {
473 let hann = WindowType::Hann.generate(64);
474 assert_eq!(hann.len(), 64);
475 assert_abs_diff_eq!(hann[0], 0.0, epsilon = 1e-10);
476 assert_abs_diff_eq!(hann[63], 0.0, epsilon = 1e-10);
477 assert!(hann[32] > 0.9); let hamming = WindowType::Hamming.generate(64);
480 assert_eq!(hamming.len(), 64);
481 assert!(hamming[0] > 0.0); }
483
484 #[test]
485 fn test_stft_simple() -> Result<()> {
486 let signal = Array1::from_vec((0..256).map(|i| (i as f64 * 0.1).sin()).collect());
487 let stft = STFT::with_params(64, 32);
488
489 let result = stft.transform(&signal.view())?;
490
491 assert!(result.dim().0 > 0);
492 assert!(result.dim().1 > 0);
493
494 Ok(())
495 }
496
497 #[test]
498 fn test_stft_inverse() -> Result<()> {
499 let signal = Array1::from_vec((0..256).map(|i| (i as f64 * 0.1).sin()).collect());
500 let stft = STFT::with_params(64, 32);
501
502 let transformed = stft.transform(&signal.view())?;
503 let reconstructed = stft.inverse(&transformed)?;
504
505 assert!(reconstructed.len() > 0);
507
508 Ok(())
509 }
510
511 #[test]
512 fn test_spectrogram() -> Result<()> {
513 let signal = Array1::from_vec((0..512).map(|i| (i as f64 * 0.05).sin()).collect());
514 let config = STFTConfig {
515 window_size: 128,
516 hop_size: 64,
517 ..Default::default()
518 };
519
520 let spectrogram = Spectrogram::new(config);
521 let spec = spectrogram.compute(&signal.view())?;
522
523 assert!(spec.dim().0 > 0);
524 assert!(spec.dim().1 > 0);
525 assert!(spec.iter().all(|&x| x >= 0.0));
526
527 Ok(())
528 }
529
530 #[test]
531 fn test_spectrogram_scaling() -> Result<()> {
532 let signal = Array1::from_vec((0..256).map(|i| (i as f64 * 0.1).sin()).collect());
533 let config = STFTConfig::default();
534
535 let spec_power = Spectrogram::new(config.clone())
536 .with_scaling(SpectrogramScaling::Power)
537 .compute(&signal.view())?;
538
539 let spec_mag = Spectrogram::new(config.clone())
540 .with_scaling(SpectrogramScaling::Magnitude)
541 .compute(&signal.view())?;
542
543 let spec_db = Spectrogram::new(config)
544 .with_scaling(SpectrogramScaling::Decibel)
545 .compute(&signal.view())?;
546
547 assert_eq!(spec_power.dim(), spec_mag.dim());
548 assert_eq!(spec_power.dim(), spec_db.dim());
549
550 Ok(())
551 }
552
553 #[test]
554 fn test_frequency_time_bins() {
555 let config = STFTConfig {
556 window_size: 256,
557 hop_size: 128,
558 ..Default::default()
559 };
560 let spectrogram = Spectrogram::new(config);
561
562 let freqs = spectrogram.frequency_bins(1000.0);
563 let times = spectrogram.time_bins(1000, 1000.0);
564
565 assert!(freqs.len() > 0);
566 assert!(times.len() > 0);
567 assert_abs_diff_eq!(freqs[0], 0.0, epsilon = 1e-10);
568 }
569}