1use num_traits::{Float, FromPrimitive};
25use rustfft::num_complex::Complex;
26use rustfft::{Fft, FftNum, FftPlanner};
27use std::collections::VecDeque;
28use std::fmt;
29use std::sync::Arc;
30
31mod utils;
32pub use utils::{apply_padding, deinterleave, deinterleave_into, interleave, interleave_into};
33
34pub mod mel;
35
36pub mod prelude {
37 pub use crate::mel::{
38 BatchMelSpectrogram, BatchMelSpectrogramF32, BatchMelSpectrogramF64, MelConfig,
39 MelConfigF32, MelConfigF64, MelFilterbank, MelFilterbankF32, MelFilterbankF64, MelNorm,
40 MelScale, MelSpectrum, MelSpectrumF32, MelSpectrumF64, StreamingMelSpectrogram,
41 StreamingMelSpectrogramF32, StreamingMelSpectrogramF64,
42 };
43 pub use crate::utils::{
44 apply_padding, deinterleave, deinterleave_into, interleave, interleave_into,
45 };
46 pub use crate::{
47 BatchIstft, BatchIstftF32, BatchIstftF64, BatchStft, BatchStftF32, BatchStftF64,
48 MultiChannelStreamingIstft, MultiChannelStreamingIstftF32, MultiChannelStreamingIstftF64,
49 MultiChannelStreamingStft, MultiChannelStreamingStftF32, MultiChannelStreamingStftF64,
50 PadMode, ReconstructionMode, Spectrum, SpectrumF32, SpectrumF64, SpectrumFrame,
51 SpectrumFrameF32, SpectrumFrameF64, StftConfig, StftConfigBuilder, StftConfigBuilderF32,
52 StftConfigBuilderF64, StftConfigF32, StftConfigF64, StreamingIstft, StreamingIstftF32,
53 StreamingIstftF64, StreamingStft, StreamingStftF32, StreamingStftF64, WindowType,
54 };
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum ReconstructionMode {
59 Ola,
61
62 Wola,
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum WindowType {
68 Hann,
69 Hamming,
70 Blackman,
71}
72
73#[derive(Debug, Clone)]
74pub enum ConfigError<T: Float + fmt::Debug> {
75 NolaViolation { min_energy: T, threshold: T },
76 ColaViolation { max_deviation: T, threshold: T },
77 InvalidHopSize,
78 InvalidFftSize,
79}
80
81impl<T: Float + fmt::Display + fmt::Debug> fmt::Display for ConfigError<T> {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 match self {
84 ConfigError::NolaViolation {
85 min_energy,
86 threshold,
87 } => {
88 write!(
89 f,
90 "NOLA condition violated: min_energy={} < threshold={}",
91 min_energy, threshold
92 )
93 }
94 ConfigError::ColaViolation {
95 max_deviation,
96 threshold,
97 } => {
98 write!(
99 f,
100 "COLA condition violated: max_deviation={} > threshold={}",
101 max_deviation, threshold
102 )
103 }
104 ConfigError::InvalidHopSize => write!(f, "Invalid hop size"),
105 ConfigError::InvalidFftSize => write!(f, "Invalid FFT size"),
106 }
107 }
108}
109
110impl<T: Float + fmt::Display + fmt::Debug> std::error::Error for ConfigError<T> {}
111
112#[derive(Debug, Clone, Copy)]
113pub enum PadMode {
114 Reflect,
115 Zero,
116 Edge,
117}
118
119#[derive(Clone)]
120pub struct StftConfig<T: Float> {
121 pub fft_size: usize,
122 pub hop_size: usize,
123 pub window: WindowType,
124 pub reconstruction_mode: ReconstructionMode,
125 _phantom: std::marker::PhantomData<T>,
126}
127
128impl<T: Float + FromPrimitive + fmt::Debug> StftConfig<T> {
129 fn nola_threshold() -> T {
130 T::from(1e-8).unwrap()
131 }
132
133 fn cola_relative_tolerance() -> T {
134 T::from(1e-4).unwrap()
135 }
136
137 #[deprecated(
138 since = "0.4.0",
139 note = "Use `StftConfig::builder()` instead for a more flexible API"
140 )]
141 pub fn new(
142 fft_size: usize,
143 hop_size: usize,
144 window: WindowType,
145 reconstruction_mode: ReconstructionMode,
146 ) -> Result<Self, ConfigError<T>> {
147 if fft_size == 0 || !fft_size.is_power_of_two() {
148 return Err(ConfigError::InvalidFftSize);
149 }
150 if hop_size == 0 || hop_size > fft_size {
151 return Err(ConfigError::InvalidHopSize);
152 }
153
154 let config = Self {
155 fft_size,
156 hop_size,
157 window,
158 reconstruction_mode,
159 _phantom: std::marker::PhantomData,
160 };
161
162 match reconstruction_mode {
164 ReconstructionMode::Ola => config.validate_cola()?,
165 ReconstructionMode::Wola => config.validate_nola()?,
166 }
167
168 Ok(config)
169 }
170
171 pub fn builder() -> StftConfigBuilder<T> {
173 StftConfigBuilder::new()
174 }
175
176 #[allow(deprecated)]
178 pub fn default_4096() -> Self {
179 Self::new(4096, 1024, WindowType::Hann, ReconstructionMode::Ola)
180 .expect("Default config should always be valid")
181 }
182
183 pub fn freq_bins(&self) -> usize {
184 self.fft_size / 2 + 1
185 }
186
187 pub fn overlap_percent(&self) -> T {
188 let one = T::one();
189 let hundred = T::from(100.0).unwrap();
190 (one - T::from(self.hop_size).unwrap() / T::from(self.fft_size).unwrap()) * hundred
191 }
192
193 fn generate_window(&self) -> Vec<T> {
194 generate_window(self.window, self.fft_size)
195 }
196
197 pub fn validate_nola(&self) -> Result<(), ConfigError<T>> {
199 let window = self.generate_window();
200 let num_overlaps = self.fft_size.div_ceil(self.hop_size);
201 let test_len = self.fft_size + (num_overlaps - 1) * self.hop_size;
202 let mut energy = vec![T::zero(); test_len];
203
204 for i in 0..num_overlaps {
205 let offset = i * self.hop_size;
206 for j in 0..self.fft_size {
207 if offset + j < test_len {
208 energy[offset + j] = energy[offset + j] + window[j] * window[j];
209 }
210 }
211 }
212
213 let start = self.fft_size / 2;
215 let end = test_len - self.fft_size / 2;
216 let min_energy = energy[start..end]
217 .iter()
218 .copied()
219 .min_by(|a, b| a.partial_cmp(b).unwrap())
220 .unwrap_or_else(T::zero);
221
222 if min_energy < Self::nola_threshold() {
223 return Err(ConfigError::NolaViolation {
224 min_energy,
225 threshold: Self::nola_threshold(),
226 });
227 }
228
229 Ok(())
230 }
231
232 pub fn validate_cola(&self) -> Result<(), ConfigError<T>> {
234 let window = self.generate_window();
235 let window_len = window.len();
236
237 let mut cola_sum_period = vec![T::zero(); self.hop_size];
238 (0..window_len).for_each(|i| {
239 let idx = i % self.hop_size;
240 cola_sum_period[idx] = cola_sum_period[idx] + window[i];
241 });
242
243 let zero = T::zero();
244 let min_sum = cola_sum_period
245 .iter()
246 .min_by(|a, b| a.partial_cmp(b).unwrap())
247 .unwrap_or(&zero);
248 let max_sum = cola_sum_period
249 .iter()
250 .max_by(|a, b| a.partial_cmp(b).unwrap())
251 .unwrap_or(&zero);
252
253 let epsilon = T::from(1e-9).unwrap();
254 if *max_sum < epsilon {
255 return Err(ConfigError::ColaViolation {
256 max_deviation: T::infinity(),
257 threshold: Self::cola_relative_tolerance(),
258 });
259 }
260
261 let ripple = (*max_sum - *min_sum) / *max_sum;
262
263 let is_compliant = ripple < Self::cola_relative_tolerance();
264
265 if !is_compliant {
266 return Err(ConfigError::ColaViolation {
267 max_deviation: ripple,
268 threshold: Self::cola_relative_tolerance(),
269 });
270 }
271 Ok(())
272 }
273}
274
275pub struct StftConfigBuilder<T: Float> {
277 fft_size: Option<usize>,
278 hop_size: Option<usize>,
279 window: WindowType,
280 reconstruction_mode: ReconstructionMode,
281 _phantom: std::marker::PhantomData<T>,
282}
283
284impl<T: Float + FromPrimitive + fmt::Debug> StftConfigBuilder<T> {
285 pub fn new() -> Self {
287 Self {
288 fft_size: None,
289 hop_size: None,
290 window: WindowType::Hann,
291 reconstruction_mode: ReconstructionMode::Ola,
292 _phantom: std::marker::PhantomData,
293 }
294 }
295
296 pub fn fft_size(mut self, fft_size: usize) -> Self {
298 self.fft_size = Some(fft_size);
299 self
300 }
301
302 pub fn hop_size(mut self, hop_size: usize) -> Self {
304 self.hop_size = Some(hop_size);
305 self
306 }
307
308 pub fn window(mut self, window: WindowType) -> Self {
310 self.window = window;
311 self
312 }
313
314 pub fn reconstruction_mode(mut self, mode: ReconstructionMode) -> Self {
316 self.reconstruction_mode = mode;
317 self
318 }
319
320 #[allow(deprecated)]
327 pub fn build(self) -> Result<StftConfig<T>, ConfigError<T>> {
328 let fft_size = self.fft_size.ok_or(ConfigError::InvalidFftSize)?;
329 let hop_size = self.hop_size.ok_or(ConfigError::InvalidHopSize)?;
330
331 StftConfig::new(fft_size, hop_size, self.window, self.reconstruction_mode)
332 }
333}
334
335impl<T: Float + FromPrimitive + fmt::Debug> Default for StftConfigBuilder<T> {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341fn generate_window<T: Float + FromPrimitive>(window_type: WindowType, size: usize) -> Vec<T> {
342 let pi = T::from(std::f64::consts::PI).unwrap();
343 let two = T::from(2.0).unwrap();
344
345 match window_type {
346 WindowType::Hann => (0..size)
347 .map(|i| {
348 let half = T::from(0.5).unwrap();
349 let one = T::one();
350 let i_t = T::from(i).unwrap();
351 let size_m1 = T::from(size - 1).unwrap();
352 half * (one - (two * pi * i_t / size_m1).cos())
353 })
354 .collect(),
355 WindowType::Hamming => (0..size)
356 .map(|i| {
357 let i_t = T::from(i).unwrap();
358 let size_m1 = T::from(size - 1).unwrap();
359 T::from(0.54).unwrap() - T::from(0.46).unwrap() * (two * pi * i_t / size_m1).cos()
360 })
361 .collect(),
362 WindowType::Blackman => (0..size)
363 .map(|i| {
364 let i_t = T::from(i).unwrap();
365 let size_m1 = T::from(size - 1).unwrap();
366 let angle = two * pi * i_t / size_m1;
367 T::from(0.42).unwrap() - T::from(0.5).unwrap() * angle.cos()
368 + T::from(0.08).unwrap() * (two * angle).cos()
369 })
370 .collect(),
371 }
372}
373
374#[derive(Clone)]
375pub struct SpectrumFrame<T: Float> {
376 pub freq_bins: usize,
377 pub data: Vec<Complex<T>>,
378}
379
380impl<T: Float> SpectrumFrame<T> {
381 pub fn new(freq_bins: usize) -> Self {
382 Self {
383 freq_bins,
384 data: vec![Complex::new(T::zero(), T::zero()); freq_bins],
385 }
386 }
387
388 pub fn from_data(data: Vec<Complex<T>>) -> Self {
389 let freq_bins = data.len();
390 Self { freq_bins, data }
391 }
392
393 pub fn clear(&mut self) {
395 for val in &mut self.data {
396 *val = Complex::new(T::zero(), T::zero());
397 }
398 }
399
400 pub fn resize_if_needed(&mut self, freq_bins: usize) {
402 if self.freq_bins != freq_bins {
403 self.freq_bins = freq_bins;
404 self.data
405 .resize(freq_bins, Complex::new(T::zero(), T::zero()));
406 }
407 }
408
409 pub fn write_from_slice(&mut self, data: &[Complex<T>]) {
411 self.resize_if_needed(data.len());
412 self.data.copy_from_slice(data);
413 }
414
415 #[inline]
417 pub fn magnitude(&self, bin: usize) -> T {
418 let c = &self.data[bin];
419 (c.re * c.re + c.im * c.im).sqrt()
420 }
421
422 #[inline]
424 pub fn phase(&self, bin: usize) -> T {
425 let c = &self.data[bin];
426 c.im.atan2(c.re)
427 }
428
429 pub fn set_magnitude_phase(&mut self, bin: usize, magnitude: T, phase: T) {
431 self.data[bin] = Complex::new(magnitude * phase.cos(), magnitude * phase.sin());
432 }
433
434 pub fn from_magnitude_phase(magnitudes: &[T], phases: &[T]) -> Self {
436 assert_eq!(
437 magnitudes.len(),
438 phases.len(),
439 "Magnitude and phase arrays must have same length"
440 );
441 let freq_bins = magnitudes.len();
442 let data: Vec<Complex<T>> = magnitudes
443 .iter()
444 .zip(phases.iter())
445 .map(|(mag, phase)| Complex::new(*mag * phase.cos(), *mag * phase.sin()))
446 .collect();
447 Self { freq_bins, data }
448 }
449
450 pub fn magnitudes(&self) -> Vec<T> {
452 self.data
453 .iter()
454 .map(|c| (c.re * c.re + c.im * c.im).sqrt())
455 .collect()
456 }
457
458 pub fn phases(&self) -> Vec<T> {
460 self.data.iter().map(|c| c.im.atan2(c.re)).collect()
461 }
462}
463
464#[derive(Clone)]
465pub struct Spectrum<T: Float> {
466 pub num_frames: usize,
467 pub freq_bins: usize,
468 pub data: Vec<T>,
469}
470
471impl<T: Float> Spectrum<T> {
472 pub fn new(num_frames: usize, freq_bins: usize) -> Self {
473 Self {
474 num_frames,
475 freq_bins,
476 data: vec![T::zero(); 2 * num_frames * freq_bins],
477 }
478 }
479
480 #[inline]
481 pub fn real(&self, frame: usize, bin: usize) -> T {
482 self.data[frame * self.freq_bins + bin]
483 }
484
485 #[inline]
486 pub fn imag(&self, frame: usize, bin: usize) -> T {
487 let offset = self.num_frames * self.freq_bins;
488 self.data[offset + frame * self.freq_bins + bin]
489 }
490
491 #[inline]
492 pub fn get_complex(&self, frame: usize, bin: usize) -> Complex<T> {
493 Complex::new(self.real(frame, bin), self.imag(frame, bin))
494 }
495
496 pub fn frames(&self) -> impl Iterator<Item = SpectrumFrame<T>> + '_ {
497 (0..self.num_frames).map(move |frame_idx| {
498 let data: Vec<Complex<T>> = (0..self.freq_bins)
499 .map(|bin| self.get_complex(frame_idx, bin))
500 .collect();
501 SpectrumFrame::from_data(data)
502 })
503 }
504
505 #[inline]
507 pub fn set_real(&mut self, frame: usize, bin: usize, value: T) {
508 self.data[frame * self.freq_bins + bin] = value;
509 }
510
511 #[inline]
513 pub fn set_imag(&mut self, frame: usize, bin: usize, value: T) {
514 let offset = self.num_frames * self.freq_bins;
515 self.data[offset + frame * self.freq_bins + bin] = value;
516 }
517
518 #[inline]
520 pub fn set_complex(&mut self, frame: usize, bin: usize, value: Complex<T>) {
521 self.set_real(frame, bin, value.re);
522 self.set_imag(frame, bin, value.im);
523 }
524
525 #[inline]
527 pub fn magnitude(&self, frame: usize, bin: usize) -> T {
528 let re = self.real(frame, bin);
529 let im = self.imag(frame, bin);
530 (re * re + im * im).sqrt()
531 }
532
533 #[inline]
535 pub fn phase(&self, frame: usize, bin: usize) -> T {
536 let re = self.real(frame, bin);
537 let im = self.imag(frame, bin);
538 im.atan2(re)
539 }
540
541 pub fn set_magnitude_phase(&mut self, frame: usize, bin: usize, magnitude: T, phase: T) {
543 self.set_real(frame, bin, magnitude * phase.cos());
544 self.set_imag(frame, bin, magnitude * phase.sin());
545 }
546
547 pub fn frame_magnitudes(&self, frame: usize) -> Vec<T> {
549 (0..self.freq_bins)
550 .map(|bin| self.magnitude(frame, bin))
551 .collect()
552 }
553
554 pub fn frame_phases(&self, frame: usize) -> Vec<T> {
556 (0..self.freq_bins)
557 .map(|bin| self.phase(frame, bin))
558 .collect()
559 }
560
561 pub fn apply<F>(&mut self, mut f: F)
563 where
564 F: FnMut(usize, usize, Complex<T>) -> Complex<T>,
565 {
566 for frame in 0..self.num_frames {
567 for bin in 0..self.freq_bins {
568 let c = self.get_complex(frame, bin);
569 let new_c = f(frame, bin, c);
570 self.set_complex(frame, bin, new_c);
571 }
572 }
573 }
574
575 pub fn apply_gain(&mut self, bin_range: std::ops::Range<usize>, gain: T) {
577 for frame in 0..self.num_frames {
578 for bin in bin_range.clone() {
579 if bin < self.freq_bins {
580 let c = self.get_complex(frame, bin);
581 self.set_complex(frame, bin, c * gain);
582 }
583 }
584 }
585 }
586
587 pub fn zero_bins(&mut self, bin_range: std::ops::Range<usize>) {
589 for frame in 0..self.num_frames {
590 for bin in bin_range.clone() {
591 if bin < self.freq_bins {
592 self.set_complex(frame, bin, Complex::new(T::zero(), T::zero()));
593 }
594 }
595 }
596 }
597}
598
599pub struct BatchStft<T: Float + FftNum> {
600 config: StftConfig<T>,
601 window: Vec<T>,
602 fft: Arc<dyn Fft<T>>,
603}
604
605impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchStft<T> {
606 pub fn new(config: StftConfig<T>) -> Self {
607 let window = config.generate_window();
608 let mut planner = FftPlanner::new();
609 let fft = planner.plan_fft_forward(config.fft_size);
610
611 Self {
612 config,
613 window,
614 fft,
615 }
616 }
617
618 pub fn process(&self, signal: &[T]) -> Spectrum<T> {
619 self.process_padded(signal, PadMode::Reflect)
620 }
621
622 pub fn process_padded(&self, signal: &[T], pad_mode: PadMode) -> Spectrum<T> {
623 let pad_amount = self.config.fft_size / 2;
624 let padded = utils::apply_padding(signal, pad_amount, pad_mode);
625
626 let num_frames = if padded.len() >= self.config.fft_size {
627 (padded.len() - self.config.fft_size) / self.config.hop_size + 1
628 } else {
629 0
630 };
631
632 let freq_bins = self.config.freq_bins();
633 let mut result = Spectrum::new(num_frames, freq_bins);
634
635 let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
636
637 for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
638 .step_by(self.config.hop_size)
639 .enumerate()
640 {
641 for i in 0..self.config.fft_size {
643 fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
644 }
645
646 self.fft.process(&mut fft_buffer);
648
649 (0..freq_bins).for_each(|bin| {
651 let idx = frame_idx * freq_bins + bin;
652 result.data[idx] = fft_buffer[bin].re;
653 result.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
654 });
655 }
656
657 result
658 }
659
660 pub fn process_into(&self, signal: &[T], spectrum: &mut Spectrum<T>) -> bool {
664 self.process_padded_into(signal, PadMode::Reflect, spectrum)
665 }
666
667 pub fn process_padded_into(
669 &self,
670 signal: &[T],
671 pad_mode: PadMode,
672 spectrum: &mut Spectrum<T>,
673 ) -> bool {
674 let pad_amount = self.config.fft_size / 2;
675 let padded = utils::apply_padding(signal, pad_amount, pad_mode);
676
677 let num_frames = if padded.len() >= self.config.fft_size {
678 (padded.len() - self.config.fft_size) / self.config.hop_size + 1
679 } else {
680 0
681 };
682
683 let freq_bins = self.config.freq_bins();
684
685 if spectrum.num_frames != num_frames || spectrum.freq_bins != freq_bins {
687 return false;
688 }
689
690 let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
691
692 for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
693 .step_by(self.config.hop_size)
694 .enumerate()
695 {
696 for i in 0..self.config.fft_size {
698 fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
699 }
700
701 self.fft.process(&mut fft_buffer);
703
704 (0..freq_bins).for_each(|bin| {
706 let idx = frame_idx * freq_bins + bin;
707 spectrum.data[idx] = fft_buffer[bin].re;
708 spectrum.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
709 });
710 }
711
712 true
713 }
714
715 pub fn process_multichannel(&self, channels: &[Vec<T>]) -> Vec<Spectrum<T>> {
742 assert!(!channels.is_empty(), "channels must not be empty");
743
744 let expected_len = channels[0].len();
746 for (i, channel) in channels.iter().enumerate() {
747 assert_eq!(
748 channel.len(),
749 expected_len,
750 "Channel {} has length {}, expected {}",
751 i,
752 channel.len(),
753 expected_len
754 );
755 }
756
757 #[cfg(feature = "rayon")]
759 {
760 use rayon::prelude::*;
761 channels
762 .par_iter()
763 .map(|channel| self.process(channel))
764 .collect()
765 }
766 #[cfg(not(feature = "rayon"))]
767 {
768 channels
769 .iter()
770 .map(|channel| self.process(channel))
771 .collect()
772 }
773 }
774
775 pub fn process_interleaved(&self, data: &[T], num_channels: usize) -> Vec<Spectrum<T>> {
803 let channels = utils::deinterleave(data, num_channels);
804 self.process_multichannel(&channels)
805 }
806}
807
808pub struct BatchIstft<T: Float + FftNum> {
809 config: StftConfig<T>,
810 window: Vec<T>,
811 ifft: Arc<dyn Fft<T>>,
812}
813
814impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchIstft<T> {
815 pub fn new(config: StftConfig<T>) -> Self {
816 let window = config.generate_window();
817 let mut planner = FftPlanner::new();
818 let ifft = planner.plan_fft_inverse(config.fft_size);
819
820 Self {
821 config,
822 window,
823 ifft,
824 }
825 }
826
827 pub fn process(&self, spectrum: &Spectrum<T>) -> Vec<T> {
828 assert_eq!(
829 spectrum.freq_bins,
830 self.config.freq_bins(),
831 "Frequency bins mismatch"
832 );
833
834 let num_frames = spectrum.num_frames;
835 let original_time_len = (num_frames - 1) * self.config.hop_size;
836 let pad_amount = self.config.fft_size / 2;
837 let padded_len = original_time_len + 2 * pad_amount;
838
839 let mut overlap_buffer = vec![T::zero(); padded_len];
840 let mut window_energy = vec![T::zero(); padded_len];
841 let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
842
843 for frame_idx in 0..num_frames {
845 let pos = frame_idx * self.config.hop_size;
846 for i in 0..self.config.fft_size {
847 match self.config.reconstruction_mode {
848 ReconstructionMode::Ola => {
849 window_energy[pos + i] = window_energy[pos + i] + self.window[i];
850 }
851 ReconstructionMode::Wola => {
852 window_energy[pos + i] =
853 window_energy[pos + i] + self.window[i] * self.window[i];
854 }
855 }
856 }
857 }
858
859 for frame_idx in 0..num_frames {
861 (0..spectrum.freq_bins).for_each(|bin| {
863 ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
864 });
865
866 for bin in 1..(spectrum.freq_bins - 1) {
868 ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
869 }
870
871 self.ifft.process(&mut ifft_buffer);
873
874 let pos = frame_idx * self.config.hop_size;
876 for i in 0..self.config.fft_size {
877 let fft_size_t = T::from(self.config.fft_size).unwrap();
878 let sample = ifft_buffer[i].re / fft_size_t;
879
880 match self.config.reconstruction_mode {
881 ReconstructionMode::Ola => {
882 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
884 }
885 ReconstructionMode::Wola => {
886 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
888 }
889 }
890 }
891 }
892
893 let threshold = T::from(1e-8).unwrap();
895 for i in 0..padded_len {
896 if window_energy[i] > threshold {
897 overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
898 }
899 }
900
901 overlap_buffer[pad_amount..pad_amount + original_time_len].to_vec()
903 }
904
905 pub fn process_into(&self, spectrum: &Spectrum<T>, output: &mut Vec<T>) {
908 assert_eq!(
909 spectrum.freq_bins,
910 self.config.freq_bins(),
911 "Frequency bins mismatch"
912 );
913
914 let num_frames = spectrum.num_frames;
915 let original_time_len = (num_frames - 1) * self.config.hop_size;
916 let pad_amount = self.config.fft_size / 2;
917 let padded_len = original_time_len + 2 * pad_amount;
918
919 let mut overlap_buffer = vec![T::zero(); padded_len];
920 let mut window_energy = vec![T::zero(); padded_len];
921 let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
922
923 for frame_idx in 0..num_frames {
925 let pos = frame_idx * self.config.hop_size;
926 for i in 0..self.config.fft_size {
927 match self.config.reconstruction_mode {
928 ReconstructionMode::Ola => {
929 window_energy[pos + i] = window_energy[pos + i] + self.window[i];
930 }
931 ReconstructionMode::Wola => {
932 window_energy[pos + i] =
933 window_energy[pos + i] + self.window[i] * self.window[i];
934 }
935 }
936 }
937 }
938
939 for frame_idx in 0..num_frames {
941 (0..spectrum.freq_bins).for_each(|bin| {
943 ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
944 });
945
946 for bin in 1..(spectrum.freq_bins - 1) {
948 ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
949 }
950
951 self.ifft.process(&mut ifft_buffer);
953
954 let pos = frame_idx * self.config.hop_size;
956 for i in 0..self.config.fft_size {
957 let fft_size_t = T::from(self.config.fft_size).unwrap();
958 let sample = ifft_buffer[i].re / fft_size_t;
959
960 match self.config.reconstruction_mode {
961 ReconstructionMode::Ola => {
962 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
963 }
964 ReconstructionMode::Wola => {
965 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
966 }
967 }
968 }
969 }
970
971 let threshold = T::from(1e-8).unwrap();
973 for i in 0..padded_len {
974 if window_energy[i] > threshold {
975 overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
976 }
977 }
978
979 output.clear();
981 output.extend_from_slice(&overlap_buffer[pad_amount..pad_amount + original_time_len]);
982 }
983
984 pub fn process_multichannel(&self, spectra: &[Spectrum<T>]) -> Vec<Vec<T>> {
1014 assert!(!spectra.is_empty(), "spectra must not be empty");
1015
1016 #[cfg(feature = "rayon")]
1018 {
1019 use rayon::prelude::*;
1020 spectra
1021 .par_iter()
1022 .map(|spectrum| self.process(spectrum))
1023 .collect()
1024 }
1025 #[cfg(not(feature = "rayon"))]
1026 {
1027 spectra
1028 .iter()
1029 .map(|spectrum| self.process(spectrum))
1030 .collect()
1031 }
1032 }
1033
1034 pub fn process_multichannel_interleaved(&self, spectra: &[Spectrum<T>]) -> Vec<T> {
1064 let channels = self.process_multichannel(spectra);
1065 utils::interleave(&channels)
1066 }
1067}
1068
1069pub struct StreamingStft<T: Float + FftNum> {
1070 config: StftConfig<T>,
1071 window: Vec<T>,
1072 fft: Arc<dyn Fft<T>>,
1073 input_buffer: VecDeque<T>,
1074 frame_index: usize,
1075 fft_buffer: Vec<Complex<T>>,
1076}
1077
1078impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingStft<T> {
1079 pub fn new(config: StftConfig<T>) -> Self {
1080 let window = config.generate_window();
1081 let mut planner = FftPlanner::new();
1082 let fft = planner.plan_fft_forward(config.fft_size);
1083 let fft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1084
1085 Self {
1086 config,
1087 window,
1088 fft,
1089 input_buffer: VecDeque::new(),
1090 frame_index: 0,
1091 fft_buffer,
1092 }
1093 }
1094
1095 pub fn push_samples(&mut self, samples: &[T]) -> Vec<SpectrumFrame<T>> {
1096 self.input_buffer.extend(samples.iter().copied());
1097
1098 let mut frames = Vec::new();
1099
1100 while self.input_buffer.len() >= self.config.fft_size {
1101 for i in 0..self.config.fft_size {
1103 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1104 }
1105
1106 self.fft.process(&mut self.fft_buffer);
1107
1108 let freq_bins = self.config.freq_bins();
1109 let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1110 frames.push(SpectrumFrame::from_data(data));
1111
1112 self.input_buffer.drain(..self.config.hop_size);
1114 self.frame_index += 1;
1115 }
1116
1117 frames
1118 }
1119
1120 pub fn push_samples_into(
1123 &mut self,
1124 samples: &[T],
1125 output: &mut Vec<SpectrumFrame<T>>,
1126 ) -> usize {
1127 self.input_buffer.extend(samples.iter().copied());
1128
1129 let initial_len = output.len();
1130
1131 while self.input_buffer.len() >= self.config.fft_size {
1132 for i in 0..self.config.fft_size {
1134 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1135 }
1136
1137 self.fft.process(&mut self.fft_buffer);
1138
1139 let freq_bins = self.config.freq_bins();
1140 let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1141 output.push(SpectrumFrame::from_data(data));
1142
1143 self.input_buffer.drain(..self.config.hop_size);
1145 self.frame_index += 1;
1146 }
1147
1148 output.len() - initial_len
1149 }
1150
1151 pub fn push_samples_write(
1164 &mut self,
1165 samples: &[T],
1166 frame_pool: &mut [SpectrumFrame<T>],
1167 pool_index: &mut usize,
1168 ) -> usize {
1169 self.input_buffer.extend(samples.iter().copied());
1170
1171 let initial_index = *pool_index;
1172 let freq_bins = self.config.freq_bins();
1173
1174 while self.input_buffer.len() >= self.config.fft_size && *pool_index < frame_pool.len() {
1175 for i in 0..self.config.fft_size {
1177 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1178 }
1179
1180 self.fft.process(&mut self.fft_buffer);
1181
1182 let frame = &mut frame_pool[*pool_index];
1184 debug_assert_eq!(
1185 frame.freq_bins, freq_bins,
1186 "Frame pool frames must match freq_bins"
1187 );
1188 frame.data[..freq_bins].copy_from_slice(&self.fft_buffer[..freq_bins]);
1189
1190 self.input_buffer.drain(..self.config.hop_size);
1192 self.frame_index += 1;
1193 *pool_index += 1;
1194 }
1195
1196 *pool_index - initial_index
1197 }
1198
1199 pub fn flush(&mut self) -> Vec<SpectrumFrame<T>> {
1200 Vec::new()
1203 }
1204
1205 pub fn reset(&mut self) {
1206 self.input_buffer.clear();
1207 self.frame_index = 0;
1208 }
1209
1210 pub fn buffered_samples(&self) -> usize {
1211 self.input_buffer.len()
1212 }
1213}
1214
1215pub struct MultiChannelStreamingStft<T: Float + FftNum> {
1217 processors: Vec<StreamingStft<T>>,
1218}
1219
1220impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingStft<T> {
1221 pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1228 assert!(num_channels > 0, "num_channels must be > 0");
1229 let processors = (0..num_channels)
1230 .map(|_| StreamingStft::new(config.clone()))
1231 .collect();
1232 Self { processors }
1233 }
1234
1235 pub fn push_samples(&mut self, channels: &[&[T]]) -> Vec<Vec<SpectrumFrame<T>>> {
1246 assert_eq!(
1247 channels.len(),
1248 self.processors.len(),
1249 "Expected {} channels, got {}",
1250 self.processors.len(),
1251 channels.len()
1252 );
1253
1254 #[cfg(feature = "rayon")]
1255 {
1256 use rayon::prelude::*;
1257 self.processors
1258 .par_iter_mut()
1259 .zip(channels.par_iter())
1260 .map(|(stft, channel)| stft.push_samples(channel))
1261 .collect()
1262 }
1263 #[cfg(not(feature = "rayon"))]
1264 {
1265 self.processors
1266 .iter_mut()
1267 .zip(channels.iter())
1268 .map(|(stft, channel)| stft.push_samples(channel))
1269 .collect()
1270 }
1271 }
1272
1273 pub fn flush(&mut self) -> Vec<Vec<SpectrumFrame<T>>> {
1275 #[cfg(feature = "rayon")]
1276 {
1277 use rayon::prelude::*;
1278 self.processors
1279 .par_iter_mut()
1280 .map(|stft| stft.flush())
1281 .collect()
1282 }
1283 #[cfg(not(feature = "rayon"))]
1284 {
1285 self.processors
1286 .iter_mut()
1287 .map(|stft| stft.flush())
1288 .collect()
1289 }
1290 }
1291
1292 pub fn reset(&mut self) {
1294 #[cfg(feature = "rayon")]
1295 {
1296 use rayon::prelude::*;
1297 self.processors.par_iter_mut().for_each(|stft| stft.reset());
1298 }
1299 #[cfg(not(feature = "rayon"))]
1300 {
1301 self.processors.iter_mut().for_each(|stft| stft.reset());
1302 }
1303 }
1304
1305 pub fn num_channels(&self) -> usize {
1307 self.processors.len()
1308 }
1309}
1310
1311pub struct StreamingIstft<T: Float + FftNum> {
1312 config: StftConfig<T>,
1313 window: Vec<T>,
1314 ifft: Arc<dyn Fft<T>>,
1315 overlap_buffer: Vec<T>,
1316 window_energy: Vec<T>,
1317 output_position: usize,
1318 frames_processed: usize,
1319 ifft_buffer: Vec<Complex<T>>,
1320}
1321
1322impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingIstft<T> {
1323 pub fn new(config: StftConfig<T>) -> Self {
1324 let window = config.generate_window();
1325 let mut planner = FftPlanner::new();
1326 let ifft = planner.plan_fft_inverse(config.fft_size);
1327
1328 let buffer_size = config.fft_size * 2;
1331 let ifft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1332
1333 Self {
1334 config,
1335 window,
1336 ifft,
1337 overlap_buffer: vec![T::zero(); buffer_size],
1338 window_energy: vec![T::zero(); buffer_size],
1339 output_position: 0,
1340 frames_processed: 0,
1341 ifft_buffer,
1342 }
1343 }
1344
1345 pub fn push_frame(&mut self, frame: &SpectrumFrame<T>) -> Vec<T> {
1346 assert_eq!(
1347 frame.freq_bins,
1348 self.config.freq_bins(),
1349 "Frequency bins mismatch"
1350 );
1351
1352 for bin in 0..frame.freq_bins {
1354 self.ifft_buffer[bin] = frame.data[bin];
1355 }
1356
1357 for bin in 1..(frame.freq_bins - 1) {
1359 self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1360 }
1361
1362 self.ifft.process(&mut self.ifft_buffer);
1364
1365 let write_pos = self.frames_processed * self.config.hop_size;
1367 for i in 0..self.config.fft_size {
1368 let fft_size_t = T::from(self.config.fft_size).unwrap();
1369 let sample = self.ifft_buffer[i].re / fft_size_t;
1370 let buf_idx = write_pos + i;
1371
1372 if buf_idx >= self.overlap_buffer.len() {
1374 self.overlap_buffer.resize(buf_idx + 1, T::zero());
1375 self.window_energy.resize(buf_idx + 1, T::zero());
1376 }
1377
1378 match self.config.reconstruction_mode {
1379 ReconstructionMode::Ola => {
1380 self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1381 self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1382 }
1383 ReconstructionMode::Wola => {
1384 self.overlap_buffer[buf_idx] =
1385 self.overlap_buffer[buf_idx] + sample * self.window[i];
1386 self.window_energy[buf_idx] =
1387 self.window_energy[buf_idx] + self.window[i] * self.window[i];
1388 }
1389 }
1390 }
1391
1392 self.frames_processed += 1;
1393
1394 let ready_until = if self.frames_processed == 1 {
1397 0 } else {
1399 (self.frames_processed - 1) * self.config.hop_size
1401 };
1402
1403 let output_start = self.output_position;
1405 let output_end = ready_until;
1406 let mut output = Vec::new();
1407
1408 let threshold = T::from(1e-8).unwrap();
1409 if output_end > output_start {
1410 for i in output_start..output_end {
1411 let normalized = if self.window_energy[i] > threshold {
1412 self.overlap_buffer[i] / self.window_energy[i]
1413 } else {
1414 T::zero()
1415 };
1416 output.push(normalized);
1417 }
1418 self.output_position = output_end;
1419 }
1420
1421 output
1422 }
1423
1424 pub fn push_frame_into(&mut self, frame: &SpectrumFrame<T>, output: &mut Vec<T>) -> usize {
1427 assert_eq!(
1428 frame.freq_bins,
1429 self.config.freq_bins(),
1430 "Frequency bins mismatch"
1431 );
1432
1433 for bin in 0..frame.freq_bins {
1435 self.ifft_buffer[bin] = frame.data[bin];
1436 }
1437
1438 for bin in 1..(frame.freq_bins - 1) {
1440 self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1441 }
1442
1443 self.ifft.process(&mut self.ifft_buffer);
1445
1446 let write_pos = self.frames_processed * self.config.hop_size;
1448 for i in 0..self.config.fft_size {
1449 let fft_size_t = T::from(self.config.fft_size).unwrap();
1450 let sample = self.ifft_buffer[i].re / fft_size_t;
1451 let buf_idx = write_pos + i;
1452
1453 if buf_idx >= self.overlap_buffer.len() {
1455 self.overlap_buffer.resize(buf_idx + 1, T::zero());
1456 self.window_energy.resize(buf_idx + 1, T::zero());
1457 }
1458
1459 match self.config.reconstruction_mode {
1460 ReconstructionMode::Ola => {
1461 self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1462 self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1463 }
1464 ReconstructionMode::Wola => {
1465 self.overlap_buffer[buf_idx] =
1466 self.overlap_buffer[buf_idx] + sample * self.window[i];
1467 self.window_energy[buf_idx] =
1468 self.window_energy[buf_idx] + self.window[i] * self.window[i];
1469 }
1470 }
1471 }
1472
1473 self.frames_processed += 1;
1474
1475 let ready_until = if self.frames_processed == 1 {
1478 0 } else {
1480 (self.frames_processed - 1) * self.config.hop_size
1482 };
1483
1484 let output_start = self.output_position;
1486 let output_end = ready_until;
1487 let initial_len = output.len();
1488
1489 let threshold = T::from(1e-8).unwrap();
1490 if output_end > output_start {
1491 for i in output_start..output_end {
1492 let normalized = if self.window_energy[i] > threshold {
1493 self.overlap_buffer[i] / self.window_energy[i]
1494 } else {
1495 T::zero()
1496 };
1497 output.push(normalized);
1498 }
1499 self.output_position = output_end;
1500 }
1501
1502 output.len() - initial_len
1503 }
1504
1505 pub fn flush(&mut self) -> Vec<T> {
1506 let mut output = Vec::new();
1508 let threshold = T::from(1e-8).unwrap();
1509 for i in self.output_position..self.overlap_buffer.len() {
1510 if self.window_energy[i] > threshold {
1511 output.push(self.overlap_buffer[i] / self.window_energy[i]);
1512 } else if i < (self.frames_processed * self.config.hop_size + self.config.fft_size) {
1513 output.push(T::zero()); } else {
1515 break; }
1517 }
1518
1519 let valid_end =
1521 (self.frames_processed.saturating_sub(1)) * self.config.hop_size + self.config.fft_size;
1522 if output.len() > valid_end - self.output_position {
1523 output.truncate(valid_end - self.output_position);
1524 }
1525
1526 self.reset();
1527 output
1528 }
1529
1530 pub fn reset(&mut self) {
1531 self.overlap_buffer.clear();
1532 self.overlap_buffer
1533 .resize(self.config.fft_size * 2, T::zero());
1534 self.window_energy.clear();
1535 self.window_energy
1536 .resize(self.config.fft_size * 2, T::zero());
1537 self.output_position = 0;
1538 self.frames_processed = 0;
1539 }
1540}
1541
1542pub struct MultiChannelStreamingIstft<T: Float + FftNum> {
1544 processors: Vec<StreamingIstft<T>>,
1545}
1546
1547impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingIstft<T> {
1548 pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1555 assert!(num_channels > 0, "num_channels must be > 0");
1556 let processors = (0..num_channels)
1557 .map(|_| StreamingIstft::new(config.clone()))
1558 .collect();
1559 Self { processors }
1560 }
1561
1562 pub fn push_frames(&mut self, frames: &[&SpectrumFrame<T>]) -> Vec<Vec<T>> {
1573 assert_eq!(
1574 frames.len(),
1575 self.processors.len(),
1576 "Expected {} channels, got {}",
1577 self.processors.len(),
1578 frames.len()
1579 );
1580
1581 #[cfg(feature = "rayon")]
1582 {
1583 use rayon::prelude::*;
1584 self.processors
1585 .par_iter_mut()
1586 .zip(frames.par_iter())
1587 .map(|(istft, frame)| istft.push_frame(frame))
1588 .collect()
1589 }
1590 #[cfg(not(feature = "rayon"))]
1591 {
1592 self.processors
1593 .iter_mut()
1594 .zip(frames.iter())
1595 .map(|(istft, frame)| istft.push_frame(frame))
1596 .collect()
1597 }
1598 }
1599
1600 pub fn flush(&mut self) -> Vec<Vec<T>> {
1602 #[cfg(feature = "rayon")]
1603 {
1604 use rayon::prelude::*;
1605 self.processors
1606 .par_iter_mut()
1607 .map(|istft| istft.flush())
1608 .collect()
1609 }
1610 #[cfg(not(feature = "rayon"))]
1611 {
1612 self.processors
1613 .iter_mut()
1614 .map(|istft| istft.flush())
1615 .collect()
1616 }
1617 }
1618
1619 pub fn reset(&mut self) {
1621 #[cfg(feature = "rayon")]
1622 {
1623 use rayon::prelude::*;
1624 self.processors
1625 .par_iter_mut()
1626 .for_each(|istft| istft.reset());
1627 }
1628 #[cfg(not(feature = "rayon"))]
1629 {
1630 self.processors.iter_mut().for_each(|istft| istft.reset());
1631 }
1632 }
1633
1634 pub fn num_channels(&self) -> usize {
1636 self.processors.len()
1637 }
1638}
1639
1640pub type StftConfigF32 = StftConfig<f32>;
1642pub type StftConfigF64 = StftConfig<f64>;
1643
1644pub type StftConfigBuilderF32 = StftConfigBuilder<f32>;
1645pub type StftConfigBuilderF64 = StftConfigBuilder<f64>;
1646
1647pub type BatchStftF32 = BatchStft<f32>;
1648pub type BatchStftF64 = BatchStft<f64>;
1649
1650pub type BatchIstftF32 = BatchIstft<f32>;
1651pub type BatchIstftF64 = BatchIstft<f64>;
1652
1653pub type StreamingStftF32 = StreamingStft<f32>;
1654pub type StreamingStftF64 = StreamingStft<f64>;
1655
1656pub type StreamingIstftF32 = StreamingIstft<f32>;
1657pub type StreamingIstftF64 = StreamingIstft<f64>;
1658
1659pub type SpectrumF32 = Spectrum<f32>;
1660pub type SpectrumF64 = Spectrum<f64>;
1661
1662pub type SpectrumFrameF32 = SpectrumFrame<f32>;
1663pub type SpectrumFrameF64 = SpectrumFrame<f64>;
1664
1665pub type MultiChannelStreamingStftF32 = MultiChannelStreamingStft<f32>;
1666pub type MultiChannelStreamingStftF64 = MultiChannelStreamingStft<f64>;
1667
1668pub type MultiChannelStreamingIstftF32 = MultiChannelStreamingIstft<f32>;
1669pub type MultiChannelStreamingIstftF64 = MultiChannelStreamingIstft<f64>;