1use num_traits::{Float, FromPrimitive};
25use rustfft::num_complex::Complex;
26use rustfft::{Fft, FftNum, FftPlanner};
27use std::collections::VecDeque;
28use std::fmt;
29use std::sync::Arc;
30
31mod utils;
32pub use utils::{apply_padding, deinterleave, deinterleave_into, interleave, interleave_into};
33
34pub mod prelude {
35 pub use crate::utils::{
36 apply_padding, deinterleave, deinterleave_into, interleave, interleave_into,
37 };
38 pub use crate::{
39 BatchIstft, BatchIstftF32, BatchIstftF64, BatchStft, BatchStftF32, BatchStftF64,
40 MultiChannelStreamingIstft, MultiChannelStreamingIstftF32, MultiChannelStreamingIstftF64,
41 MultiChannelStreamingStft, MultiChannelStreamingStftF32, MultiChannelStreamingStftF64,
42 PadMode, ReconstructionMode, Spectrum, SpectrumF32, SpectrumF64, SpectrumFrame,
43 SpectrumFrameF32, SpectrumFrameF64, StftConfig, StftConfigBuilder, StftConfigBuilderF32,
44 StftConfigBuilderF64, StftConfigF32, StftConfigF64, StreamingIstft, StreamingIstftF32,
45 StreamingIstftF64, StreamingStft, StreamingStftF32, StreamingStftF64, WindowType,
46 };
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum ReconstructionMode {
51 Ola,
53
54 Wola,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum WindowType {
60 Hann,
61 Hamming,
62 Blackman,
63}
64
65#[derive(Debug, Clone)]
66pub enum ConfigError<T: Float + fmt::Debug> {
67 NolaViolation { min_energy: T, threshold: T },
68 ColaViolation { max_deviation: T, threshold: T },
69 InvalidHopSize,
70 InvalidFftSize,
71}
72
73impl<T: Float + fmt::Display + fmt::Debug> fmt::Display for ConfigError<T> {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 match self {
76 ConfigError::NolaViolation {
77 min_energy,
78 threshold,
79 } => {
80 write!(
81 f,
82 "NOLA condition violated: min_energy={} < threshold={}",
83 min_energy, threshold
84 )
85 }
86 ConfigError::ColaViolation {
87 max_deviation,
88 threshold,
89 } => {
90 write!(
91 f,
92 "COLA condition violated: max_deviation={} > threshold={}",
93 max_deviation, threshold
94 )
95 }
96 ConfigError::InvalidHopSize => write!(f, "Invalid hop size"),
97 ConfigError::InvalidFftSize => write!(f, "Invalid FFT size"),
98 }
99 }
100}
101
102impl<T: Float + fmt::Display + fmt::Debug> std::error::Error for ConfigError<T> {}
103
104#[derive(Debug, Clone, Copy)]
105pub enum PadMode {
106 Reflect,
107 Zero,
108 Edge,
109}
110
111#[derive(Clone)]
112pub struct StftConfig<T: Float> {
113 pub fft_size: usize,
114 pub hop_size: usize,
115 pub window: WindowType,
116 pub reconstruction_mode: ReconstructionMode,
117 _phantom: std::marker::PhantomData<T>,
118}
119
120impl<T: Float + FromPrimitive + fmt::Debug> StftConfig<T> {
121 fn nola_threshold() -> T {
122 T::from(1e-8).unwrap()
123 }
124
125 fn cola_relative_tolerance() -> T {
126 T::from(1e-4).unwrap()
127 }
128
129 #[deprecated(
130 since = "0.4.0",
131 note = "Use `StftConfig::builder()` instead for a more flexible API"
132 )]
133 pub fn new(
134 fft_size: usize,
135 hop_size: usize,
136 window: WindowType,
137 reconstruction_mode: ReconstructionMode,
138 ) -> Result<Self, ConfigError<T>> {
139 if fft_size == 0 || !fft_size.is_power_of_two() {
140 return Err(ConfigError::InvalidFftSize);
141 }
142 if hop_size == 0 || hop_size > fft_size {
143 return Err(ConfigError::InvalidHopSize);
144 }
145
146 let config = Self {
147 fft_size,
148 hop_size,
149 window,
150 reconstruction_mode,
151 _phantom: std::marker::PhantomData,
152 };
153
154 match reconstruction_mode {
156 ReconstructionMode::Ola => config.validate_cola()?,
157 ReconstructionMode::Wola => config.validate_nola()?,
158 }
159
160 Ok(config)
161 }
162
163 pub fn builder() -> StftConfigBuilder<T> {
165 StftConfigBuilder::new()
166 }
167
168 #[allow(deprecated)]
170 pub fn default_4096() -> Self {
171 Self::new(4096, 1024, WindowType::Hann, ReconstructionMode::Ola)
172 .expect("Default config should always be valid")
173 }
174
175 pub fn freq_bins(&self) -> usize {
176 self.fft_size / 2 + 1
177 }
178
179 pub fn overlap_percent(&self) -> T {
180 let one = T::one();
181 let hundred = T::from(100.0).unwrap();
182 (one - T::from(self.hop_size).unwrap() / T::from(self.fft_size).unwrap()) * hundred
183 }
184
185 fn generate_window(&self) -> Vec<T> {
186 generate_window(self.window, self.fft_size)
187 }
188
189 pub fn validate_nola(&self) -> Result<(), ConfigError<T>> {
191 let window = self.generate_window();
192 let num_overlaps = (self.fft_size + self.hop_size - 1) / self.hop_size;
193 let test_len = self.fft_size + (num_overlaps - 1) * self.hop_size;
194 let mut energy = vec![T::zero(); test_len];
195
196 for i in 0..num_overlaps {
197 let offset = i * self.hop_size;
198 for j in 0..self.fft_size {
199 if offset + j < test_len {
200 energy[offset + j] = energy[offset + j] + window[j] * window[j];
201 }
202 }
203 }
204
205 let start = self.fft_size / 2;
207 let end = test_len - self.fft_size / 2;
208 let min_energy = energy[start..end]
209 .iter()
210 .copied()
211 .min_by(|a, b| a.partial_cmp(b).unwrap())
212 .unwrap_or_else(T::zero);
213
214 if min_energy < Self::nola_threshold() {
215 return Err(ConfigError::NolaViolation {
216 min_energy,
217 threshold: Self::nola_threshold(),
218 });
219 }
220
221 Ok(())
222 }
223
224 pub fn validate_cola(&self) -> Result<(), ConfigError<T>> {
226 let window = self.generate_window();
227 let window_len = window.len();
228
229 let mut cola_sum_period = vec![T::zero(); self.hop_size];
230 for i in 0..window_len {
231 let idx = i % self.hop_size;
232 cola_sum_period[idx] = cola_sum_period[idx] + window[i];
233 }
234
235 let zero = T::zero();
236 let min_sum = cola_sum_period
237 .iter()
238 .min_by(|a, b| a.partial_cmp(b).unwrap())
239 .unwrap_or(&zero);
240 let max_sum = cola_sum_period
241 .iter()
242 .max_by(|a, b| a.partial_cmp(b).unwrap())
243 .unwrap_or(&zero);
244
245 let epsilon = T::from(1e-9).unwrap();
246 if *max_sum < epsilon {
247 return Err(ConfigError::ColaViolation {
248 max_deviation: T::infinity(),
249 threshold: Self::cola_relative_tolerance(),
250 });
251 }
252
253 let ripple = (*max_sum - *min_sum) / *max_sum;
254
255 let is_compliant = ripple < Self::cola_relative_tolerance();
256
257 if !is_compliant {
258 return Err(ConfigError::ColaViolation {
259 max_deviation: ripple,
260 threshold: Self::cola_relative_tolerance(),
261 });
262 }
263 Ok(())
264 }
265}
266
267pub struct StftConfigBuilder<T: Float> {
269 fft_size: Option<usize>,
270 hop_size: Option<usize>,
271 window: WindowType,
272 reconstruction_mode: ReconstructionMode,
273 _phantom: std::marker::PhantomData<T>,
274}
275
276impl<T: Float + FromPrimitive + fmt::Debug> StftConfigBuilder<T> {
277 pub fn new() -> Self {
279 Self {
280 fft_size: None,
281 hop_size: None,
282 window: WindowType::Hann,
283 reconstruction_mode: ReconstructionMode::Ola,
284 _phantom: std::marker::PhantomData,
285 }
286 }
287
288 pub fn fft_size(mut self, fft_size: usize) -> Self {
290 self.fft_size = Some(fft_size);
291 self
292 }
293
294 pub fn hop_size(mut self, hop_size: usize) -> Self {
296 self.hop_size = Some(hop_size);
297 self
298 }
299
300 pub fn window(mut self, window: WindowType) -> Self {
302 self.window = window;
303 self
304 }
305
306 pub fn reconstruction_mode(mut self, mode: ReconstructionMode) -> Self {
308 self.reconstruction_mode = mode;
309 self
310 }
311
312 #[allow(deprecated)]
319 pub fn build(self) -> Result<StftConfig<T>, ConfigError<T>> {
320 let fft_size = self.fft_size.ok_or(ConfigError::InvalidFftSize)?;
321 let hop_size = self.hop_size.ok_or(ConfigError::InvalidHopSize)?;
322
323 StftConfig::new(fft_size, hop_size, self.window, self.reconstruction_mode)
324 }
325}
326
327impl<T: Float + FromPrimitive + fmt::Debug> Default for StftConfigBuilder<T> {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333fn generate_window<T: Float + FromPrimitive>(window_type: WindowType, size: usize) -> Vec<T> {
334 let pi = T::from(std::f64::consts::PI).unwrap();
335 let two = T::from(2.0).unwrap();
336
337 match window_type {
338 WindowType::Hann => (0..size)
339 .map(|i| {
340 let half = T::from(0.5).unwrap();
341 let one = T::one();
342 let i_t = T::from(i).unwrap();
343 let size_m1 = T::from(size - 1).unwrap();
344 half * (one - (two * pi * i_t / size_m1).cos())
345 })
346 .collect(),
347 WindowType::Hamming => (0..size)
348 .map(|i| {
349 let i_t = T::from(i).unwrap();
350 let size_m1 = T::from(size - 1).unwrap();
351 T::from(0.54).unwrap() - T::from(0.46).unwrap() * (two * pi * i_t / size_m1).cos()
352 })
353 .collect(),
354 WindowType::Blackman => (0..size)
355 .map(|i| {
356 let i_t = T::from(i).unwrap();
357 let size_m1 = T::from(size - 1).unwrap();
358 let angle = two * pi * i_t / size_m1;
359 T::from(0.42).unwrap() - T::from(0.5).unwrap() * angle.cos()
360 + T::from(0.08).unwrap() * (two * angle).cos()
361 })
362 .collect(),
363 }
364}
365
366#[derive(Clone)]
367pub struct SpectrumFrame<T: Float> {
368 pub freq_bins: usize,
369 pub data: Vec<Complex<T>>,
370}
371
372impl<T: Float> SpectrumFrame<T> {
373 pub fn new(freq_bins: usize) -> Self {
374 Self {
375 freq_bins,
376 data: vec![Complex::new(T::zero(), T::zero()); freq_bins],
377 }
378 }
379
380 pub fn from_data(data: Vec<Complex<T>>) -> Self {
381 let freq_bins = data.len();
382 Self { freq_bins, data }
383 }
384
385 pub fn clear(&mut self) {
387 for val in &mut self.data {
388 *val = Complex::new(T::zero(), T::zero());
389 }
390 }
391
392 pub fn resize_if_needed(&mut self, freq_bins: usize) {
394 if self.freq_bins != freq_bins {
395 self.freq_bins = freq_bins;
396 self.data
397 .resize(freq_bins, Complex::new(T::zero(), T::zero()));
398 }
399 }
400
401 pub fn write_from_slice(&mut self, data: &[Complex<T>]) {
403 self.resize_if_needed(data.len());
404 self.data.copy_from_slice(data);
405 }
406
407 #[inline]
409 pub fn magnitude(&self, bin: usize) -> T {
410 let c = &self.data[bin];
411 (c.re * c.re + c.im * c.im).sqrt()
412 }
413
414 #[inline]
416 pub fn phase(&self, bin: usize) -> T {
417 let c = &self.data[bin];
418 c.im.atan2(c.re)
419 }
420
421 pub fn set_magnitude_phase(&mut self, bin: usize, magnitude: T, phase: T) {
423 self.data[bin] = Complex::new(magnitude * phase.cos(), magnitude * phase.sin());
424 }
425
426 pub fn from_magnitude_phase(magnitudes: &[T], phases: &[T]) -> Self {
428 assert_eq!(
429 magnitudes.len(),
430 phases.len(),
431 "Magnitude and phase arrays must have same length"
432 );
433 let freq_bins = magnitudes.len();
434 let data: Vec<Complex<T>> = magnitudes
435 .iter()
436 .zip(phases.iter())
437 .map(|(mag, phase)| Complex::new(*mag * phase.cos(), *mag * phase.sin()))
438 .collect();
439 Self { freq_bins, data }
440 }
441
442 pub fn magnitudes(&self) -> Vec<T> {
444 self.data
445 .iter()
446 .map(|c| (c.re * c.re + c.im * c.im).sqrt())
447 .collect()
448 }
449
450 pub fn phases(&self) -> Vec<T> {
452 self.data.iter().map(|c| c.im.atan2(c.re)).collect()
453 }
454}
455
456#[derive(Clone)]
457pub struct Spectrum<T: Float> {
458 pub num_frames: usize,
459 pub freq_bins: usize,
460 pub data: Vec<T>,
461}
462
463impl<T: Float> Spectrum<T> {
464 pub fn new(num_frames: usize, freq_bins: usize) -> Self {
465 Self {
466 num_frames,
467 freq_bins,
468 data: vec![T::zero(); 2 * num_frames * freq_bins],
469 }
470 }
471
472 #[inline]
473 pub fn real(&self, frame: usize, bin: usize) -> T {
474 self.data[frame * self.freq_bins + bin]
475 }
476
477 #[inline]
478 pub fn imag(&self, frame: usize, bin: usize) -> T {
479 let offset = self.num_frames * self.freq_bins;
480 self.data[offset + frame * self.freq_bins + bin]
481 }
482
483 #[inline]
484 pub fn get_complex(&self, frame: usize, bin: usize) -> Complex<T> {
485 Complex::new(self.real(frame, bin), self.imag(frame, bin))
486 }
487
488 pub fn frames(&self) -> impl Iterator<Item = SpectrumFrame<T>> + '_ {
489 (0..self.num_frames).map(move |frame_idx| {
490 let data: Vec<Complex<T>> = (0..self.freq_bins)
491 .map(|bin| self.get_complex(frame_idx, bin))
492 .collect();
493 SpectrumFrame::from_data(data)
494 })
495 }
496
497 #[inline]
499 pub fn set_real(&mut self, frame: usize, bin: usize, value: T) {
500 self.data[frame * self.freq_bins + bin] = value;
501 }
502
503 #[inline]
505 pub fn set_imag(&mut self, frame: usize, bin: usize, value: T) {
506 let offset = self.num_frames * self.freq_bins;
507 self.data[offset + frame * self.freq_bins + bin] = value;
508 }
509
510 #[inline]
512 pub fn set_complex(&mut self, frame: usize, bin: usize, value: Complex<T>) {
513 self.set_real(frame, bin, value.re);
514 self.set_imag(frame, bin, value.im);
515 }
516
517 #[inline]
519 pub fn magnitude(&self, frame: usize, bin: usize) -> T {
520 let re = self.real(frame, bin);
521 let im = self.imag(frame, bin);
522 (re * re + im * im).sqrt()
523 }
524
525 #[inline]
527 pub fn phase(&self, frame: usize, bin: usize) -> T {
528 let re = self.real(frame, bin);
529 let im = self.imag(frame, bin);
530 im.atan2(re)
531 }
532
533 pub fn set_magnitude_phase(&mut self, frame: usize, bin: usize, magnitude: T, phase: T) {
535 self.set_real(frame, bin, magnitude * phase.cos());
536 self.set_imag(frame, bin, magnitude * phase.sin());
537 }
538
539 pub fn frame_magnitudes(&self, frame: usize) -> Vec<T> {
541 (0..self.freq_bins)
542 .map(|bin| self.magnitude(frame, bin))
543 .collect()
544 }
545
546 pub fn frame_phases(&self, frame: usize) -> Vec<T> {
548 (0..self.freq_bins)
549 .map(|bin| self.phase(frame, bin))
550 .collect()
551 }
552
553 pub fn apply<F>(&mut self, mut f: F)
555 where
556 F: FnMut(usize, usize, Complex<T>) -> Complex<T>,
557 {
558 for frame in 0..self.num_frames {
559 for bin in 0..self.freq_bins {
560 let c = self.get_complex(frame, bin);
561 let new_c = f(frame, bin, c);
562 self.set_complex(frame, bin, new_c);
563 }
564 }
565 }
566
567 pub fn apply_gain(&mut self, bin_range: std::ops::Range<usize>, gain: T) {
569 for frame in 0..self.num_frames {
570 for bin in bin_range.clone() {
571 if bin < self.freq_bins {
572 let c = self.get_complex(frame, bin);
573 self.set_complex(frame, bin, c * gain);
574 }
575 }
576 }
577 }
578
579 pub fn zero_bins(&mut self, bin_range: std::ops::Range<usize>) {
581 for frame in 0..self.num_frames {
582 for bin in bin_range.clone() {
583 if bin < self.freq_bins {
584 self.set_complex(frame, bin, Complex::new(T::zero(), T::zero()));
585 }
586 }
587 }
588 }
589}
590
591pub struct BatchStft<T: Float + FftNum> {
592 config: StftConfig<T>,
593 window: Vec<T>,
594 fft: Arc<dyn Fft<T>>,
595}
596
597impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchStft<T> {
598 pub fn new(config: StftConfig<T>) -> Self {
599 let window = config.generate_window();
600 let mut planner = FftPlanner::new();
601 let fft = planner.plan_fft_forward(config.fft_size);
602
603 Self {
604 config,
605 window,
606 fft,
607 }
608 }
609
610 pub fn process(&self, signal: &[T]) -> Spectrum<T> {
611 self.process_padded(signal, PadMode::Reflect)
612 }
613
614 pub fn process_padded(&self, signal: &[T], pad_mode: PadMode) -> Spectrum<T> {
615 let pad_amount = self.config.fft_size / 2;
616 let padded = utils::apply_padding(signal, pad_amount, pad_mode);
617
618 let num_frames = if padded.len() >= self.config.fft_size {
619 (padded.len() - self.config.fft_size) / self.config.hop_size + 1
620 } else {
621 0
622 };
623
624 let freq_bins = self.config.freq_bins();
625 let mut result = Spectrum::new(num_frames, freq_bins);
626
627 let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
628
629 for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
630 .step_by(self.config.hop_size)
631 .enumerate()
632 {
633 for i in 0..self.config.fft_size {
635 fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
636 }
637
638 self.fft.process(&mut fft_buffer);
640
641 for bin in 0..freq_bins {
643 let idx = frame_idx * freq_bins + bin;
644 result.data[idx] = fft_buffer[bin].re;
645 result.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
646 }
647 }
648
649 result
650 }
651
652 pub fn process_into(&self, signal: &[T], spectrum: &mut Spectrum<T>) -> bool {
656 self.process_padded_into(signal, PadMode::Reflect, spectrum)
657 }
658
659 pub fn process_padded_into(
661 &self,
662 signal: &[T],
663 pad_mode: PadMode,
664 spectrum: &mut Spectrum<T>,
665 ) -> bool {
666 let pad_amount = self.config.fft_size / 2;
667 let padded = utils::apply_padding(signal, pad_amount, pad_mode);
668
669 let num_frames = if padded.len() >= self.config.fft_size {
670 (padded.len() - self.config.fft_size) / self.config.hop_size + 1
671 } else {
672 0
673 };
674
675 let freq_bins = self.config.freq_bins();
676
677 if spectrum.num_frames != num_frames || spectrum.freq_bins != freq_bins {
679 return false;
680 }
681
682 let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
683
684 for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
685 .step_by(self.config.hop_size)
686 .enumerate()
687 {
688 for i in 0..self.config.fft_size {
690 fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
691 }
692
693 self.fft.process(&mut fft_buffer);
695
696 for bin in 0..freq_bins {
698 let idx = frame_idx * freq_bins + bin;
699 spectrum.data[idx] = fft_buffer[bin].re;
700 spectrum.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
701 }
702 }
703
704 true
705 }
706
707 pub fn process_multichannel(&self, channels: &[Vec<T>]) -> Vec<Spectrum<T>> {
734 assert!(!channels.is_empty(), "channels must not be empty");
735
736 let expected_len = channels[0].len();
738 for (i, channel) in channels.iter().enumerate() {
739 assert_eq!(
740 channel.len(),
741 expected_len,
742 "Channel {} has length {}, expected {}",
743 i,
744 channel.len(),
745 expected_len
746 );
747 }
748
749 #[cfg(feature = "rayon")]
751 {
752 use rayon::prelude::*;
753 channels
754 .par_iter()
755 .map(|channel| self.process(channel))
756 .collect()
757 }
758 #[cfg(not(feature = "rayon"))]
759 {
760 channels
761 .iter()
762 .map(|channel| self.process(channel))
763 .collect()
764 }
765 }
766
767 pub fn process_interleaved(&self, data: &[T], num_channels: usize) -> Vec<Spectrum<T>> {
795 let channels = utils::deinterleave(data, num_channels);
796 self.process_multichannel(&channels)
797 }
798}
799
800pub struct BatchIstft<T: Float + FftNum> {
801 config: StftConfig<T>,
802 window: Vec<T>,
803 ifft: Arc<dyn Fft<T>>,
804}
805
806impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchIstft<T> {
807 pub fn new(config: StftConfig<T>) -> Self {
808 let window = config.generate_window();
809 let mut planner = FftPlanner::new();
810 let ifft = planner.plan_fft_inverse(config.fft_size);
811
812 Self {
813 config,
814 window,
815 ifft,
816 }
817 }
818
819 pub fn process(&self, spectrum: &Spectrum<T>) -> Vec<T> {
820 assert_eq!(
821 spectrum.freq_bins,
822 self.config.freq_bins(),
823 "Frequency bins mismatch"
824 );
825
826 let num_frames = spectrum.num_frames;
827 let original_time_len = (num_frames - 1) * self.config.hop_size;
828 let pad_amount = self.config.fft_size / 2;
829 let padded_len = original_time_len + 2 * pad_amount;
830
831 let mut overlap_buffer = vec![T::zero(); padded_len];
832 let mut window_energy = vec![T::zero(); padded_len];
833 let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
834
835 for frame_idx in 0..num_frames {
837 let pos = frame_idx * self.config.hop_size;
838 for i in 0..self.config.fft_size {
839 match self.config.reconstruction_mode {
840 ReconstructionMode::Ola => {
841 window_energy[pos + i] = window_energy[pos + i] + self.window[i];
842 }
843 ReconstructionMode::Wola => {
844 window_energy[pos + i] =
845 window_energy[pos + i] + self.window[i] * self.window[i];
846 }
847 }
848 }
849 }
850
851 for frame_idx in 0..num_frames {
853 for bin in 0..spectrum.freq_bins {
855 ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
856 }
857
858 for bin in 1..(spectrum.freq_bins - 1) {
860 ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
861 }
862
863 self.ifft.process(&mut ifft_buffer);
865
866 let pos = frame_idx * self.config.hop_size;
868 for i in 0..self.config.fft_size {
869 let fft_size_t = T::from(self.config.fft_size).unwrap();
870 let sample = ifft_buffer[i].re / fft_size_t;
871
872 match self.config.reconstruction_mode {
873 ReconstructionMode::Ola => {
874 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
876 }
877 ReconstructionMode::Wola => {
878 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
880 }
881 }
882 }
883 }
884
885 let threshold = T::from(1e-8).unwrap();
887 for i in 0..padded_len {
888 if window_energy[i] > threshold {
889 overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
890 }
891 }
892
893 overlap_buffer[pad_amount..pad_amount + original_time_len].to_vec()
895 }
896
897 pub fn process_into(&self, spectrum: &Spectrum<T>, output: &mut Vec<T>) {
900 assert_eq!(
901 spectrum.freq_bins,
902 self.config.freq_bins(),
903 "Frequency bins mismatch"
904 );
905
906 let num_frames = spectrum.num_frames;
907 let original_time_len = (num_frames - 1) * self.config.hop_size;
908 let pad_amount = self.config.fft_size / 2;
909 let padded_len = original_time_len + 2 * pad_amount;
910
911 let mut overlap_buffer = vec![T::zero(); padded_len];
912 let mut window_energy = vec![T::zero(); padded_len];
913 let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
914
915 for frame_idx in 0..num_frames {
917 let pos = frame_idx * self.config.hop_size;
918 for i in 0..self.config.fft_size {
919 match self.config.reconstruction_mode {
920 ReconstructionMode::Ola => {
921 window_energy[pos + i] = window_energy[pos + i] + self.window[i];
922 }
923 ReconstructionMode::Wola => {
924 window_energy[pos + i] =
925 window_energy[pos + i] + self.window[i] * self.window[i];
926 }
927 }
928 }
929 }
930
931 for frame_idx in 0..num_frames {
933 for bin in 0..spectrum.freq_bins {
935 ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
936 }
937
938 for bin in 1..(spectrum.freq_bins - 1) {
940 ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
941 }
942
943 self.ifft.process(&mut ifft_buffer);
945
946 let pos = frame_idx * self.config.hop_size;
948 for i in 0..self.config.fft_size {
949 let fft_size_t = T::from(self.config.fft_size).unwrap();
950 let sample = ifft_buffer[i].re / fft_size_t;
951
952 match self.config.reconstruction_mode {
953 ReconstructionMode::Ola => {
954 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
955 }
956 ReconstructionMode::Wola => {
957 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
958 }
959 }
960 }
961 }
962
963 let threshold = T::from(1e-8).unwrap();
965 for i in 0..padded_len {
966 if window_energy[i] > threshold {
967 overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
968 }
969 }
970
971 output.clear();
973 output.extend_from_slice(&overlap_buffer[pad_amount..pad_amount + original_time_len]);
974 }
975
976 pub fn process_multichannel(&self, spectra: &[Spectrum<T>]) -> Vec<Vec<T>> {
1006 assert!(!spectra.is_empty(), "spectra must not be empty");
1007
1008 #[cfg(feature = "rayon")]
1010 {
1011 use rayon::prelude::*;
1012 spectra
1013 .par_iter()
1014 .map(|spectrum| self.process(spectrum))
1015 .collect()
1016 }
1017 #[cfg(not(feature = "rayon"))]
1018 {
1019 spectra
1020 .iter()
1021 .map(|spectrum| self.process(spectrum))
1022 .collect()
1023 }
1024 }
1025
1026 pub fn process_multichannel_interleaved(&self, spectra: &[Spectrum<T>]) -> Vec<T> {
1056 let channels = self.process_multichannel(spectra);
1057 utils::interleave(&channels)
1058 }
1059}
1060
1061pub struct StreamingStft<T: Float + FftNum> {
1062 config: StftConfig<T>,
1063 window: Vec<T>,
1064 fft: Arc<dyn Fft<T>>,
1065 input_buffer: VecDeque<T>,
1066 frame_index: usize,
1067 fft_buffer: Vec<Complex<T>>,
1068}
1069
1070impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingStft<T> {
1071 pub fn new(config: StftConfig<T>) -> Self {
1072 let window = config.generate_window();
1073 let mut planner = FftPlanner::new();
1074 let fft = planner.plan_fft_forward(config.fft_size);
1075 let fft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1076
1077 Self {
1078 config,
1079 window,
1080 fft,
1081 input_buffer: VecDeque::new(),
1082 frame_index: 0,
1083 fft_buffer,
1084 }
1085 }
1086
1087 pub fn push_samples(&mut self, samples: &[T]) -> Vec<SpectrumFrame<T>> {
1088 self.input_buffer.extend(samples.iter().copied());
1089
1090 let mut frames = Vec::new();
1091
1092 while self.input_buffer.len() >= self.config.fft_size {
1093 for i in 0..self.config.fft_size {
1095 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1096 }
1097
1098 self.fft.process(&mut self.fft_buffer);
1099
1100 let freq_bins = self.config.freq_bins();
1101 let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1102 frames.push(SpectrumFrame::from_data(data));
1103
1104 self.input_buffer.drain(..self.config.hop_size);
1106 self.frame_index += 1;
1107 }
1108
1109 frames
1110 }
1111
1112 pub fn push_samples_into(
1115 &mut self,
1116 samples: &[T],
1117 output: &mut Vec<SpectrumFrame<T>>,
1118 ) -> usize {
1119 self.input_buffer.extend(samples.iter().copied());
1120
1121 let initial_len = output.len();
1122
1123 while self.input_buffer.len() >= self.config.fft_size {
1124 for i in 0..self.config.fft_size {
1126 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1127 }
1128
1129 self.fft.process(&mut self.fft_buffer);
1130
1131 let freq_bins = self.config.freq_bins();
1132 let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1133 output.push(SpectrumFrame::from_data(data));
1134
1135 self.input_buffer.drain(..self.config.hop_size);
1137 self.frame_index += 1;
1138 }
1139
1140 output.len() - initial_len
1141 }
1142
1143 pub fn push_samples_write(
1156 &mut self,
1157 samples: &[T],
1158 frame_pool: &mut [SpectrumFrame<T>],
1159 pool_index: &mut usize,
1160 ) -> usize {
1161 self.input_buffer.extend(samples.iter().copied());
1162
1163 let initial_index = *pool_index;
1164 let freq_bins = self.config.freq_bins();
1165
1166 while self.input_buffer.len() >= self.config.fft_size && *pool_index < frame_pool.len() {
1167 for i in 0..self.config.fft_size {
1169 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1170 }
1171
1172 self.fft.process(&mut self.fft_buffer);
1173
1174 let frame = &mut frame_pool[*pool_index];
1176 debug_assert_eq!(
1177 frame.freq_bins, freq_bins,
1178 "Frame pool frames must match freq_bins"
1179 );
1180 frame.data[..freq_bins].copy_from_slice(&self.fft_buffer[..freq_bins]);
1181
1182 self.input_buffer.drain(..self.config.hop_size);
1184 self.frame_index += 1;
1185 *pool_index += 1;
1186 }
1187
1188 *pool_index - initial_index
1189 }
1190
1191 pub fn flush(&mut self) -> Vec<SpectrumFrame<T>> {
1192 Vec::new()
1195 }
1196
1197 pub fn reset(&mut self) {
1198 self.input_buffer.clear();
1199 self.frame_index = 0;
1200 }
1201
1202 pub fn buffered_samples(&self) -> usize {
1203 self.input_buffer.len()
1204 }
1205}
1206
1207pub struct MultiChannelStreamingStft<T: Float + FftNum> {
1209 processors: Vec<StreamingStft<T>>,
1210}
1211
1212impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingStft<T> {
1213 pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1220 assert!(num_channels > 0, "num_channels must be > 0");
1221 let processors = (0..num_channels)
1222 .map(|_| StreamingStft::new(config.clone()))
1223 .collect();
1224 Self { processors }
1225 }
1226
1227 pub fn push_samples(&mut self, channels: &[&[T]]) -> Vec<Vec<SpectrumFrame<T>>> {
1238 assert_eq!(
1239 channels.len(),
1240 self.processors.len(),
1241 "Expected {} channels, got {}",
1242 self.processors.len(),
1243 channels.len()
1244 );
1245
1246 #[cfg(feature = "rayon")]
1247 {
1248 use rayon::prelude::*;
1249 self.processors
1250 .par_iter_mut()
1251 .zip(channels.par_iter())
1252 .map(|(stft, channel)| stft.push_samples(channel))
1253 .collect()
1254 }
1255 #[cfg(not(feature = "rayon"))]
1256 {
1257 self.processors
1258 .iter_mut()
1259 .zip(channels.iter())
1260 .map(|(stft, channel)| stft.push_samples(channel))
1261 .collect()
1262 }
1263 }
1264
1265 pub fn flush(&mut self) -> Vec<Vec<SpectrumFrame<T>>> {
1267 #[cfg(feature = "rayon")]
1268 {
1269 use rayon::prelude::*;
1270 self.processors
1271 .par_iter_mut()
1272 .map(|stft| stft.flush())
1273 .collect()
1274 }
1275 #[cfg(not(feature = "rayon"))]
1276 {
1277 self.processors
1278 .iter_mut()
1279 .map(|stft| stft.flush())
1280 .collect()
1281 }
1282 }
1283
1284 pub fn reset(&mut self) {
1286 #[cfg(feature = "rayon")]
1287 {
1288 use rayon::prelude::*;
1289 self.processors.par_iter_mut().for_each(|stft| stft.reset());
1290 }
1291 #[cfg(not(feature = "rayon"))]
1292 {
1293 self.processors.iter_mut().for_each(|stft| stft.reset());
1294 }
1295 }
1296
1297 pub fn num_channels(&self) -> usize {
1299 self.processors.len()
1300 }
1301}
1302
1303pub struct StreamingIstft<T: Float + FftNum> {
1304 config: StftConfig<T>,
1305 window: Vec<T>,
1306 ifft: Arc<dyn Fft<T>>,
1307 overlap_buffer: Vec<T>,
1308 window_energy: Vec<T>,
1309 output_position: usize,
1310 frames_processed: usize,
1311 ifft_buffer: Vec<Complex<T>>,
1312}
1313
1314impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingIstft<T> {
1315 pub fn new(config: StftConfig<T>) -> Self {
1316 let window = config.generate_window();
1317 let mut planner = FftPlanner::new();
1318 let ifft = planner.plan_fft_inverse(config.fft_size);
1319
1320 let buffer_size = config.fft_size * 2;
1323 let ifft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1324
1325 Self {
1326 config,
1327 window,
1328 ifft,
1329 overlap_buffer: vec![T::zero(); buffer_size],
1330 window_energy: vec![T::zero(); buffer_size],
1331 output_position: 0,
1332 frames_processed: 0,
1333 ifft_buffer,
1334 }
1335 }
1336
1337 pub fn push_frame(&mut self, frame: &SpectrumFrame<T>) -> Vec<T> {
1338 assert_eq!(
1339 frame.freq_bins,
1340 self.config.freq_bins(),
1341 "Frequency bins mismatch"
1342 );
1343
1344 for bin in 0..frame.freq_bins {
1346 self.ifft_buffer[bin] = frame.data[bin];
1347 }
1348
1349 for bin in 1..(frame.freq_bins - 1) {
1351 self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1352 }
1353
1354 self.ifft.process(&mut self.ifft_buffer);
1356
1357 let write_pos = self.frames_processed * self.config.hop_size;
1359 for i in 0..self.config.fft_size {
1360 let fft_size_t = T::from(self.config.fft_size).unwrap();
1361 let sample = self.ifft_buffer[i].re / fft_size_t;
1362 let buf_idx = write_pos + i;
1363
1364 if buf_idx >= self.overlap_buffer.len() {
1366 self.overlap_buffer.resize(buf_idx + 1, T::zero());
1367 self.window_energy.resize(buf_idx + 1, T::zero());
1368 }
1369
1370 match self.config.reconstruction_mode {
1371 ReconstructionMode::Ola => {
1372 self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1373 self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1374 }
1375 ReconstructionMode::Wola => {
1376 self.overlap_buffer[buf_idx] =
1377 self.overlap_buffer[buf_idx] + sample * self.window[i];
1378 self.window_energy[buf_idx] =
1379 self.window_energy[buf_idx] + self.window[i] * self.window[i];
1380 }
1381 }
1382 }
1383
1384 self.frames_processed += 1;
1385
1386 let ready_until = if self.frames_processed == 1 {
1389 0 } else {
1391 (self.frames_processed - 1) * self.config.hop_size
1393 };
1394
1395 let output_start = self.output_position;
1397 let output_end = ready_until;
1398 let mut output = Vec::new();
1399
1400 let threshold = T::from(1e-8).unwrap();
1401 if output_end > output_start {
1402 for i in output_start..output_end {
1403 let normalized = if self.window_energy[i] > threshold {
1404 self.overlap_buffer[i] / self.window_energy[i]
1405 } else {
1406 T::zero()
1407 };
1408 output.push(normalized);
1409 }
1410 self.output_position = output_end;
1411 }
1412
1413 output
1414 }
1415
1416 pub fn push_frame_into(&mut self, frame: &SpectrumFrame<T>, output: &mut Vec<T>) -> usize {
1419 assert_eq!(
1420 frame.freq_bins,
1421 self.config.freq_bins(),
1422 "Frequency bins mismatch"
1423 );
1424
1425 for bin in 0..frame.freq_bins {
1427 self.ifft_buffer[bin] = frame.data[bin];
1428 }
1429
1430 for bin in 1..(frame.freq_bins - 1) {
1432 self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1433 }
1434
1435 self.ifft.process(&mut self.ifft_buffer);
1437
1438 let write_pos = self.frames_processed * self.config.hop_size;
1440 for i in 0..self.config.fft_size {
1441 let fft_size_t = T::from(self.config.fft_size).unwrap();
1442 let sample = self.ifft_buffer[i].re / fft_size_t;
1443 let buf_idx = write_pos + i;
1444
1445 if buf_idx >= self.overlap_buffer.len() {
1447 self.overlap_buffer.resize(buf_idx + 1, T::zero());
1448 self.window_energy.resize(buf_idx + 1, T::zero());
1449 }
1450
1451 match self.config.reconstruction_mode {
1452 ReconstructionMode::Ola => {
1453 self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1454 self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1455 }
1456 ReconstructionMode::Wola => {
1457 self.overlap_buffer[buf_idx] =
1458 self.overlap_buffer[buf_idx] + sample * self.window[i];
1459 self.window_energy[buf_idx] =
1460 self.window_energy[buf_idx] + self.window[i] * self.window[i];
1461 }
1462 }
1463 }
1464
1465 self.frames_processed += 1;
1466
1467 let ready_until = if self.frames_processed == 1 {
1470 0 } else {
1472 (self.frames_processed - 1) * self.config.hop_size
1474 };
1475
1476 let output_start = self.output_position;
1478 let output_end = ready_until;
1479 let initial_len = output.len();
1480
1481 let threshold = T::from(1e-8).unwrap();
1482 if output_end > output_start {
1483 for i in output_start..output_end {
1484 let normalized = if self.window_energy[i] > threshold {
1485 self.overlap_buffer[i] / self.window_energy[i]
1486 } else {
1487 T::zero()
1488 };
1489 output.push(normalized);
1490 }
1491 self.output_position = output_end;
1492 }
1493
1494 output.len() - initial_len
1495 }
1496
1497 pub fn flush(&mut self) -> Vec<T> {
1498 let mut output = Vec::new();
1500 let threshold = T::from(1e-8).unwrap();
1501 for i in self.output_position..self.overlap_buffer.len() {
1502 if self.window_energy[i] > threshold {
1503 output.push(self.overlap_buffer[i] / self.window_energy[i]);
1504 } else if i < (self.frames_processed * self.config.hop_size + self.config.fft_size) {
1505 output.push(T::zero()); } else {
1507 break; }
1509 }
1510
1511 let valid_end =
1513 (self.frames_processed.saturating_sub(1)) * self.config.hop_size + self.config.fft_size;
1514 if output.len() > valid_end - self.output_position {
1515 output.truncate(valid_end - self.output_position);
1516 }
1517
1518 self.reset();
1519 output
1520 }
1521
1522 pub fn reset(&mut self) {
1523 self.overlap_buffer.clear();
1524 self.overlap_buffer
1525 .resize(self.config.fft_size * 2, T::zero());
1526 self.window_energy.clear();
1527 self.window_energy
1528 .resize(self.config.fft_size * 2, T::zero());
1529 self.output_position = 0;
1530 self.frames_processed = 0;
1531 }
1532}
1533
1534pub struct MultiChannelStreamingIstft<T: Float + FftNum> {
1536 processors: Vec<StreamingIstft<T>>,
1537}
1538
1539impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingIstft<T> {
1540 pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1547 assert!(num_channels > 0, "num_channels must be > 0");
1548 let processors = (0..num_channels)
1549 .map(|_| StreamingIstft::new(config.clone()))
1550 .collect();
1551 Self { processors }
1552 }
1553
1554 pub fn push_frames(&mut self, frames: &[&SpectrumFrame<T>]) -> Vec<Vec<T>> {
1565 assert_eq!(
1566 frames.len(),
1567 self.processors.len(),
1568 "Expected {} channels, got {}",
1569 self.processors.len(),
1570 frames.len()
1571 );
1572
1573 #[cfg(feature = "rayon")]
1574 {
1575 use rayon::prelude::*;
1576 self.processors
1577 .par_iter_mut()
1578 .zip(frames.par_iter())
1579 .map(|(istft, frame)| istft.push_frame(frame))
1580 .collect()
1581 }
1582 #[cfg(not(feature = "rayon"))]
1583 {
1584 self.processors
1585 .iter_mut()
1586 .zip(frames.iter())
1587 .map(|(istft, frame)| istft.push_frame(frame))
1588 .collect()
1589 }
1590 }
1591
1592 pub fn flush(&mut self) -> Vec<Vec<T>> {
1594 #[cfg(feature = "rayon")]
1595 {
1596 use rayon::prelude::*;
1597 self.processors
1598 .par_iter_mut()
1599 .map(|istft| istft.flush())
1600 .collect()
1601 }
1602 #[cfg(not(feature = "rayon"))]
1603 {
1604 self.processors
1605 .iter_mut()
1606 .map(|istft| istft.flush())
1607 .collect()
1608 }
1609 }
1610
1611 pub fn reset(&mut self) {
1613 #[cfg(feature = "rayon")]
1614 {
1615 use rayon::prelude::*;
1616 self.processors
1617 .par_iter_mut()
1618 .for_each(|istft| istft.reset());
1619 }
1620 #[cfg(not(feature = "rayon"))]
1621 {
1622 self.processors.iter_mut().for_each(|istft| istft.reset());
1623 }
1624 }
1625
1626 pub fn num_channels(&self) -> usize {
1628 self.processors.len()
1629 }
1630}
1631
1632pub type StftConfigF32 = StftConfig<f32>;
1634pub type StftConfigF64 = StftConfig<f64>;
1635
1636pub type StftConfigBuilderF32 = StftConfigBuilder<f32>;
1637pub type StftConfigBuilderF64 = StftConfigBuilder<f64>;
1638
1639pub type BatchStftF32 = BatchStft<f32>;
1640pub type BatchStftF64 = BatchStft<f64>;
1641
1642pub type BatchIstftF32 = BatchIstft<f32>;
1643pub type BatchIstftF64 = BatchIstft<f64>;
1644
1645pub type StreamingStftF32 = StreamingStft<f32>;
1646pub type StreamingStftF64 = StreamingStft<f64>;
1647
1648pub type StreamingIstftF32 = StreamingIstft<f32>;
1649pub type StreamingIstftF64 = StreamingIstft<f64>;
1650
1651pub type SpectrumF32 = Spectrum<f32>;
1652pub type SpectrumF64 = Spectrum<f64>;
1653
1654pub type SpectrumFrameF32 = SpectrumFrame<f32>;
1655pub type SpectrumFrameF64 = SpectrumFrame<f64>;
1656
1657pub type MultiChannelStreamingStftF32 = MultiChannelStreamingStft<f32>;
1658pub type MultiChannelStreamingStftF64 = MultiChannelStreamingStft<f64>;
1659
1660pub type MultiChannelStreamingIstftF32 = MultiChannelStreamingIstft<f32>;
1661pub type MultiChannelStreamingIstftF64 = MultiChannelStreamingIstft<f64>;