1use num_traits::{Float, FromPrimitive};
25use rustfft::num_complex::Complex;
26use rustfft::{Fft, FftNum, FftPlanner};
27use std::collections::VecDeque;
28use std::fmt;
29use std::sync::Arc;
30
31mod utils;
32pub use utils::{apply_padding, deinterleave, deinterleave_into, interleave, interleave_into};
33
34pub mod prelude {
35 pub use crate::utils::{
36 apply_padding, deinterleave, deinterleave_into, interleave, interleave_into,
37 };
38 pub use crate::{
39 BatchIstft, BatchIstftF32, BatchIstftF64, BatchStft, BatchStftF32, BatchStftF64,
40 MultiChannelStreamingIstft, MultiChannelStreamingIstftF32, MultiChannelStreamingIstftF64,
41 MultiChannelStreamingStft, MultiChannelStreamingStftF32, MultiChannelStreamingStftF64,
42 PadMode, ReconstructionMode, Spectrum, SpectrumF32, SpectrumF64, SpectrumFrame,
43 SpectrumFrameF32, SpectrumFrameF64, StftConfig, StftConfigF32, StftConfigF64,
44 StreamingIstft, StreamingIstftF32, StreamingIstftF64, StreamingStft, StreamingStftF32,
45 StreamingStftF64, WindowType,
46 };
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum ReconstructionMode {
51 Ola,
53
54 Wola,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum WindowType {
60 Hann,
61 Hamming,
62 Blackman,
63}
64
65#[derive(Debug, Clone)]
66pub enum ConfigError<T: Float + fmt::Debug> {
67 NolaViolation { min_energy: T, threshold: T },
68 ColaViolation { max_deviation: T, threshold: T },
69 InvalidHopSize,
70 InvalidFftSize,
71}
72
73impl<T: Float + fmt::Display + fmt::Debug> fmt::Display for ConfigError<T> {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 match self {
76 ConfigError::NolaViolation {
77 min_energy,
78 threshold,
79 } => {
80 write!(
81 f,
82 "NOLA condition violated: min_energy={} < threshold={}",
83 min_energy, threshold
84 )
85 }
86 ConfigError::ColaViolation {
87 max_deviation,
88 threshold,
89 } => {
90 write!(
91 f,
92 "COLA condition violated: max_deviation={} > threshold={}",
93 max_deviation, threshold
94 )
95 }
96 ConfigError::InvalidHopSize => write!(f, "Invalid hop size"),
97 ConfigError::InvalidFftSize => write!(f, "Invalid FFT size"),
98 }
99 }
100}
101
102impl<T: Float + fmt::Display + fmt::Debug> std::error::Error for ConfigError<T> {}
103
104#[derive(Debug, Clone, Copy)]
105pub enum PadMode {
106 Reflect,
107 Zero,
108 Edge,
109}
110
111#[derive(Clone)]
112pub struct StftConfig<T: Float> {
113 pub fft_size: usize,
114 pub hop_size: usize,
115 pub window: WindowType,
116 pub reconstruction_mode: ReconstructionMode,
117 _phantom: std::marker::PhantomData<T>,
118}
119
120impl<T: Float + FromPrimitive + fmt::Debug> StftConfig<T> {
121 fn nola_threshold() -> T {
122 T::from(1e-8).unwrap()
123 }
124
125 fn cola_relative_tolerance() -> T {
126 T::from(1e-4).unwrap()
127 }
128
129 pub fn new(
130 fft_size: usize,
131 hop_size: usize,
132 window: WindowType,
133 reconstruction_mode: ReconstructionMode,
134 ) -> Result<Self, ConfigError<T>> {
135 if fft_size == 0 || !fft_size.is_power_of_two() {
136 return Err(ConfigError::InvalidFftSize);
137 }
138 if hop_size == 0 || hop_size > fft_size {
139 return Err(ConfigError::InvalidHopSize);
140 }
141
142 let config = Self {
143 fft_size,
144 hop_size,
145 window,
146 reconstruction_mode,
147 _phantom: std::marker::PhantomData,
148 };
149
150 match reconstruction_mode {
152 ReconstructionMode::Ola => config.validate_cola()?,
153 ReconstructionMode::Wola => config.validate_nola()?,
154 }
155
156 Ok(config)
157 }
158
159 pub fn default_4096() -> Self {
161 Self::new(4096, 1024, WindowType::Hann, ReconstructionMode::Ola)
162 .expect("Default config should always be valid")
163 }
164
165 pub fn freq_bins(&self) -> usize {
166 self.fft_size / 2 + 1
167 }
168
169 pub fn overlap_percent(&self) -> T {
170 let one = T::one();
171 let hundred = T::from(100.0).unwrap();
172 (one - T::from(self.hop_size).unwrap() / T::from(self.fft_size).unwrap()) * hundred
173 }
174
175 fn generate_window(&self) -> Vec<T> {
176 generate_window(self.window, self.fft_size)
177 }
178
179 pub fn validate_nola(&self) -> Result<(), ConfigError<T>> {
181 let window = self.generate_window();
182 let num_overlaps = (self.fft_size + self.hop_size - 1) / self.hop_size;
183 let test_len = self.fft_size + (num_overlaps - 1) * self.hop_size;
184 let mut energy = vec![T::zero(); test_len];
185
186 for i in 0..num_overlaps {
187 let offset = i * self.hop_size;
188 for j in 0..self.fft_size {
189 if offset + j < test_len {
190 energy[offset + j] = energy[offset + j] + window[j] * window[j];
191 }
192 }
193 }
194
195 let start = self.fft_size / 2;
197 let end = test_len - self.fft_size / 2;
198 let min_energy = energy[start..end]
199 .iter()
200 .copied()
201 .min_by(|a, b| a.partial_cmp(b).unwrap())
202 .unwrap_or_else(T::zero);
203
204 if min_energy < Self::nola_threshold() {
205 return Err(ConfigError::NolaViolation {
206 min_energy,
207 threshold: Self::nola_threshold(),
208 });
209 }
210
211 Ok(())
212 }
213
214 pub fn validate_cola(&self) -> Result<(), ConfigError<T>> {
216 let window = self.generate_window();
217 let window_len = window.len();
218
219 let mut cola_sum_period = vec![T::zero(); self.hop_size];
220 for i in 0..window_len {
221 let idx = i % self.hop_size;
222 cola_sum_period[idx] = cola_sum_period[idx] + window[i];
223 }
224
225 let zero = T::zero();
226 let min_sum = cola_sum_period
227 .iter()
228 .min_by(|a, b| a.partial_cmp(b).unwrap())
229 .unwrap_or(&zero);
230 let max_sum = cola_sum_period
231 .iter()
232 .max_by(|a, b| a.partial_cmp(b).unwrap())
233 .unwrap_or(&zero);
234
235 let epsilon = T::from(1e-9).unwrap();
236 if *max_sum < epsilon {
237 return Err(ConfigError::ColaViolation {
238 max_deviation: T::infinity(),
239 threshold: Self::cola_relative_tolerance(),
240 });
241 }
242
243 let ripple = (*max_sum - *min_sum) / *max_sum;
244
245 let is_compliant = ripple < Self::cola_relative_tolerance();
246
247 if !is_compliant {
248 return Err(ConfigError::ColaViolation {
249 max_deviation: ripple,
250 threshold: Self::cola_relative_tolerance(),
251 });
252 }
253 Ok(())
254 }
255}
256
257fn generate_window<T: Float + FromPrimitive>(window_type: WindowType, size: usize) -> Vec<T> {
258 let pi = T::from(std::f64::consts::PI).unwrap();
259 let two = T::from(2.0).unwrap();
260
261 match window_type {
262 WindowType::Hann => (0..size)
263 .map(|i| {
264 let half = T::from(0.5).unwrap();
265 let one = T::one();
266 let i_t = T::from(i).unwrap();
267 let size_m1 = T::from(size - 1).unwrap();
268 half * (one - (two * pi * i_t / size_m1).cos())
269 })
270 .collect(),
271 WindowType::Hamming => (0..size)
272 .map(|i| {
273 let i_t = T::from(i).unwrap();
274 let size_m1 = T::from(size - 1).unwrap();
275 T::from(0.54).unwrap() - T::from(0.46).unwrap() * (two * pi * i_t / size_m1).cos()
276 })
277 .collect(),
278 WindowType::Blackman => (0..size)
279 .map(|i| {
280 let i_t = T::from(i).unwrap();
281 let size_m1 = T::from(size - 1).unwrap();
282 let angle = two * pi * i_t / size_m1;
283 T::from(0.42).unwrap() - T::from(0.5).unwrap() * angle.cos()
284 + T::from(0.08).unwrap() * (two * angle).cos()
285 })
286 .collect(),
287 }
288}
289
290#[derive(Clone)]
291pub struct SpectrumFrame<T: Float> {
292 pub freq_bins: usize,
293 pub data: Vec<Complex<T>>,
294}
295
296impl<T: Float> SpectrumFrame<T> {
297 pub fn new(freq_bins: usize) -> Self {
298 Self {
299 freq_bins,
300 data: vec![Complex::new(T::zero(), T::zero()); freq_bins],
301 }
302 }
303
304 pub fn from_data(data: Vec<Complex<T>>) -> Self {
305 let freq_bins = data.len();
306 Self { freq_bins, data }
307 }
308
309 pub fn clear(&mut self) {
311 for val in &mut self.data {
312 *val = Complex::new(T::zero(), T::zero());
313 }
314 }
315
316 pub fn resize_if_needed(&mut self, freq_bins: usize) {
318 if self.freq_bins != freq_bins {
319 self.freq_bins = freq_bins;
320 self.data
321 .resize(freq_bins, Complex::new(T::zero(), T::zero()));
322 }
323 }
324
325 pub fn write_from_slice(&mut self, data: &[Complex<T>]) {
327 self.resize_if_needed(data.len());
328 self.data.copy_from_slice(data);
329 }
330
331 #[inline]
333 pub fn magnitude(&self, bin: usize) -> T {
334 let c = &self.data[bin];
335 (c.re * c.re + c.im * c.im).sqrt()
336 }
337
338 #[inline]
340 pub fn phase(&self, bin: usize) -> T {
341 let c = &self.data[bin];
342 c.im.atan2(c.re)
343 }
344
345 pub fn set_magnitude_phase(&mut self, bin: usize, magnitude: T, phase: T) {
347 self.data[bin] = Complex::new(magnitude * phase.cos(), magnitude * phase.sin());
348 }
349
350 pub fn from_magnitude_phase(magnitudes: &[T], phases: &[T]) -> Self {
352 assert_eq!(
353 magnitudes.len(),
354 phases.len(),
355 "Magnitude and phase arrays must have same length"
356 );
357 let freq_bins = magnitudes.len();
358 let data: Vec<Complex<T>> = magnitudes
359 .iter()
360 .zip(phases.iter())
361 .map(|(mag, phase)| Complex::new(*mag * phase.cos(), *mag * phase.sin()))
362 .collect();
363 Self { freq_bins, data }
364 }
365
366 pub fn magnitudes(&self) -> Vec<T> {
368 self.data
369 .iter()
370 .map(|c| (c.re * c.re + c.im * c.im).sqrt())
371 .collect()
372 }
373
374 pub fn phases(&self) -> Vec<T> {
376 self.data.iter().map(|c| c.im.atan2(c.re)).collect()
377 }
378}
379
380#[derive(Clone)]
381pub struct Spectrum<T: Float> {
382 pub num_frames: usize,
383 pub freq_bins: usize,
384 pub data: Vec<T>,
385}
386
387impl<T: Float> Spectrum<T> {
388 pub fn new(num_frames: usize, freq_bins: usize) -> Self {
389 Self {
390 num_frames,
391 freq_bins,
392 data: vec![T::zero(); 2 * num_frames * freq_bins],
393 }
394 }
395
396 #[inline]
397 pub fn real(&self, frame: usize, bin: usize) -> T {
398 self.data[frame * self.freq_bins + bin]
399 }
400
401 #[inline]
402 pub fn imag(&self, frame: usize, bin: usize) -> T {
403 let offset = self.num_frames * self.freq_bins;
404 self.data[offset + frame * self.freq_bins + bin]
405 }
406
407 #[inline]
408 pub fn get_complex(&self, frame: usize, bin: usize) -> Complex<T> {
409 Complex::new(self.real(frame, bin), self.imag(frame, bin))
410 }
411
412 pub fn frames(&self) -> impl Iterator<Item = SpectrumFrame<T>> + '_ {
413 (0..self.num_frames).map(move |frame_idx| {
414 let data: Vec<Complex<T>> = (0..self.freq_bins)
415 .map(|bin| self.get_complex(frame_idx, bin))
416 .collect();
417 SpectrumFrame::from_data(data)
418 })
419 }
420
421 #[inline]
423 pub fn set_real(&mut self, frame: usize, bin: usize, value: T) {
424 self.data[frame * self.freq_bins + bin] = value;
425 }
426
427 #[inline]
429 pub fn set_imag(&mut self, frame: usize, bin: usize, value: T) {
430 let offset = self.num_frames * self.freq_bins;
431 self.data[offset + frame * self.freq_bins + bin] = value;
432 }
433
434 #[inline]
436 pub fn set_complex(&mut self, frame: usize, bin: usize, value: Complex<T>) {
437 self.set_real(frame, bin, value.re);
438 self.set_imag(frame, bin, value.im);
439 }
440
441 #[inline]
443 pub fn magnitude(&self, frame: usize, bin: usize) -> T {
444 let re = self.real(frame, bin);
445 let im = self.imag(frame, bin);
446 (re * re + im * im).sqrt()
447 }
448
449 #[inline]
451 pub fn phase(&self, frame: usize, bin: usize) -> T {
452 let re = self.real(frame, bin);
453 let im = self.imag(frame, bin);
454 im.atan2(re)
455 }
456
457 pub fn set_magnitude_phase(&mut self, frame: usize, bin: usize, magnitude: T, phase: T) {
459 self.set_real(frame, bin, magnitude * phase.cos());
460 self.set_imag(frame, bin, magnitude * phase.sin());
461 }
462
463 pub fn frame_magnitudes(&self, frame: usize) -> Vec<T> {
465 (0..self.freq_bins)
466 .map(|bin| self.magnitude(frame, bin))
467 .collect()
468 }
469
470 pub fn frame_phases(&self, frame: usize) -> Vec<T> {
472 (0..self.freq_bins)
473 .map(|bin| self.phase(frame, bin))
474 .collect()
475 }
476
477 pub fn apply<F>(&mut self, mut f: F)
479 where
480 F: FnMut(usize, usize, Complex<T>) -> Complex<T>,
481 {
482 for frame in 0..self.num_frames {
483 for bin in 0..self.freq_bins {
484 let c = self.get_complex(frame, bin);
485 let new_c = f(frame, bin, c);
486 self.set_complex(frame, bin, new_c);
487 }
488 }
489 }
490
491 pub fn apply_gain(&mut self, bin_range: std::ops::Range<usize>, gain: T) {
493 for frame in 0..self.num_frames {
494 for bin in bin_range.clone() {
495 if bin < self.freq_bins {
496 let c = self.get_complex(frame, bin);
497 self.set_complex(frame, bin, c * gain);
498 }
499 }
500 }
501 }
502
503 pub fn zero_bins(&mut self, bin_range: std::ops::Range<usize>) {
505 for frame in 0..self.num_frames {
506 for bin in bin_range.clone() {
507 if bin < self.freq_bins {
508 self.set_complex(frame, bin, Complex::new(T::zero(), T::zero()));
509 }
510 }
511 }
512 }
513}
514
515pub struct BatchStft<T: Float + FftNum> {
516 config: StftConfig<T>,
517 window: Vec<T>,
518 fft: Arc<dyn Fft<T>>,
519}
520
521impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchStft<T> {
522 pub fn new(config: StftConfig<T>) -> Self {
523 let window = config.generate_window();
524 let mut planner = FftPlanner::new();
525 let fft = planner.plan_fft_forward(config.fft_size);
526
527 Self {
528 config,
529 window,
530 fft,
531 }
532 }
533
534 pub fn process(&self, signal: &[T]) -> Spectrum<T> {
535 self.process_padded(signal, PadMode::Reflect)
536 }
537
538 pub fn process_padded(&self, signal: &[T], pad_mode: PadMode) -> Spectrum<T> {
539 let pad_amount = self.config.fft_size / 2;
540 let padded = utils::apply_padding(signal, pad_amount, pad_mode);
541
542 let num_frames = if padded.len() >= self.config.fft_size {
543 (padded.len() - self.config.fft_size) / self.config.hop_size + 1
544 } else {
545 0
546 };
547
548 let freq_bins = self.config.freq_bins();
549 let mut result = Spectrum::new(num_frames, freq_bins);
550
551 let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
552
553 for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
554 .step_by(self.config.hop_size)
555 .enumerate()
556 {
557 for i in 0..self.config.fft_size {
559 fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
560 }
561
562 self.fft.process(&mut fft_buffer);
564
565 for bin in 0..freq_bins {
567 let idx = frame_idx * freq_bins + bin;
568 result.data[idx] = fft_buffer[bin].re;
569 result.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
570 }
571 }
572
573 result
574 }
575
576 pub fn process_into(&self, signal: &[T], spectrum: &mut Spectrum<T>) -> bool {
580 self.process_padded_into(signal, PadMode::Reflect, spectrum)
581 }
582
583 pub fn process_padded_into(
585 &self,
586 signal: &[T],
587 pad_mode: PadMode,
588 spectrum: &mut Spectrum<T>,
589 ) -> bool {
590 let pad_amount = self.config.fft_size / 2;
591 let padded = utils::apply_padding(signal, pad_amount, pad_mode);
592
593 let num_frames = if padded.len() >= self.config.fft_size {
594 (padded.len() - self.config.fft_size) / self.config.hop_size + 1
595 } else {
596 0
597 };
598
599 let freq_bins = self.config.freq_bins();
600
601 if spectrum.num_frames != num_frames || spectrum.freq_bins != freq_bins {
603 return false;
604 }
605
606 let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
607
608 for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
609 .step_by(self.config.hop_size)
610 .enumerate()
611 {
612 for i in 0..self.config.fft_size {
614 fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
615 }
616
617 self.fft.process(&mut fft_buffer);
619
620 for bin in 0..freq_bins {
622 let idx = frame_idx * freq_bins + bin;
623 spectrum.data[idx] = fft_buffer[bin].re;
624 spectrum.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
625 }
626 }
627
628 true
629 }
630
631 pub fn process_multichannel(&self, channels: &[Vec<T>]) -> Vec<Spectrum<T>> {
658 assert!(!channels.is_empty(), "channels must not be empty");
659
660 let expected_len = channels[0].len();
662 for (i, channel) in channels.iter().enumerate() {
663 assert_eq!(
664 channel.len(),
665 expected_len,
666 "Channel {} has length {}, expected {}",
667 i,
668 channel.len(),
669 expected_len
670 );
671 }
672
673 #[cfg(feature = "rayon")]
675 {
676 use rayon::prelude::*;
677 channels
678 .par_iter()
679 .map(|channel| self.process(channel))
680 .collect()
681 }
682 #[cfg(not(feature = "rayon"))]
683 {
684 channels
685 .iter()
686 .map(|channel| self.process(channel))
687 .collect()
688 }
689 }
690
691 pub fn process_interleaved(&self, data: &[T], num_channels: usize) -> Vec<Spectrum<T>> {
719 let channels = utils::deinterleave(data, num_channels);
720 self.process_multichannel(&channels)
721 }
722}
723
724pub struct BatchIstft<T: Float + FftNum> {
725 config: StftConfig<T>,
726 window: Vec<T>,
727 ifft: Arc<dyn Fft<T>>,
728}
729
730impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchIstft<T> {
731 pub fn new(config: StftConfig<T>) -> Self {
732 let window = config.generate_window();
733 let mut planner = FftPlanner::new();
734 let ifft = planner.plan_fft_inverse(config.fft_size);
735
736 Self {
737 config,
738 window,
739 ifft,
740 }
741 }
742
743 pub fn process(&self, spectrum: &Spectrum<T>) -> Vec<T> {
744 assert_eq!(
745 spectrum.freq_bins,
746 self.config.freq_bins(),
747 "Frequency bins mismatch"
748 );
749
750 let num_frames = spectrum.num_frames;
751 let original_time_len = (num_frames - 1) * self.config.hop_size;
752 let pad_amount = self.config.fft_size / 2;
753 let padded_len = original_time_len + 2 * pad_amount;
754
755 let mut overlap_buffer = vec![T::zero(); padded_len];
756 let mut window_energy = vec![T::zero(); padded_len];
757 let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
758
759 for frame_idx in 0..num_frames {
761 let pos = frame_idx * self.config.hop_size;
762 for i in 0..self.config.fft_size {
763 match self.config.reconstruction_mode {
764 ReconstructionMode::Ola => {
765 window_energy[pos + i] = window_energy[pos + i] + self.window[i];
766 }
767 ReconstructionMode::Wola => {
768 window_energy[pos + i] =
769 window_energy[pos + i] + self.window[i] * self.window[i];
770 }
771 }
772 }
773 }
774
775 for frame_idx in 0..num_frames {
777 for bin in 0..spectrum.freq_bins {
779 ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
780 }
781
782 for bin in 1..(spectrum.freq_bins - 1) {
784 ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
785 }
786
787 self.ifft.process(&mut ifft_buffer);
789
790 let pos = frame_idx * self.config.hop_size;
792 for i in 0..self.config.fft_size {
793 let fft_size_t = T::from(self.config.fft_size).unwrap();
794 let sample = ifft_buffer[i].re / fft_size_t;
795
796 match self.config.reconstruction_mode {
797 ReconstructionMode::Ola => {
798 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
800 }
801 ReconstructionMode::Wola => {
802 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
804 }
805 }
806 }
807 }
808
809 let threshold = T::from(1e-8).unwrap();
811 for i in 0..padded_len {
812 if window_energy[i] > threshold {
813 overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
814 }
815 }
816
817 overlap_buffer[pad_amount..pad_amount + original_time_len].to_vec()
819 }
820
821 pub fn process_into(&self, spectrum: &Spectrum<T>, output: &mut Vec<T>) {
824 assert_eq!(
825 spectrum.freq_bins,
826 self.config.freq_bins(),
827 "Frequency bins mismatch"
828 );
829
830 let num_frames = spectrum.num_frames;
831 let original_time_len = (num_frames - 1) * self.config.hop_size;
832 let pad_amount = self.config.fft_size / 2;
833 let padded_len = original_time_len + 2 * pad_amount;
834
835 let mut overlap_buffer = vec![T::zero(); padded_len];
836 let mut window_energy = vec![T::zero(); padded_len];
837 let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
838
839 for frame_idx in 0..num_frames {
841 let pos = frame_idx * self.config.hop_size;
842 for i in 0..self.config.fft_size {
843 match self.config.reconstruction_mode {
844 ReconstructionMode::Ola => {
845 window_energy[pos + i] = window_energy[pos + i] + self.window[i];
846 }
847 ReconstructionMode::Wola => {
848 window_energy[pos + i] =
849 window_energy[pos + i] + self.window[i] * self.window[i];
850 }
851 }
852 }
853 }
854
855 for frame_idx in 0..num_frames {
857 for bin in 0..spectrum.freq_bins {
859 ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
860 }
861
862 for bin in 1..(spectrum.freq_bins - 1) {
864 ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
865 }
866
867 self.ifft.process(&mut ifft_buffer);
869
870 let pos = frame_idx * self.config.hop_size;
872 for i in 0..self.config.fft_size {
873 let fft_size_t = T::from(self.config.fft_size).unwrap();
874 let sample = ifft_buffer[i].re / fft_size_t;
875
876 match self.config.reconstruction_mode {
877 ReconstructionMode::Ola => {
878 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
879 }
880 ReconstructionMode::Wola => {
881 overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
882 }
883 }
884 }
885 }
886
887 let threshold = T::from(1e-8).unwrap();
889 for i in 0..padded_len {
890 if window_energy[i] > threshold {
891 overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
892 }
893 }
894
895 output.clear();
897 output.extend_from_slice(&overlap_buffer[pad_amount..pad_amount + original_time_len]);
898 }
899
900 pub fn process_multichannel(&self, spectra: &[Spectrum<T>]) -> Vec<Vec<T>> {
930 assert!(!spectra.is_empty(), "spectra must not be empty");
931
932 #[cfg(feature = "rayon")]
934 {
935 use rayon::prelude::*;
936 spectra
937 .par_iter()
938 .map(|spectrum| self.process(spectrum))
939 .collect()
940 }
941 #[cfg(not(feature = "rayon"))]
942 {
943 spectra
944 .iter()
945 .map(|spectrum| self.process(spectrum))
946 .collect()
947 }
948 }
949
950 pub fn process_multichannel_interleaved(&self, spectra: &[Spectrum<T>]) -> Vec<T> {
980 let channels = self.process_multichannel(spectra);
981 utils::interleave(&channels)
982 }
983}
984
985pub struct StreamingStft<T: Float + FftNum> {
986 config: StftConfig<T>,
987 window: Vec<T>,
988 fft: Arc<dyn Fft<T>>,
989 input_buffer: VecDeque<T>,
990 frame_index: usize,
991 fft_buffer: Vec<Complex<T>>,
992}
993
994impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingStft<T> {
995 pub fn new(config: StftConfig<T>) -> Self {
996 let window = config.generate_window();
997 let mut planner = FftPlanner::new();
998 let fft = planner.plan_fft_forward(config.fft_size);
999 let fft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1000
1001 Self {
1002 config,
1003 window,
1004 fft,
1005 input_buffer: VecDeque::new(),
1006 frame_index: 0,
1007 fft_buffer,
1008 }
1009 }
1010
1011 pub fn push_samples(&mut self, samples: &[T]) -> Vec<SpectrumFrame<T>> {
1012 self.input_buffer.extend(samples.iter().copied());
1013
1014 let mut frames = Vec::new();
1015
1016 while self.input_buffer.len() >= self.config.fft_size {
1017 for i in 0..self.config.fft_size {
1019 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1020 }
1021
1022 self.fft.process(&mut self.fft_buffer);
1023
1024 let freq_bins = self.config.freq_bins();
1025 let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1026 frames.push(SpectrumFrame::from_data(data));
1027
1028 self.input_buffer.drain(..self.config.hop_size);
1030 self.frame_index += 1;
1031 }
1032
1033 frames
1034 }
1035
1036 pub fn push_samples_into(
1039 &mut self,
1040 samples: &[T],
1041 output: &mut Vec<SpectrumFrame<T>>,
1042 ) -> usize {
1043 self.input_buffer.extend(samples.iter().copied());
1044
1045 let initial_len = output.len();
1046
1047 while self.input_buffer.len() >= self.config.fft_size {
1048 for i in 0..self.config.fft_size {
1050 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1051 }
1052
1053 self.fft.process(&mut self.fft_buffer);
1054
1055 let freq_bins = self.config.freq_bins();
1056 let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1057 output.push(SpectrumFrame::from_data(data));
1058
1059 self.input_buffer.drain(..self.config.hop_size);
1061 self.frame_index += 1;
1062 }
1063
1064 output.len() - initial_len
1065 }
1066
1067 pub fn push_samples_write(
1080 &mut self,
1081 samples: &[T],
1082 frame_pool: &mut [SpectrumFrame<T>],
1083 pool_index: &mut usize,
1084 ) -> usize {
1085 self.input_buffer.extend(samples.iter().copied());
1086
1087 let initial_index = *pool_index;
1088 let freq_bins = self.config.freq_bins();
1089
1090 while self.input_buffer.len() >= self.config.fft_size && *pool_index < frame_pool.len() {
1091 for i in 0..self.config.fft_size {
1093 self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1094 }
1095
1096 self.fft.process(&mut self.fft_buffer);
1097
1098 let frame = &mut frame_pool[*pool_index];
1100 debug_assert_eq!(
1101 frame.freq_bins, freq_bins,
1102 "Frame pool frames must match freq_bins"
1103 );
1104 frame.data[..freq_bins].copy_from_slice(&self.fft_buffer[..freq_bins]);
1105
1106 self.input_buffer.drain(..self.config.hop_size);
1108 self.frame_index += 1;
1109 *pool_index += 1;
1110 }
1111
1112 *pool_index - initial_index
1113 }
1114
1115 pub fn flush(&mut self) -> Vec<SpectrumFrame<T>> {
1116 Vec::new()
1119 }
1120
1121 pub fn reset(&mut self) {
1122 self.input_buffer.clear();
1123 self.frame_index = 0;
1124 }
1125
1126 pub fn buffered_samples(&self) -> usize {
1127 self.input_buffer.len()
1128 }
1129}
1130
1131pub struct MultiChannelStreamingStft<T: Float + FftNum> {
1133 processors: Vec<StreamingStft<T>>,
1134}
1135
1136impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingStft<T> {
1137 pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1144 assert!(num_channels > 0, "num_channels must be > 0");
1145 let processors = (0..num_channels)
1146 .map(|_| StreamingStft::new(config.clone()))
1147 .collect();
1148 Self { processors }
1149 }
1150
1151 pub fn push_samples(&mut self, channels: &[&[T]]) -> Vec<Vec<SpectrumFrame<T>>> {
1162 assert_eq!(
1163 channels.len(),
1164 self.processors.len(),
1165 "Expected {} channels, got {}",
1166 self.processors.len(),
1167 channels.len()
1168 );
1169
1170 #[cfg(feature = "rayon")]
1171 {
1172 use rayon::prelude::*;
1173 self.processors
1174 .par_iter_mut()
1175 .zip(channels.par_iter())
1176 .map(|(stft, channel)| stft.push_samples(channel))
1177 .collect()
1178 }
1179 #[cfg(not(feature = "rayon"))]
1180 {
1181 self.processors
1182 .iter_mut()
1183 .zip(channels.iter())
1184 .map(|(stft, channel)| stft.push_samples(channel))
1185 .collect()
1186 }
1187 }
1188
1189 pub fn flush(&mut self) -> Vec<Vec<SpectrumFrame<T>>> {
1191 #[cfg(feature = "rayon")]
1192 {
1193 use rayon::prelude::*;
1194 self.processors
1195 .par_iter_mut()
1196 .map(|stft| stft.flush())
1197 .collect()
1198 }
1199 #[cfg(not(feature = "rayon"))]
1200 {
1201 self.processors
1202 .iter_mut()
1203 .map(|stft| stft.flush())
1204 .collect()
1205 }
1206 }
1207
1208 pub fn reset(&mut self) {
1210 #[cfg(feature = "rayon")]
1211 {
1212 use rayon::prelude::*;
1213 self.processors.par_iter_mut().for_each(|stft| stft.reset());
1214 }
1215 #[cfg(not(feature = "rayon"))]
1216 {
1217 self.processors.iter_mut().for_each(|stft| stft.reset());
1218 }
1219 }
1220
1221 pub fn num_channels(&self) -> usize {
1223 self.processors.len()
1224 }
1225}
1226
1227pub struct StreamingIstft<T: Float + FftNum> {
1228 config: StftConfig<T>,
1229 window: Vec<T>,
1230 ifft: Arc<dyn Fft<T>>,
1231 overlap_buffer: Vec<T>,
1232 window_energy: Vec<T>,
1233 output_position: usize,
1234 frames_processed: usize,
1235 ifft_buffer: Vec<Complex<T>>,
1236}
1237
1238impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingIstft<T> {
1239 pub fn new(config: StftConfig<T>) -> Self {
1240 let window = config.generate_window();
1241 let mut planner = FftPlanner::new();
1242 let ifft = planner.plan_fft_inverse(config.fft_size);
1243
1244 let buffer_size = config.fft_size * 2;
1247 let ifft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1248
1249 Self {
1250 config,
1251 window,
1252 ifft,
1253 overlap_buffer: vec![T::zero(); buffer_size],
1254 window_energy: vec![T::zero(); buffer_size],
1255 output_position: 0,
1256 frames_processed: 0,
1257 ifft_buffer,
1258 }
1259 }
1260
1261 pub fn push_frame(&mut self, frame: &SpectrumFrame<T>) -> Vec<T> {
1262 assert_eq!(
1263 frame.freq_bins,
1264 self.config.freq_bins(),
1265 "Frequency bins mismatch"
1266 );
1267
1268 for bin in 0..frame.freq_bins {
1270 self.ifft_buffer[bin] = frame.data[bin];
1271 }
1272
1273 for bin in 1..(frame.freq_bins - 1) {
1275 self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1276 }
1277
1278 self.ifft.process(&mut self.ifft_buffer);
1280
1281 let write_pos = self.frames_processed * self.config.hop_size;
1283 for i in 0..self.config.fft_size {
1284 let fft_size_t = T::from(self.config.fft_size).unwrap();
1285 let sample = self.ifft_buffer[i].re / fft_size_t;
1286 let buf_idx = write_pos + i;
1287
1288 if buf_idx >= self.overlap_buffer.len() {
1290 self.overlap_buffer.resize(buf_idx + 1, T::zero());
1291 self.window_energy.resize(buf_idx + 1, T::zero());
1292 }
1293
1294 match self.config.reconstruction_mode {
1295 ReconstructionMode::Ola => {
1296 self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1297 self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1298 }
1299 ReconstructionMode::Wola => {
1300 self.overlap_buffer[buf_idx] =
1301 self.overlap_buffer[buf_idx] + sample * self.window[i];
1302 self.window_energy[buf_idx] =
1303 self.window_energy[buf_idx] + self.window[i] * self.window[i];
1304 }
1305 }
1306 }
1307
1308 self.frames_processed += 1;
1309
1310 let ready_until = if self.frames_processed == 1 {
1313 0 } else {
1315 (self.frames_processed - 1) * self.config.hop_size
1317 };
1318
1319 let output_start = self.output_position;
1321 let output_end = ready_until;
1322 let mut output = Vec::new();
1323
1324 let threshold = T::from(1e-8).unwrap();
1325 if output_end > output_start {
1326 for i in output_start..output_end {
1327 let normalized = if self.window_energy[i] > threshold {
1328 self.overlap_buffer[i] / self.window_energy[i]
1329 } else {
1330 T::zero()
1331 };
1332 output.push(normalized);
1333 }
1334 self.output_position = output_end;
1335 }
1336
1337 output
1338 }
1339
1340 pub fn push_frame_into(&mut self, frame: &SpectrumFrame<T>, output: &mut Vec<T>) -> usize {
1343 assert_eq!(
1344 frame.freq_bins,
1345 self.config.freq_bins(),
1346 "Frequency bins mismatch"
1347 );
1348
1349 for bin in 0..frame.freq_bins {
1351 self.ifft_buffer[bin] = frame.data[bin];
1352 }
1353
1354 for bin in 1..(frame.freq_bins - 1) {
1356 self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1357 }
1358
1359 self.ifft.process(&mut self.ifft_buffer);
1361
1362 let write_pos = self.frames_processed * self.config.hop_size;
1364 for i in 0..self.config.fft_size {
1365 let fft_size_t = T::from(self.config.fft_size).unwrap();
1366 let sample = self.ifft_buffer[i].re / fft_size_t;
1367 let buf_idx = write_pos + i;
1368
1369 if buf_idx >= self.overlap_buffer.len() {
1371 self.overlap_buffer.resize(buf_idx + 1, T::zero());
1372 self.window_energy.resize(buf_idx + 1, T::zero());
1373 }
1374
1375 match self.config.reconstruction_mode {
1376 ReconstructionMode::Ola => {
1377 self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1378 self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1379 }
1380 ReconstructionMode::Wola => {
1381 self.overlap_buffer[buf_idx] =
1382 self.overlap_buffer[buf_idx] + sample * self.window[i];
1383 self.window_energy[buf_idx] =
1384 self.window_energy[buf_idx] + self.window[i] * self.window[i];
1385 }
1386 }
1387 }
1388
1389 self.frames_processed += 1;
1390
1391 let ready_until = if self.frames_processed == 1 {
1394 0 } else {
1396 (self.frames_processed - 1) * self.config.hop_size
1398 };
1399
1400 let output_start = self.output_position;
1402 let output_end = ready_until;
1403 let initial_len = output.len();
1404
1405 let threshold = T::from(1e-8).unwrap();
1406 if output_end > output_start {
1407 for i in output_start..output_end {
1408 let normalized = if self.window_energy[i] > threshold {
1409 self.overlap_buffer[i] / self.window_energy[i]
1410 } else {
1411 T::zero()
1412 };
1413 output.push(normalized);
1414 }
1415 self.output_position = output_end;
1416 }
1417
1418 output.len() - initial_len
1419 }
1420
1421 pub fn flush(&mut self) -> Vec<T> {
1422 let mut output = Vec::new();
1424 let threshold = T::from(1e-8).unwrap();
1425 for i in self.output_position..self.overlap_buffer.len() {
1426 if self.window_energy[i] > threshold {
1427 output.push(self.overlap_buffer[i] / self.window_energy[i]);
1428 } else if i < (self.frames_processed * self.config.hop_size + self.config.fft_size) {
1429 output.push(T::zero()); } else {
1431 break; }
1433 }
1434
1435 let valid_end =
1437 (self.frames_processed.saturating_sub(1)) * self.config.hop_size + self.config.fft_size;
1438 if output.len() > valid_end - self.output_position {
1439 output.truncate(valid_end - self.output_position);
1440 }
1441
1442 self.reset();
1443 output
1444 }
1445
1446 pub fn reset(&mut self) {
1447 self.overlap_buffer.clear();
1448 self.overlap_buffer
1449 .resize(self.config.fft_size * 2, T::zero());
1450 self.window_energy.clear();
1451 self.window_energy
1452 .resize(self.config.fft_size * 2, T::zero());
1453 self.output_position = 0;
1454 self.frames_processed = 0;
1455 }
1456}
1457
1458pub struct MultiChannelStreamingIstft<T: Float + FftNum> {
1460 processors: Vec<StreamingIstft<T>>,
1461}
1462
1463impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingIstft<T> {
1464 pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1471 assert!(num_channels > 0, "num_channels must be > 0");
1472 let processors = (0..num_channels)
1473 .map(|_| StreamingIstft::new(config.clone()))
1474 .collect();
1475 Self { processors }
1476 }
1477
1478 pub fn push_frames(&mut self, frames: &[&SpectrumFrame<T>]) -> Vec<Vec<T>> {
1489 assert_eq!(
1490 frames.len(),
1491 self.processors.len(),
1492 "Expected {} channels, got {}",
1493 self.processors.len(),
1494 frames.len()
1495 );
1496
1497 #[cfg(feature = "rayon")]
1498 {
1499 use rayon::prelude::*;
1500 self.processors
1501 .par_iter_mut()
1502 .zip(frames.par_iter())
1503 .map(|(istft, frame)| istft.push_frame(frame))
1504 .collect()
1505 }
1506 #[cfg(not(feature = "rayon"))]
1507 {
1508 self.processors
1509 .iter_mut()
1510 .zip(frames.iter())
1511 .map(|(istft, frame)| istft.push_frame(frame))
1512 .collect()
1513 }
1514 }
1515
1516 pub fn flush(&mut self) -> Vec<Vec<T>> {
1518 #[cfg(feature = "rayon")]
1519 {
1520 use rayon::prelude::*;
1521 self.processors
1522 .par_iter_mut()
1523 .map(|istft| istft.flush())
1524 .collect()
1525 }
1526 #[cfg(not(feature = "rayon"))]
1527 {
1528 self.processors
1529 .iter_mut()
1530 .map(|istft| istft.flush())
1531 .collect()
1532 }
1533 }
1534
1535 pub fn reset(&mut self) {
1537 #[cfg(feature = "rayon")]
1538 {
1539 use rayon::prelude::*;
1540 self.processors
1541 .par_iter_mut()
1542 .for_each(|istft| istft.reset());
1543 }
1544 #[cfg(not(feature = "rayon"))]
1545 {
1546 self.processors.iter_mut().for_each(|istft| istft.reset());
1547 }
1548 }
1549
1550 pub fn num_channels(&self) -> usize {
1552 self.processors.len()
1553 }
1554}
1555
1556pub type StftConfigF32 = StftConfig<f32>;
1558pub type StftConfigF64 = StftConfig<f64>;
1559
1560pub type BatchStftF32 = BatchStft<f32>;
1561pub type BatchStftF64 = BatchStft<f64>;
1562
1563pub type BatchIstftF32 = BatchIstft<f32>;
1564pub type BatchIstftF64 = BatchIstft<f64>;
1565
1566pub type StreamingStftF32 = StreamingStft<f32>;
1567pub type StreamingStftF64 = StreamingStft<f64>;
1568
1569pub type StreamingIstftF32 = StreamingIstft<f32>;
1570pub type StreamingIstftF64 = StreamingIstft<f64>;
1571
1572pub type SpectrumF32 = Spectrum<f32>;
1573pub type SpectrumF64 = Spectrum<f64>;
1574
1575pub type SpectrumFrameF32 = SpectrumFrame<f32>;
1576pub type SpectrumFrameF64 = SpectrumFrame<f64>;
1577
1578pub type MultiChannelStreamingStftF32 = MultiChannelStreamingStft<f32>;
1579pub type MultiChannelStreamingStftF64 = MultiChannelStreamingStft<f64>;
1580
1581pub type MultiChannelStreamingIstftF32 = MultiChannelStreamingIstft<f32>;
1582pub type MultiChannelStreamingIstftF64 = MultiChannelStreamingIstft<f64>;