1#![cfg_attr(not(feature = "std"), no_std)]
25
26#[cfg(not(feature = "std"))]
27extern crate alloc;
28
29#[cfg(not(feature = "std"))]
30use alloc::{collections::VecDeque, sync::Arc, vec, vec::Vec};
31
32#[cfg(feature = "std")]
33use std::{collections::VecDeque, sync::Arc, vec};
34
35use core::fmt;
36use core::marker::PhantomData;
37use num_traits::{Float, FromPrimitive};
38
39pub mod fft_backend;
40use fft_backend::{Complex, FftBackend, FftNum, FftPlanner, FftPlannerTrait};
41
42mod utils;
43pub use utils::{apply_padding, deinterleave, deinterleave_into, interleave, interleave_into};
44
45pub mod mel;
46
47pub mod prelude {
48 pub use crate::fft_backend::Complex;
49 pub use crate::mel::{
50 BatchMelSpectrogram, BatchMelSpectrogramF32, BatchMelSpectrogramF64, MelConfig,
51 MelConfigF32, MelConfigF64, MelFilterbank, MelFilterbankF32, MelFilterbankF64, MelNorm,
52 MelScale, MelSpectrum, MelSpectrumF32, MelSpectrumF64, StreamingMelSpectrogram,
53 StreamingMelSpectrogramF32, StreamingMelSpectrogramF64,
54 };
55 pub use crate::utils::{
56 apply_padding, deinterleave, deinterleave_into, interleave, interleave_into,
57 };
58 pub use crate::{
59 BatchIstft, BatchIstftF32, BatchIstftF64, BatchStft, BatchStftF32, BatchStftF64,
60 MultiChannelStreamingIstft, MultiChannelStreamingIstftF32, MultiChannelStreamingIstftF64,
61 MultiChannelStreamingStft, MultiChannelStreamingStftF32, MultiChannelStreamingStftF64,
62 PadMode, ReconstructionMode, Spectrum, SpectrumF32, SpectrumF64, SpectrumFrame,
63 SpectrumFrameF32, SpectrumFrameF64, StftConfig, StftConfigBuilder, StftConfigBuilderF32,
64 StftConfigBuilderF64, StftConfigF32, StftConfigF64, StreamingIstft, StreamingIstftF32,
65 StreamingIstftF64, StreamingStft, StreamingStftF32, StreamingStftF64, WindowType,
66 };
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
70pub enum ReconstructionMode {
71 Ola,
73
74 Wola,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
79pub enum WindowType {
80 Hann,
81 Hamming,
82 Blackman,
83}
84
85#[derive(Debug, Clone)]
86pub enum ConfigError<T: Float + fmt::Debug> {
87 NolaViolation { min_energy: T, threshold: T },
88 ColaViolation { max_deviation: T, threshold: T },
89 InvalidHopSize,
90 InvalidFftSize,
91}
92
93impl<T: Float + fmt::Display + fmt::Debug> fmt::Display for ConfigError<T> {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 match self {
96 ConfigError::NolaViolation {
97 min_energy,
98 threshold,
99 } => {
100 write!(
101 f,
102 "NOLA condition violated: min_energy={} < threshold={}",
103 min_energy, threshold
104 )
105 }
106 ConfigError::ColaViolation {
107 max_deviation,
108 threshold,
109 } => {
110 write!(
111 f,
112 "COLA condition violated: max_deviation={} > threshold={}",
113 max_deviation, threshold
114 )
115 }
116 ConfigError::InvalidHopSize => write!(f, "Invalid hop size"),
117 ConfigError::InvalidFftSize => write!(f, "Invalid FFT size"),
118 }
119 }
120}
121
122#[cfg(feature = "std")]
123impl<T: Float + fmt::Display + fmt::Debug> std::error::Error for ConfigError<T> {}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
126pub enum PadMode {
127 Reflect,
128 Zero,
129 Edge,
130}
131
132#[derive(Debug, Clone, PartialEq)]
133pub struct StftConfig<T: Float> {
134 pub fft_size: usize,
135 pub hop_size: usize,
136 pub window: WindowType,
137 pub reconstruction_mode: ReconstructionMode,
138 _phantom: PhantomData<T>,
139}
140
141impl<T: Float + FromPrimitive + fmt::Debug> StftConfig<T> {
142 fn nola_threshold() -> T {
143 T::from(1e-8).unwrap()
144 }
145
146 fn cola_relative_tolerance() -> T {
147 T::from(1e-4).unwrap()
148 }
149
150 #[deprecated(
151 since = "0.4.0",
152 note = "Use `StftConfig::builder()` instead for a more flexible API"
153 )]
154 pub fn new(
155 fft_size: usize,
156 hop_size: usize,
157 window: WindowType,
158 reconstruction_mode: ReconstructionMode,
159 ) -> Result<Self, ConfigError<T>> {
160 if fft_size == 0 || !(cfg!(feature = "rustfft-backend") || fft_size.is_power_of_two()) {
161 return Err(ConfigError::InvalidFftSize);
162 }
163 if hop_size == 0 || hop_size > fft_size {
164 return Err(ConfigError::InvalidHopSize);
165 }
166
167 let config = Self {
168 fft_size,
169 hop_size,
170 window,
171 reconstruction_mode,
172 _phantom: PhantomData,
173 };
174
175 match reconstruction_mode {
177 ReconstructionMode::Ola => config.validate_cola()?,
178 ReconstructionMode::Wola => config.validate_nola()?,
179 }
180
181 Ok(config)
182 }
183
184 pub fn builder() -> StftConfigBuilder<T> {
186 StftConfigBuilder::new()
187 }
188
189 #[allow(deprecated)]
191 pub fn default_4096() -> Self {
192 Self::new(4096, 1024, WindowType::Hann, ReconstructionMode::Ola)
193 .expect("Default config should always be valid")
194 }
195
196 pub fn freq_bins(&self) -> usize {
197 self.fft_size / 2 + 1
198 }
199
200 pub fn overlap_percent(&self) -> T {
201 let one = T::one();
202 let hundred = T::from(100.0).unwrap();
203 (one - T::from(self.hop_size).unwrap() / T::from(self.fft_size).unwrap()) * hundred
204 }
205
206 fn generate_window(&self) -> Vec<T> {
207 generate_window(self.window, self.fft_size)
208 }
209
210 pub fn validate_nola(&self) -> Result<(), ConfigError<T>> {
212 let window = self.generate_window();
213 let num_overlaps = self.fft_size.div_ceil(self.hop_size);
214 let test_len = self.fft_size + (num_overlaps - 1) * self.hop_size;
215 let mut energy = vec![T::zero(); test_len];
216
217 for i in 0..num_overlaps {
218 let offset = i * self.hop_size;
219 for j in 0..self.fft_size {
220 if offset + j < test_len {
221 energy[offset + j] = energy[offset + j] + window[j] * window[j];
222 }
223 }
224 }
225
226 let start = self.fft_size / 2;
228 let end = test_len - self.fft_size / 2;
229 let min_energy = energy[start..end]
230 .iter()
231 .copied()
232 .min_by(|a, b| a.partial_cmp(b).unwrap())
233 .unwrap_or_else(T::zero);
234
235 if min_energy < Self::nola_threshold() {
236 return Err(ConfigError::NolaViolation {
237 min_energy,
238 threshold: Self::nola_threshold(),
239 });
240 }
241
242 Ok(())
243 }
244
245 pub fn validate_cola(&self) -> Result<(), ConfigError<T>> {
247 let window = self.generate_window();
248 let window_len = window.len();
249
250 let mut cola_sum_period = vec![T::zero(); self.hop_size];
251 (0..window_len).for_each(|i| {
252 let idx = i % self.hop_size;
253 cola_sum_period[idx] = cola_sum_period[idx] + window[i];
254 });
255
256 let zero = T::zero();
257 let min_sum = cola_sum_period
258 .iter()
259 .min_by(|a, b| a.partial_cmp(b).unwrap())
260 .unwrap_or(&zero);
261 let max_sum = cola_sum_period
262 .iter()
263 .max_by(|a, b| a.partial_cmp(b).unwrap())
264 .unwrap_or(&zero);
265
266 let epsilon = T::from(1e-9).unwrap();
267 if *max_sum < epsilon {
268 return Err(ConfigError::ColaViolation {
269 max_deviation: T::infinity(),
270 threshold: Self::cola_relative_tolerance(),
271 });
272 }
273
274 let ripple = (*max_sum - *min_sum) / *max_sum;
275
276 let is_compliant = ripple < Self::cola_relative_tolerance();
277
278 if !is_compliant {
279 return Err(ConfigError::ColaViolation {
280 max_deviation: ripple,
281 threshold: Self::cola_relative_tolerance(),
282 });
283 }
284 Ok(())
285 }
286}
287
288#[derive(Debug, Clone, PartialEq)]
290pub struct StftConfigBuilder<T: Float> {
291 fft_size: Option<usize>,
292 hop_size: Option<usize>,
293 window: WindowType,
294 reconstruction_mode: ReconstructionMode,
295 _phantom: PhantomData<T>,
296}
297
298impl<T: Float + FromPrimitive + fmt::Debug> StftConfigBuilder<T> {
299 pub fn new() -> Self {
301 Self {
302 fft_size: None,
303 hop_size: None,
304 window: WindowType::Hann,
305 reconstruction_mode: ReconstructionMode::Ola,
306 _phantom: PhantomData,
307 }
308 }
309
310 pub fn fft_size(mut self, fft_size: usize) -> Self {
312 self.fft_size = Some(fft_size);
313 self
314 }
315
316 pub fn hop_size(mut self, hop_size: usize) -> Self {
318 self.hop_size = Some(hop_size);
319 self
320 }
321
322 pub fn window(mut self, window: WindowType) -> Self {
324 self.window = window;
325 self
326 }
327
328 pub fn reconstruction_mode(mut self, mode: ReconstructionMode) -> Self {
330 self.reconstruction_mode = mode;
331 self
332 }
333
334 #[allow(deprecated)]
341 pub fn build(self) -> Result<StftConfig<T>, ConfigError<T>> {
342 let fft_size = self.fft_size.ok_or(ConfigError::InvalidFftSize)?;
343 let hop_size = self.hop_size.ok_or(ConfigError::InvalidHopSize)?;
344
345 StftConfig::new(fft_size, hop_size, self.window, self.reconstruction_mode)
346 }
347}
348
349impl<T: Float + FromPrimitive + fmt::Debug> Default for StftConfigBuilder<T> {
350 fn default() -> Self {
351 Self::new()
352 }
353}
354
355fn generate_window<T: Float + FromPrimitive>(window_type: WindowType, size: usize) -> Vec<T> {
356 let pi = T::from(core::f64::consts::PI).unwrap();
357 let two = T::from(2.0).unwrap();
358
359 match window_type {
360 WindowType::Hann => (0..size)
361 .map(|i| {
362 let half = T::from(0.5).unwrap();
363 let one = T::one();
364 let i_t = T::from(i).unwrap();
365 let size_t = T::from(size).unwrap(); half * (one - (two * pi * i_t / size_t).cos())
367 })
368 .collect(),
369 WindowType::Hamming => (0..size)
370 .map(|i| {
371 let i_t = T::from(i).unwrap();
372 let size_t = T::from(size).unwrap(); T::from(0.54).unwrap() - T::from(0.46).unwrap() * (two * pi * i_t / size_t).cos()
374 })
375 .collect(),
376 WindowType::Blackman => (0..size)
377 .map(|i| {
378 let i_t = T::from(i).unwrap();
379 let size_t = T::from(size).unwrap(); let angle = two * pi * i_t / size_t;
381 T::from(0.42).unwrap() - T::from(0.5).unwrap() * angle.cos()
382 + T::from(0.08).unwrap() * (two * angle).cos()
383 })
384 .collect(),
385 }
386}
387
388#[derive(Debug, Clone, PartialEq)]
389pub struct SpectrumFrame<T: Float> {
390 pub freq_bins: usize,
391 pub data: Vec<Complex<T>>,
392}
393
394impl<T: Float> SpectrumFrame<T> {
395 pub fn new(freq_bins: usize) -> Self {
396 Self {
397 freq_bins,
398 data: vec![Complex::new(T::zero(), T::zero()); freq_bins],
399 }
400 }
401
402 pub fn from_data(data: Vec<Complex<T>>) -> Self {
403 let freq_bins = data.len();
404 Self { freq_bins, data }
405 }
406
407 pub fn clear(&mut self) {
409 for val in &mut self.data {
410 *val = Complex::new(T::zero(), T::zero());
411 }
412 }
413
414 pub fn resize_if_needed(&mut self, freq_bins: usize) {
416 if self.freq_bins != freq_bins {
417 self.freq_bins = freq_bins;
418 self.data
419 .resize(freq_bins, Complex::new(T::zero(), T::zero()));
420 }
421 }
422
423 pub fn write_from_slice(&mut self, data: &[Complex<T>]) {
425 self.resize_if_needed(data.len());
426 self.data.copy_from_slice(data);
427 }
428
429 #[inline]
431 pub fn magnitude(&self, bin: usize) -> T {
432 let c = &self.data[bin];
433 (c.re * c.re + c.im * c.im).sqrt()
434 }
435
436 #[inline]
438 pub fn phase(&self, bin: usize) -> T {
439 let c = &self.data[bin];
440 c.im.atan2(c.re)
441 }
442
443 pub fn set_magnitude_phase(&mut self, bin: usize, magnitude: T, phase: T) {
445 self.data[bin] = Complex::new(magnitude * phase.cos(), magnitude * phase.sin());
446 }
447
448 pub fn from_magnitude_phase(magnitudes: &[T], phases: &[T]) -> Self {
450 assert_eq!(
451 magnitudes.len(),
452 phases.len(),
453 "Magnitude and phase arrays must have same length"
454 );
455 let freq_bins = magnitudes.len();
456 let data: Vec<Complex<T>> = magnitudes
457 .iter()
458 .zip(phases.iter())
459 .map(|(mag, phase)| Complex::new(*mag * phase.cos(), *mag * phase.sin()))
460 .collect();
461 Self { freq_bins, data }
462 }
463
464 pub fn magnitudes(&self) -> Vec<T> {
466 self.data
467 .iter()
468 .map(|c| (c.re * c.re + c.im * c.im).sqrt())
469 .collect()
470 }
471
472 pub fn phases(&self) -> Vec<T> {
474 self.data.iter().map(|c| c.im.atan2(c.re)).collect()
475 }
476}
477
478#[derive(Debug, Clone, PartialEq)]
479pub struct Spectrum<T: Float> {
480 pub num_frames: usize,
481 pub freq_bins: usize,
482 pub data: Vec<T>,
483}
484
485impl<T: Float> Spectrum<T> {
486 pub fn new(num_frames: usize, freq_bins: usize) -> Self {
487 Self {
488 num_frames,
489 freq_bins,
490 data: vec![T::zero(); 2 * num_frames * freq_bins],
491 }
492 }
493
494 #[inline]
495 pub fn real(&self, frame: usize, bin: usize) -> T {
496 self.data[frame * self.freq_bins + bin]
497 }
498
499 #[inline]
500 pub fn imag(&self, frame: usize, bin: usize) -> T {
501 let offset = self.num_frames * self.freq_bins;
502 self.data[offset + frame * self.freq_bins + bin]
503 }
504
505 #[inline]
506 pub fn get_complex(&self, frame: usize, bin: usize) -> Complex<T> {
507 Complex::new(self.real(frame, bin), self.imag(frame, bin))
508 }
509
510 pub fn frames(&self) -> impl Iterator<Item = SpectrumFrame<T>> + '_ {
511 (0..self.num_frames).map(move |frame_idx| {
512 let data: Vec<Complex<T>> = (0..self.freq_bins)
513 .map(|bin| self.get_complex(frame_idx, bin))
514 .collect();
515 SpectrumFrame::from_data(data)
516 })
517 }
518
519 #[inline]
521 pub fn set_real(&mut self, frame: usize, bin: usize, value: T) {
522 self.data[frame * self.freq_bins + bin] = value;
523 }
524
525 #[inline]
527 pub fn set_imag(&mut self, frame: usize, bin: usize, value: T) {
528 let offset = self.num_frames * self.freq_bins;
529 self.data[offset + frame * self.freq_bins + bin] = value;
530 }
531
532 #[inline]
534 pub fn set_complex(&mut self, frame: usize, bin: usize, value: Complex<T>) {
535 self.set_real(frame, bin, value.re);
536 self.set_imag(frame, bin, value.im);
537 }
538
539 #[inline]
541 pub fn magnitude(&self, frame: usize, bin: usize) -> T {
542 let re = self.real(frame, bin);
543 let im = self.imag(frame, bin);
544 (re * re + im * im).sqrt()
545 }
546
547 #[inline]
549 pub fn phase(&self, frame: usize, bin: usize) -> T {
550 let re = self.real(frame, bin);
551 let im = self.imag(frame, bin);
552 im.atan2(re)
553 }
554
555 pub fn set_magnitude_phase(&mut self, frame: usize, bin: usize, magnitude: T, phase: T) {
557 self.set_real(frame, bin, magnitude * phase.cos());
558 self.set_imag(frame, bin, magnitude * phase.sin());
559 }
560
561 pub fn frame_magnitudes(&self, frame: usize) -> Vec<T> {
563 (0..self.freq_bins)
564 .map(|bin| self.magnitude(frame, bin))
565 .collect()
566 }
567
568 pub fn frame_phases(&self, frame: usize) -> Vec<T> {
570 (0..self.freq_bins)
571 .map(|bin| self.phase(frame, bin))
572 .collect()
573 }
574
575 pub fn apply<F>(&mut self, mut f: F)
577 where
578 F: FnMut(usize, usize, Complex<T>) -> Complex<T>,
579 {
580 for frame in 0..self.num_frames {
581 for bin in 0..self.freq_bins {
582 let c = self.get_complex(frame, bin);
583 let new_c = f(frame, bin, c);
584 self.set_complex(frame, bin, new_c);
585 }
586 }
587 }
588
589 pub fn apply_gain(&mut self, bin_range: core::ops::Range<usize>, gain: T) {
591 for frame in 0..self.num_frames {
592 for bin in bin_range.clone() {
593 if bin < self.freq_bins {
594 let c = self.get_complex(frame, bin);
595 self.set_complex(frame, bin, c * gain);
596 }
597 }
598 }
599 }
600
601 pub fn zero_bins(&mut self, bin_range: core::ops::Range<usize>) {
603 for frame in 0..self.num_frames {
604 for bin in bin_range.clone() {
605 if bin < self.freq_bins {
606 self.set_complex(frame, bin, Complex::new(T::zero(), T::zero()));
607 }
608 }
609 }
610 }
611}
612
613#[derive(Debug, Clone)]
614pub struct BatchStft<T: Float + FftNum> {
615 config: StftConfig<T>,
616 window: Vec<T>,
617 fft: Arc<dyn FftBackend<T>>,
618}
619
620impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchStft<T> {
621 pub fn new(config: StftConfig<T>) -> Self
622 where
623 FftPlanner<T>: FftPlannerTrait<T>,
624 {
625 let window = config.generate_window();
626 let mut planner = <FftPlanner<T> as FftPlannerTrait<T>>::new();
627 let fft = planner.plan_fft_forward(config.fft_size);
628
629 Self {
630 config,
631 window,
632 fft,
633 }
634 }
635
636 pub fn process(&self, signal: &[T]) -> Spectrum<T> {
637 self.process_padded(signal, PadMode::Reflect)
638 }
639
640 pub fn process_padded(&self, signal: &[T], pad_mode: PadMode) -> Spectrum<T> {
641 let pad_amount = self.config.fft_size / 2;
642 let padded = utils::apply_padding(signal, pad_amount, pad_mode);
643
644 let num_frames = if padded.len() >= self.config.fft_size {
645 (padded.len() - self.config.fft_size) / self.config.hop_size + 1
646 } else {
647 0
648 };
649
650 let freq_bins = self.config.freq_bins();
651 let mut result = Spectrum::new(num_frames, freq_bins);
652
653 let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
654
655 for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
656 .step_by(self.config.hop_size)
657 .enumerate()
658 {
659 for i in 0..self.config.fft_size {
661 fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
662 }
663
664 self.fft.process(&mut fft_buffer);
666
667 (0..freq_bins).for_each(|bin| {
669 let idx = frame_idx * freq_bins + bin;
670 result.data[idx] = fft_buffer[bin].re;
671 result.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
672 });
673 }
674
675 result
676 }
677
678 pub fn process_into(&self, signal: &[T], spectrum: &mut Spectrum<T>) -> bool {
682 self.process_padded_into(signal, PadMode::Reflect, spectrum)
683 }
684
685 pub fn process_padded_into(
687 &self,
688 signal: &[T],
689 pad_mode: PadMode,
690 spectrum: &mut Spectrum<T>,
691 ) -> bool {
692 let pad_amount = self.config.fft_size / 2;
693 let padded = utils::apply_padding(signal, pad_amount, pad_mode);
694
695 let num_frames = if padded.len() >= self.config.fft_size {
696 (padded.len() - self.config.fft_size) / self.config.hop_size + 1
697 } else {
698 0
699 };
700
701 let freq_bins = self.config.freq_bins();
702
703 if spectrum.num_frames != num_frames || spectrum.freq_bins != freq_bins {
705 return false;
706 }
707
708 let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
709
710 for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
711 .step_by(self.config.hop_size)
712 .enumerate()
713 {
714 for i in 0..self.config.fft_size {
716 fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
717 }
718
719 self.fft.process(&mut fft_buffer);
721
722 (0..freq_bins).for_each(|bin| {
724 let idx = frame_idx * freq_bins + bin;
725 spectrum.data[idx] = fft_buffer[bin].re;
726 spectrum.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
727 });
728 }
729
730 true
731 }
732
733 pub fn process_multichannel(&self, channels: &[Vec<T>]) -> Vec<Spectrum<T>> {
760 assert!(!channels.is_empty(), "channels must not be empty");
761
762 let expected_len = channels[0].len();
764 for (i, channel) in channels.iter().enumerate() {
765 assert_eq!(
766 channel.len(),
767 expected_len,
768 "Channel {} has length {}, expected {}",
769 i,
770 channel.len(),
771 expected_len
772 );
773 }
774
775 #[cfg(feature = "rayon")]
777 {
778 use rayon::prelude::*;
779 channels
780 .par_iter()
781 .map(|channel| self.process(channel))
782 .collect()
783 }
784 #[cfg(not(feature = "rayon"))]
785 {
786 channels
787 .iter()
788 .map(|channel| self.process(channel))
789 .collect()
790 }
791 }
792
793 pub fn process_interleaved(&self, data: &[T], num_channels: usize) -> Vec<Spectrum<T>> {
821 let channels = utils::deinterleave(data, num_channels);
822 self.process_multichannel(&channels)
823 }
824}
825
826#[derive(Debug, Clone)]
827pub struct BatchIstft<T: Float + FftNum> {
828 config: StftConfig<T>,
829 window: Vec<T>,
830 ifft: Arc<dyn FftBackend<T>>,
831}
832
833impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchIstft<T> {
834 pub fn new(config: StftConfig<T>) -> Self
835 where
836 FftPlanner<T>: FftPlannerTrait<T>,
837 {
838 let window = config.generate_window();
839 let mut planner = <FftPlanner<T> as FftPlannerTrait<T>>::new();
840 let ifft = planner.plan_fft_inverse(config.fft_size);
841
842 Self {
843 config,
844 window,
845 ifft,
846 }
847 }
848
849 pub fn process(&self, spectrum: &Spectrum<T>) -> Vec<T> {
850 assert_eq!(
851 spectrum.freq_bins,
852 self.config.freq_bins(),
853 "Frequency bins mismatch"
854 );
855
856 let num_frames = spectrum.num_frames;
857 let original_time_len = (num_frames - 1) * self.config.hop_size;
858 let pad_amount = self.config.fft_size / 2;
859 let padded_len = original_time_len + 2 * pad_amount;
860
861 let mut overlap_buffer = vec![T::zero(); padded_len];
862 let mut window_energy = vec![T::zero(); padded_len];
863 let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
864
865 for frame_idx in 0..num_frames {
867 let pos = frame_idx * self.config.hop_size;
868 for i in 0..self.config.fft_size {
869 match self.config.reconstruction_mode {
870 ReconstructionMode::Ola => {
871 window_energy[pos + i] = window_energy[pos + i] + self.window[i];
872 }
873 ReconstructionMode::Wola => {
874 window_energy[pos + i] =
875 window_energy[pos + i] + self.window[i] * self.window[i];
876 }
877 }
878 }
879 }
880
881 for frame_idx in 0..num_frames {
883 (0..spectrum.freq_bins).for_each(|bin| {
885 ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
886 });
887
888 for bin in 1..(spectrum.freq_bins - 1) {
890 ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
891 }
892
893 self.ifft.process(&mut ifft_buffer);
895
896 let pos = frame_idx * self.config.hop_size;
898 for i in 0..self.config.fft_size {
899 let fft_size_t = T::from(self.config.fft_size).unwrap();
900 let sample = ifft_buffer[i].re / fft_size_t;
901
902 match self.config.reconstruction_mode {
903 ReconstructionMode::Ola => {
904 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
906 }
907 ReconstructionMode::Wola => {
908 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
910 }
911 }
912 }
913 }
914
915 let threshold = T::from(1e-8).unwrap();
917 for i in 0..padded_len {
918 if window_energy[i] > threshold {
919 overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
920 }
921 }
922
923 overlap_buffer[pad_amount..pad_amount + original_time_len].to_vec()
925 }
926
927 pub fn process_into(&self, spectrum: &Spectrum<T>, output: &mut Vec<T>) {
930 assert_eq!(
931 spectrum.freq_bins,
932 self.config.freq_bins(),
933 "Frequency bins mismatch"
934 );
935
936 let num_frames = spectrum.num_frames;
937 let original_time_len = (num_frames - 1) * self.config.hop_size;
938 let pad_amount = self.config.fft_size / 2;
939 let padded_len = original_time_len + 2 * pad_amount;
940
941 let mut overlap_buffer = vec![T::zero(); padded_len];
942 let mut window_energy = vec![T::zero(); padded_len];
943 let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
944
945 for frame_idx in 0..num_frames {
947 let pos = frame_idx * self.config.hop_size;
948 for i in 0..self.config.fft_size {
949 match self.config.reconstruction_mode {
950 ReconstructionMode::Ola => {
951 window_energy[pos + i] = window_energy[pos + i] + self.window[i];
952 }
953 ReconstructionMode::Wola => {
954 window_energy[pos + i] =
955 window_energy[pos + i] + self.window[i] * self.window[i];
956 }
957 }
958 }
959 }
960
961 for frame_idx in 0..num_frames {
963 (0..spectrum.freq_bins).for_each(|bin| {
965 ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
966 });
967
968 for bin in 1..(spectrum.freq_bins - 1) {
970 ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
971 }
972
973 self.ifft.process(&mut ifft_buffer);
975
976 let pos = frame_idx * self.config.hop_size;
978 for i in 0..self.config.fft_size {
979 let fft_size_t = T::from(self.config.fft_size).unwrap();
980 let sample = ifft_buffer[i].re / fft_size_t;
981
982 match self.config.reconstruction_mode {
983 ReconstructionMode::Ola => {
984 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
985 }
986 ReconstructionMode::Wola => {
987 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
988 }
989 }
990 }
991 }
992
993 let threshold = T::from(1e-8).unwrap();
995 for i in 0..padded_len {
996 if window_energy[i] > threshold {
997 overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
998 }
999 }
1000
1001 output.clear();
1003 output.extend_from_slice(&overlap_buffer[pad_amount..pad_amount + original_time_len]);
1004 }
1005
1006 pub fn process_multichannel(&self, spectra: &[Spectrum<T>]) -> Vec<Vec<T>> {
1036 assert!(!spectra.is_empty(), "spectra must not be empty");
1037
1038 #[cfg(feature = "rayon")]
1040 {
1041 use rayon::prelude::*;
1042 spectra
1043 .par_iter()
1044 .map(|spectrum| self.process(spectrum))
1045 .collect()
1046 }
1047 #[cfg(not(feature = "rayon"))]
1048 {
1049 spectra
1050 .iter()
1051 .map(|spectrum| self.process(spectrum))
1052 .collect()
1053 }
1054 }
1055
1056 pub fn process_multichannel_interleaved(&self, spectra: &[Spectrum<T>]) -> Vec<T> {
1086 let channels = self.process_multichannel(spectra);
1087 utils::interleave(&channels)
1088 }
1089}
1090
1091#[derive(Debug, Clone)]
1092pub struct StreamingStft<T: Float + FftNum> {
1093 config: StftConfig<T>,
1094 window: Vec<T>,
1095 fft: Arc<dyn FftBackend<T>>,
1096 input_buffer: VecDeque<T>,
1097 frame_index: usize,
1098 fft_buffer: Vec<Complex<T>>,
1099}
1100
1101impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingStft<T> {
1102 pub fn new(config: StftConfig<T>) -> Self
1103 where
1104 FftPlanner<T>: FftPlannerTrait<T>,
1105 {
1106 let window = config.generate_window();
1107 let mut planner = <FftPlanner<T> as FftPlannerTrait<T>>::new();
1108 let fft = planner.plan_fft_forward(config.fft_size);
1109 let fft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1110
1111 Self {
1112 config,
1113 window,
1114 fft,
1115 input_buffer: VecDeque::new(),
1116 frame_index: 0,
1117 fft_buffer,
1118 }
1119 }
1120
1121 pub fn push_samples(&mut self, samples: &[T]) -> Vec<SpectrumFrame<T>> {
1122 self.input_buffer.extend(samples.iter().copied());
1123
1124 let mut frames = Vec::new();
1125
1126 while self.input_buffer.len() >= self.config.fft_size {
1127 for i in 0..self.config.fft_size {
1129 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1130 }
1131
1132 self.fft.process(&mut self.fft_buffer);
1133
1134 let freq_bins = self.config.freq_bins();
1135 let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1136 frames.push(SpectrumFrame::from_data(data));
1137
1138 self.input_buffer.drain(..self.config.hop_size);
1140 self.frame_index += 1;
1141 }
1142
1143 frames
1144 }
1145
1146 pub fn push_samples_into(
1149 &mut self,
1150 samples: &[T],
1151 output: &mut Vec<SpectrumFrame<T>>,
1152 ) -> usize {
1153 self.input_buffer.extend(samples.iter().copied());
1154
1155 let initial_len = output.len();
1156
1157 while self.input_buffer.len() >= self.config.fft_size {
1158 for i in 0..self.config.fft_size {
1160 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1161 }
1162
1163 self.fft.process(&mut self.fft_buffer);
1164
1165 let freq_bins = self.config.freq_bins();
1166 let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1167 output.push(SpectrumFrame::from_data(data));
1168
1169 self.input_buffer.drain(..self.config.hop_size);
1171 self.frame_index += 1;
1172 }
1173
1174 output.len() - initial_len
1175 }
1176
1177 pub fn push_samples_write(
1190 &mut self,
1191 samples: &[T],
1192 frame_pool: &mut [SpectrumFrame<T>],
1193 pool_index: &mut usize,
1194 ) -> usize {
1195 self.input_buffer.extend(samples.iter().copied());
1196
1197 let initial_index = *pool_index;
1198 let freq_bins = self.config.freq_bins();
1199
1200 while self.input_buffer.len() >= self.config.fft_size && *pool_index < frame_pool.len() {
1201 for i in 0..self.config.fft_size {
1203 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1204 }
1205
1206 self.fft.process(&mut self.fft_buffer);
1207
1208 let frame = &mut frame_pool[*pool_index];
1210 debug_assert_eq!(
1211 frame.freq_bins, freq_bins,
1212 "Frame pool frames must match freq_bins"
1213 );
1214 frame.data[..freq_bins].copy_from_slice(&self.fft_buffer[..freq_bins]);
1215
1216 self.input_buffer.drain(..self.config.hop_size);
1218 self.frame_index += 1;
1219 *pool_index += 1;
1220 }
1221
1222 *pool_index - initial_index
1223 }
1224
1225 pub fn flush(&mut self) -> Vec<SpectrumFrame<T>> {
1226 Vec::new()
1229 }
1230
1231 pub fn reset(&mut self) {
1232 self.input_buffer.clear();
1233 self.frame_index = 0;
1234 }
1235
1236 pub fn buffered_samples(&self) -> usize {
1237 self.input_buffer.len()
1238 }
1239}
1240
1241#[derive(Debug, Clone)]
1243pub struct MultiChannelStreamingStft<T: Float + FftNum> {
1244 processors: Vec<StreamingStft<T>>,
1245}
1246
1247impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingStft<T>
1248where
1249 FftPlanner<T>: FftPlannerTrait<T>,
1250{
1251 pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1258 assert!(num_channels > 0, "num_channels must be > 0");
1259 let processors = (0..num_channels)
1260 .map(|_| StreamingStft::new(config.clone()))
1261 .collect();
1262 Self { processors }
1263 }
1264
1265 pub fn push_samples(&mut self, channels: &[&[T]]) -> Vec<Vec<SpectrumFrame<T>>> {
1276 assert_eq!(
1277 channels.len(),
1278 self.processors.len(),
1279 "Expected {} channels, got {}",
1280 self.processors.len(),
1281 channels.len()
1282 );
1283
1284 #[cfg(feature = "rayon")]
1285 {
1286 use rayon::prelude::*;
1287 self.processors
1288 .par_iter_mut()
1289 .zip(channels.par_iter())
1290 .map(|(stft, channel)| stft.push_samples(channel))
1291 .collect()
1292 }
1293 #[cfg(not(feature = "rayon"))]
1294 {
1295 self.processors
1296 .iter_mut()
1297 .zip(channels.iter())
1298 .map(|(stft, channel)| stft.push_samples(channel))
1299 .collect()
1300 }
1301 }
1302
1303 pub fn flush(&mut self) -> Vec<Vec<SpectrumFrame<T>>> {
1305 #[cfg(feature = "rayon")]
1306 {
1307 use rayon::prelude::*;
1308 self.processors
1309 .par_iter_mut()
1310 .map(|stft| stft.flush())
1311 .collect()
1312 }
1313 #[cfg(not(feature = "rayon"))]
1314 {
1315 self.processors
1316 .iter_mut()
1317 .map(|stft| stft.flush())
1318 .collect()
1319 }
1320 }
1321
1322 pub fn reset(&mut self) {
1324 #[cfg(feature = "rayon")]
1325 {
1326 use rayon::prelude::*;
1327 self.processors.par_iter_mut().for_each(|stft| stft.reset());
1328 }
1329 #[cfg(not(feature = "rayon"))]
1330 {
1331 self.processors.iter_mut().for_each(|stft| stft.reset());
1332 }
1333 }
1334
1335 pub fn num_channels(&self) -> usize {
1337 self.processors.len()
1338 }
1339}
1340
1341#[derive(Debug, Clone)]
1342pub struct StreamingIstft<T: Float + FftNum> {
1343 config: StftConfig<T>,
1344 window: Vec<T>,
1345 ifft: Arc<dyn FftBackend<T>>,
1346 overlap_buffer: Vec<T>,
1347 window_energy: Vec<T>,
1348 output_position: usize,
1349 frames_processed: usize,
1350 ifft_buffer: Vec<Complex<T>>,
1351}
1352
1353impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingIstft<T> {
1354 pub fn new(config: StftConfig<T>) -> Self
1355 where
1356 FftPlanner<T>: FftPlannerTrait<T>,
1357 {
1358 let window = config.generate_window();
1359 let mut planner = <FftPlanner<T> as FftPlannerTrait<T>>::new();
1360 let ifft = planner.plan_fft_inverse(config.fft_size);
1361
1362 let buffer_size = config.fft_size * 2;
1365 let ifft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1366
1367 Self {
1368 config,
1369 window,
1370 ifft,
1371 overlap_buffer: vec![T::zero(); buffer_size],
1372 window_energy: vec![T::zero(); buffer_size],
1373 output_position: 0,
1374 frames_processed: 0,
1375 ifft_buffer,
1376 }
1377 }
1378
1379 pub fn push_frame(&mut self, frame: &SpectrumFrame<T>) -> Vec<T> {
1380 assert_eq!(
1381 frame.freq_bins,
1382 self.config.freq_bins(),
1383 "Frequency bins mismatch"
1384 );
1385
1386 for bin in 0..frame.freq_bins {
1388 self.ifft_buffer[bin] = frame.data[bin];
1389 }
1390
1391 for bin in 1..(frame.freq_bins - 1) {
1393 self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1394 }
1395
1396 self.ifft.process(&mut self.ifft_buffer);
1398
1399 let write_pos = self.frames_processed * self.config.hop_size;
1401 for i in 0..self.config.fft_size {
1402 let fft_size_t = T::from(self.config.fft_size).unwrap();
1403 let sample = self.ifft_buffer[i].re / fft_size_t;
1404 let buf_idx = write_pos + i;
1405
1406 if buf_idx >= self.overlap_buffer.len() {
1408 self.overlap_buffer.resize(buf_idx + 1, T::zero());
1409 self.window_energy.resize(buf_idx + 1, T::zero());
1410 }
1411
1412 match self.config.reconstruction_mode {
1413 ReconstructionMode::Ola => {
1414 self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1415 self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1416 }
1417 ReconstructionMode::Wola => {
1418 self.overlap_buffer[buf_idx] =
1419 self.overlap_buffer[buf_idx] + sample * self.window[i];
1420 self.window_energy[buf_idx] =
1421 self.window_energy[buf_idx] + self.window[i] * self.window[i];
1422 }
1423 }
1424 }
1425
1426 self.frames_processed += 1;
1427
1428 let ready_until = if self.frames_processed == 1 {
1431 0 } else {
1433 (self.frames_processed - 1) * self.config.hop_size
1435 };
1436
1437 let output_start = self.output_position;
1439 let output_end = ready_until;
1440 let mut output = Vec::new();
1441
1442 let threshold = T::from(1e-8).unwrap();
1443 if output_end > output_start {
1444 for i in output_start..output_end {
1445 let normalized = if self.window_energy[i] > threshold {
1446 self.overlap_buffer[i] / self.window_energy[i]
1447 } else {
1448 T::zero()
1449 };
1450 output.push(normalized);
1451 }
1452 self.output_position = output_end;
1453 }
1454
1455 output
1456 }
1457
1458 pub fn push_frame_into(&mut self, frame: &SpectrumFrame<T>, output: &mut Vec<T>) -> usize {
1461 assert_eq!(
1462 frame.freq_bins,
1463 self.config.freq_bins(),
1464 "Frequency bins mismatch"
1465 );
1466
1467 for bin in 0..frame.freq_bins {
1469 self.ifft_buffer[bin] = frame.data[bin];
1470 }
1471
1472 for bin in 1..(frame.freq_bins - 1) {
1474 self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1475 }
1476
1477 self.ifft.process(&mut self.ifft_buffer);
1479
1480 let write_pos = self.frames_processed * self.config.hop_size;
1482 for i in 0..self.config.fft_size {
1483 let fft_size_t = T::from(self.config.fft_size).unwrap();
1484 let sample = self.ifft_buffer[i].re / fft_size_t;
1485 let buf_idx = write_pos + i;
1486
1487 if buf_idx >= self.overlap_buffer.len() {
1489 self.overlap_buffer.resize(buf_idx + 1, T::zero());
1490 self.window_energy.resize(buf_idx + 1, T::zero());
1491 }
1492
1493 match self.config.reconstruction_mode {
1494 ReconstructionMode::Ola => {
1495 self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1496 self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1497 }
1498 ReconstructionMode::Wola => {
1499 self.overlap_buffer[buf_idx] =
1500 self.overlap_buffer[buf_idx] + sample * self.window[i];
1501 self.window_energy[buf_idx] =
1502 self.window_energy[buf_idx] + self.window[i] * self.window[i];
1503 }
1504 }
1505 }
1506
1507 self.frames_processed += 1;
1508
1509 let ready_until = if self.frames_processed == 1 {
1512 0 } else {
1514 (self.frames_processed - 1) * self.config.hop_size
1516 };
1517
1518 let output_start = self.output_position;
1520 let output_end = ready_until;
1521 let initial_len = output.len();
1522
1523 let threshold = T::from(1e-8).unwrap();
1524 if output_end > output_start {
1525 for i in output_start..output_end {
1526 let normalized = if self.window_energy[i] > threshold {
1527 self.overlap_buffer[i] / self.window_energy[i]
1528 } else {
1529 T::zero()
1530 };
1531 output.push(normalized);
1532 }
1533 self.output_position = output_end;
1534 }
1535
1536 output.len() - initial_len
1537 }
1538
1539 pub fn flush(&mut self) -> Vec<T> {
1540 let mut output = Vec::new();
1542 let threshold = T::from(1e-8).unwrap();
1543 for i in self.output_position..self.overlap_buffer.len() {
1544 if self.window_energy[i] > threshold {
1545 output.push(self.overlap_buffer[i] / self.window_energy[i]);
1546 } else if i < (self.frames_processed * self.config.hop_size + self.config.fft_size) {
1547 output.push(T::zero()); } else {
1549 break; }
1551 }
1552
1553 let valid_end =
1555 (self.frames_processed.saturating_sub(1)) * self.config.hop_size + self.config.fft_size;
1556 if output.len() > valid_end - self.output_position {
1557 output.truncate(valid_end - self.output_position);
1558 }
1559
1560 self.reset();
1561 output
1562 }
1563
1564 pub fn reset(&mut self) {
1565 self.overlap_buffer.clear();
1566 self.overlap_buffer
1567 .resize(self.config.fft_size * 2, T::zero());
1568 self.window_energy.clear();
1569 self.window_energy
1570 .resize(self.config.fft_size * 2, T::zero());
1571 self.output_position = 0;
1572 self.frames_processed = 0;
1573 }
1574}
1575
1576#[derive(Debug, Clone)]
1578pub struct MultiChannelStreamingIstft<T: Float + FftNum> {
1579 processors: Vec<StreamingIstft<T>>,
1580}
1581
1582impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingIstft<T>
1583where
1584 FftPlanner<T>: FftPlannerTrait<T>,
1585{
1586 pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1593 assert!(num_channels > 0, "num_channels must be > 0");
1594 let processors = (0..num_channels)
1595 .map(|_| StreamingIstft::new(config.clone()))
1596 .collect();
1597 Self { processors }
1598 }
1599
1600 pub fn push_frames(&mut self, frames: &[&SpectrumFrame<T>]) -> Vec<Vec<T>> {
1611 assert_eq!(
1612 frames.len(),
1613 self.processors.len(),
1614 "Expected {} channels, got {}",
1615 self.processors.len(),
1616 frames.len()
1617 );
1618
1619 #[cfg(feature = "rayon")]
1620 {
1621 use rayon::prelude::*;
1622 self.processors
1623 .par_iter_mut()
1624 .zip(frames.par_iter())
1625 .map(|(istft, frame)| istft.push_frame(frame))
1626 .collect()
1627 }
1628 #[cfg(not(feature = "rayon"))]
1629 {
1630 self.processors
1631 .iter_mut()
1632 .zip(frames.iter())
1633 .map(|(istft, frame)| istft.push_frame(frame))
1634 .collect()
1635 }
1636 }
1637
1638 pub fn flush(&mut self) -> Vec<Vec<T>> {
1640 #[cfg(feature = "rayon")]
1641 {
1642 use rayon::prelude::*;
1643 self.processors
1644 .par_iter_mut()
1645 .map(|istft| istft.flush())
1646 .collect()
1647 }
1648 #[cfg(not(feature = "rayon"))]
1649 {
1650 self.processors
1651 .iter_mut()
1652 .map(|istft| istft.flush())
1653 .collect()
1654 }
1655 }
1656
1657 pub fn reset(&mut self) {
1659 #[cfg(feature = "rayon")]
1660 {
1661 use rayon::prelude::*;
1662 self.processors
1663 .par_iter_mut()
1664 .for_each(|istft| istft.reset());
1665 }
1666 #[cfg(not(feature = "rayon"))]
1667 {
1668 self.processors.iter_mut().for_each(|istft| istft.reset());
1669 }
1670 }
1671
1672 pub fn num_channels(&self) -> usize {
1674 self.processors.len()
1675 }
1676}
1677
1678pub type StftConfigF32 = StftConfig<f32>;
1680pub type StftConfigF64 = StftConfig<f64>;
1681
1682pub type StftConfigBuilderF32 = StftConfigBuilder<f32>;
1683pub type StftConfigBuilderF64 = StftConfigBuilder<f64>;
1684
1685pub type BatchStftF32 = BatchStft<f32>;
1686pub type BatchStftF64 = BatchStft<f64>;
1687
1688pub type BatchIstftF32 = BatchIstft<f32>;
1689pub type BatchIstftF64 = BatchIstft<f64>;
1690
1691pub type StreamingStftF32 = StreamingStft<f32>;
1692pub type StreamingStftF64 = StreamingStft<f64>;
1693
1694pub type StreamingIstftF32 = StreamingIstft<f32>;
1695pub type StreamingIstftF64 = StreamingIstft<f64>;
1696
1697pub type SpectrumF32 = Spectrum<f32>;
1698pub type SpectrumF64 = Spectrum<f64>;
1699
1700pub type SpectrumFrameF32 = SpectrumFrame<f32>;
1701pub type SpectrumFrameF64 = SpectrumFrame<f64>;
1702
1703pub type MultiChannelStreamingStftF32 = MultiChannelStreamingStft<f32>;
1704pub type MultiChannelStreamingStftF64 = MultiChannelStreamingStft<f64>;
1705
1706pub type MultiChannelStreamingIstftF32 = MultiChannelStreamingIstft<f32>;
1707pub type MultiChannelStreamingIstftF64 = MultiChannelStreamingIstft<f64>;