1use std::f32::consts::PI;
38
39use crate::error::{FFTError, FFTResult};
40
41#[derive(Debug, Clone)]
47pub struct StftFrame {
48 pub sample_index: usize,
50 pub magnitudes: Vec<f32>,
52 pub phases: Vec<f32>,
54 pub spectrum: Vec<(f32, f32)>,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum WindowFunction {
61 Rectangular,
63 Hann,
65 Hamming,
67 Blackman,
69 FlatTop,
71}
72
73#[derive(Debug, Clone)]
75pub struct RingBufferStftConfig {
76 pub window_size: usize,
78 pub hop_size: usize,
81 pub window_fn: WindowFunction,
83 pub overlap_add: bool,
86}
87
88pub struct RingBufferStft {
95 config: RingBufferStftConfig,
96 buffer: Vec<f32>,
98 write_pos: usize,
100 samples_since_last_frame: usize,
102 total_samples: usize,
104 window: Vec<f32>,
106 overlap_add_buffer: Option<Vec<f32>>,
108}
109
110impl RingBufferStft {
111 pub fn new(config: RingBufferStftConfig) -> FFTResult<Self> {
121 if !config.window_size.is_power_of_two() || config.window_size < 4 {
122 return Err(FFTError::ValueError(format!(
123 "ring_buffer_stft: window_size must be a power of two >= 4, got {}",
124 config.window_size
125 )));
126 }
127 if config.hop_size == 0 || config.hop_size > config.window_size {
128 return Err(FFTError::ValueError(format!(
129 "ring_buffer_stft: hop_size must be in [1, window_size], got {}",
130 config.hop_size
131 )));
132 }
133
134 let window = compute_window(config.window_fn, config.window_size);
135 let overlap_add_buffer = if config.overlap_add {
136 Some(vec![0.0_f32; config.window_size])
137 } else {
138 None
139 };
140
141 Ok(Self {
142 buffer: vec![0.0_f32; config.window_size],
143 write_pos: 0,
144 samples_since_last_frame: 0,
145 total_samples: 0,
146 window,
147 overlap_add_buffer,
148 config,
149 })
150 }
151
152 pub fn push(&mut self, samples: &[f32]) -> Vec<StftFrame> {
159 let mut frames = Vec::new();
160 for &s in samples {
161 self.buffer[self.write_pos] = s;
162 self.write_pos = (self.write_pos + 1) % self.config.window_size;
163 self.total_samples += 1;
164 self.samples_since_last_frame += 1;
165
166 if self.samples_since_last_frame >= self.config.hop_size
167 && self.total_samples >= self.config.window_size
168 {
169 if let Ok(frame) = self.emit_frame() {
170 frames.push(frame);
171 }
172 self.samples_since_last_frame = 0;
173 }
174 }
175 frames
176 }
177
178 pub fn flush(&mut self) -> Vec<StftFrame> {
181 if self.samples_since_last_frame == 0 {
182 return Vec::new();
183 }
184 let needed = self.config.hop_size - self.samples_since_last_frame;
186 let pad = vec![0.0_f32; needed];
187 self.push(&pad)
188 }
189
190 pub fn reconstruct(
206 &mut self,
207 frame: &StftFrame,
208 modified_spectrum: Option<&[(f32, f32)]>,
209 ) -> FFTResult<Vec<f32>> {
210 let ola_buf = self.overlap_add_buffer.as_mut().ok_or_else(|| {
211 FFTError::ValueError(
212 "ring_buffer_stft: reconstruct called but overlap_add is disabled".into(),
213 )
214 })?;
215 let n = self.config.window_size;
216
217 let spec: &[(f32, f32)] = match modified_spectrum {
218 Some(s) => {
219 if s.len() != n {
220 return Err(FFTError::ValueError(format!(
221 "ring_buffer_stft: modified_spectrum length {} != window_size {}",
222 s.len(),
223 n
224 )));
225 }
226 s
227 }
228 None => &frame.spectrum,
229 };
230
231 let mut data: Vec<(f32, f32)> = spec.to_vec();
233 fft_inplace_f32(&mut data, true);
234
235 let hop = self.config.hop_size;
237 for i in 0..n {
238 ola_buf[i] += data[i].0 * self.window[i];
239 }
240
241 let out: Vec<f32> = ola_buf[..hop].to_vec();
243 ola_buf.copy_within(hop..n, 0);
244 for v in &mut ola_buf[n - hop..n] {
245 *v = 0.0;
246 }
247 Ok(out)
248 }
249
250 pub fn latency(&self) -> usize {
254 self.config.window_size - self.config.hop_size
255 }
256
257 pub fn n_freq_bins(&self) -> usize {
259 self.config.window_size / 2 + 1
260 }
261
262 fn emit_frame(&self) -> FFTResult<StftFrame> {
267 let n = self.config.window_size;
268 let mut windowed = self.apply_window_from_ring();
270
271 fft_inplace_f32(&mut windowed, false);
273 let spectrum: Vec<(f32, f32)> = windowed;
274
275 let n_one_sided = n / 2 + 1;
276 let mut magnitudes = Vec::with_capacity(n_one_sided);
277 let mut phases = Vec::with_capacity(n_one_sided);
278 for &(re, im) in &spectrum[..n_one_sided] {
279 magnitudes.push((re * re + im * im).sqrt());
280 phases.push(im.atan2(re));
281 }
282
283 let center = self.total_samples.saturating_sub(n / 2);
285
286 Ok(StftFrame {
287 sample_index: center,
288 magnitudes,
289 phases,
290 spectrum,
291 })
292 }
293
294 fn apply_window_from_ring(&self) -> Vec<(f32, f32)> {
297 let n = self.config.window_size;
298 (0..n)
300 .map(|i| {
301 let idx = (self.write_pos + i) % n;
302 let v = self.buffer[idx] * self.window[i];
303 (v, 0.0_f32)
304 })
305 .collect()
306 }
307}
308
309pub struct StreamingSpectrogram {
319 stft: RingBufferStft,
320 frames: std::collections::VecDeque<Vec<f32>>,
321 max_frames: usize,
322}
323
324impl StreamingSpectrogram {
325 pub fn new(stft_config: RingBufferStftConfig, max_frames: usize) -> FFTResult<Self> {
334 Ok(Self {
335 stft: RingBufferStft::new(stft_config)?,
336 frames: std::collections::VecDeque::new(),
337 max_frames,
338 })
339 }
340
341 pub fn push(&mut self, samples: &[f32]) {
343 let new_frames = self.stft.push(samples);
344 for frame in new_frames {
345 if self.frames.len() == self.max_frames {
346 self.frames.pop_front();
347 }
348 self.frames.push_back(frame.magnitudes);
349 }
350 }
351
352 pub fn get_spectrogram(&self) -> Vec<Vec<f32>> {
355 self.frames.iter().cloned().collect()
356 }
357
358 pub fn n_freq_bins(&self) -> usize {
360 self.stft.n_freq_bins()
361 }
362
363 pub fn n_frames(&self) -> usize {
365 self.frames.len()
366 }
367}
368
369fn compute_window(wf: WindowFunction, n: usize) -> Vec<f32> {
374 match wf {
375 WindowFunction::Rectangular => vec![1.0_f32; n],
376 WindowFunction::Hann => (0..n)
377 .map(|i| 0.5 - 0.5 * (2.0 * PI * i as f32 / (n - 1) as f32).cos())
378 .collect(),
379 WindowFunction::Hamming => (0..n)
380 .map(|i| 0.54 - 0.46 * (2.0 * PI * i as f32 / (n - 1) as f32).cos())
381 .collect(),
382 WindowFunction::Blackman => (0..n)
383 .map(|i| {
384 let x = 2.0 * PI * i as f32 / (n - 1) as f32;
385 0.42 - 0.5 * x.cos() + 0.08 * (2.0 * x).cos()
386 })
387 .collect(),
388 WindowFunction::FlatTop => (0..n)
389 .map(|i| {
390 let x = 2.0 * PI * i as f32 / (n - 1) as f32;
391 1.0 - 1.93_f32 * x.cos() + 1.29_f32 * (2.0 * x).cos() - 0.388_f32 * (3.0 * x).cos()
392 + 0.032_f32 * (4.0 * x).cos()
393 })
394 .collect(),
395 }
396}
397
398fn fft_inplace_f32(data: &mut [(f32, f32)], inverse: bool) {
407 let n = data.len();
408 if n <= 1 {
409 return;
410 }
411
412 let log2_n = n.trailing_zeros() as usize;
414 for i in 0..n {
415 let j = bit_reverse(i, log2_n);
416 if j > i {
417 data.swap(i, j);
418 }
419 }
420
421 let sign: f32 = if inverse { 1.0 } else { -1.0 };
423 let mut half_size = 1_usize;
424 while half_size < n {
425 let full_size = half_size * 2;
426 for k in (0..n).step_by(full_size) {
427 for j in 0..half_size {
428 let angle = sign * PI * j as f32 / half_size as f32;
429 let (cos_a, sin_a) = (angle.cos(), angle.sin());
430 let (ur, ui) = data[k + j];
431 let (vr, vi) = data[k + j + half_size];
432 let (wr, wi) = (cos_a * vr - sin_a * vi, cos_a * vi + sin_a * vr);
433 data[k + j] = (ur + wr, ui + wi);
434 data[k + j + half_size] = (ur - wr, ui - wi);
435 }
436 }
437 half_size = full_size;
438 }
439
440 if inverse {
442 let inv_n = 1.0 / n as f32;
443 for (re, im) in data.iter_mut() {
444 *re *= inv_n;
445 *im *= inv_n;
446 }
447 }
448}
449
450#[inline]
452fn bit_reverse(mut v: usize, bits: usize) -> usize {
453 let mut r = 0_usize;
454 for _ in 0..bits {
455 r = (r << 1) | (v & 1);
456 v >>= 1;
457 }
458 r
459}
460
461#[cfg(test)]
466mod tests {
467 use super::*;
468
469 fn make_sine(n: usize, freq: f32, fs: f32) -> Vec<f32> {
470 (0..n)
471 .map(|i| (2.0 * PI * freq * i as f32 / fs).sin())
472 .collect()
473 }
474
475 #[test]
478 fn test_ring_buffer_stft_frame_count() {
479 let window_size = 128;
480 let hop_size = 64;
481 let n_samples = 1024_usize;
482 let config = RingBufferStftConfig {
483 window_size,
484 hop_size,
485 window_fn: WindowFunction::Hann,
486 overlap_add: false,
487 };
488 let mut proc = RingBufferStft::new(config).expect("valid config");
489 let signal = make_sine(n_samples, 440.0, 8000.0);
490 let frames = proc.push(&signal);
491 let expected = (n_samples - window_size) / hop_size + 1;
493 assert_eq!(
494 frames.len(),
495 expected,
496 "expected {expected} frames, got {}",
497 frames.len()
498 );
499 }
500
501 #[test]
502 fn test_ring_buffer_stft_freq_bins() {
503 let window_size = 64;
504 let config = RingBufferStftConfig {
505 window_size,
506 hop_size: 32,
507 window_fn: WindowFunction::Hamming,
508 overlap_add: false,
509 };
510 let mut proc = RingBufferStft::new(config).expect("valid config");
511 let signal = make_sine(512, 220.0, 8000.0);
512 let frames = proc.push(&signal);
513 assert!(!frames.is_empty(), "no frames generated");
514 for frame in &frames {
515 assert_eq!(
516 frame.magnitudes.len(),
517 window_size / 2 + 1,
518 "wrong number of freq bins"
519 );
520 assert_eq!(frame.phases.len(), window_size / 2 + 1);
521 assert_eq!(frame.spectrum.len(), window_size);
522 }
523 }
524
525 #[test]
526 fn test_ring_buffer_stft_flush_emits_remaining() {
527 let window_size = 64;
528 let hop_size = 32;
529 let config = RingBufferStftConfig {
530 window_size,
531 hop_size,
532 window_fn: WindowFunction::Hann,
533 overlap_add: false,
534 };
535 let mut proc = RingBufferStft::new(config).expect("valid config");
536
537 let init = make_sine(window_size, 100.0, 8000.0);
539 proc.push(&init);
540
541 let partial = vec![0.0_f32; hop_size / 2];
543 proc.push(&partial);
544
545 let flushed = proc.flush();
547 assert!(
548 !flushed.is_empty(),
549 "flush should emit at least one frame for partial hop"
550 );
551 }
552
553 #[test]
554 fn test_ring_buffer_stft_invalid_config() {
555 let cfg = RingBufferStftConfig {
557 window_size: 100,
558 hop_size: 50,
559 window_fn: WindowFunction::Rectangular,
560 overlap_add: false,
561 };
562 assert!(RingBufferStft::new(cfg).is_err());
563
564 let cfg2 = RingBufferStftConfig {
566 window_size: 128,
567 hop_size: 0,
568 window_fn: WindowFunction::Hann,
569 overlap_add: false,
570 };
571 assert!(RingBufferStft::new(cfg2).is_err());
572 }
573
574 #[test]
577 fn test_streaming_spectrogram_push_update() {
578 let config = RingBufferStftConfig {
579 window_size: 64,
580 hop_size: 32,
581 window_fn: WindowFunction::Hann,
582 overlap_add: false,
583 };
584 let mut spec = StreamingSpectrogram::new(config, 10).expect("valid config");
585 let signal = make_sine(512, 440.0, 8000.0);
586 spec.push(&signal);
587 let sg = spec.get_spectrogram();
588 assert!(!sg.is_empty(), "spectrogram should have frames after push");
589 assert_eq!(sg[0].len(), 64 / 2 + 1, "wrong freq bin count");
590 }
591
592 #[test]
593 fn test_streaming_spectrogram_max_frames() {
594 let config = RingBufferStftConfig {
595 window_size: 64,
596 hop_size: 32,
597 window_fn: WindowFunction::Hann,
598 overlap_add: false,
599 };
600 let max_frames = 5;
601 let mut spec = StreamingSpectrogram::new(config, max_frames).expect("valid config");
602 let signal = make_sine(4096, 440.0, 8000.0);
603 spec.push(&signal);
604 assert!(
605 spec.n_frames() <= max_frames,
606 "should not exceed max_frames"
607 );
608 }
609
610 #[test]
613 fn test_reconstruction_roundtrip() {
614 let window_size = 64_usize;
619 let hop_size = 32_usize;
620 let signal = make_sine(512, 440.0, 8000.0);
621
622 let config = RingBufferStftConfig {
623 window_size,
624 hop_size,
625 window_fn: WindowFunction::Hann,
626 overlap_add: true,
627 };
628 let mut proc = RingBufferStft::new(config).expect("valid config");
629 let frames = proc.push(&signal);
630
631 let mut reconstructed: Vec<f32> = Vec::new();
632 for frame in &frames {
633 let chunk = proc
634 .reconstruct(frame, None)
635 .expect("reconstruction should work");
636 reconstructed.extend(chunk);
637 }
638
639 let common_len = reconstructed.len().min(signal.len());
641 if common_len > hop_size {
642 let sig_energy: f32 = signal[hop_size..common_len]
643 .iter()
644 .map(|&x| x * x)
645 .sum::<f32>();
646 let rec_energy: f32 = reconstructed[..common_len - hop_size]
647 .iter()
648 .map(|&x| x * x)
649 .sum::<f32>();
650 if sig_energy > 1e-6 {
651 let ratio = (rec_energy - sig_energy).abs() / sig_energy;
652 assert!(ratio < 0.5, "energy ratio error too large: {ratio}");
653 }
654 }
655 }
656
657 #[test]
660 fn test_fft_inplace_identity() {
661 let n = 16_usize;
663 let original: Vec<(f32, f32)> = (0..n).map(|i| (i as f32, 0.0_f32)).collect();
664 let mut data = original.clone();
665 fft_inplace_f32(&mut data, false);
666 fft_inplace_f32(&mut data, true);
667 for (i, (&(re_orig, im_orig), &(re_rec, im_rec))) in
668 original.iter().zip(data.iter()).enumerate()
669 {
670 assert!(
671 (re_orig - re_rec).abs() < 1e-4,
672 "re mismatch at {i}: {re_orig} vs {re_rec}"
673 );
674 assert!(
675 (im_orig - im_rec).abs() < 1e-4,
676 "im mismatch at {i}: {im_orig} vs {im_rec}"
677 );
678 }
679 }
680}