stft_rs/
lib.rs

1/*MIT License
2
3Copyright (c) 2025 David Maseda Neira
4
5Permission is hereby granted, free of charge, to any person obtaining a copy
6of this software and associated documentation files (the "Software"), to deal
7in the Software without restriction, including without limitation the rights
8to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9copies of the Software, and to permit persons to whom the Software is
10furnished to do so, subject to the following conditions:
11
12The above copyright notice and this permission notice shall be included in all
13copies or substantial portions of the Software.
14
15THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21SOFTWARE.
22*/
23
24use num_traits::{Float, FromPrimitive};
25use rustfft::num_complex::Complex;
26use rustfft::{Fft, FftNum, FftPlanner};
27use std::collections::VecDeque;
28use std::fmt;
29use std::sync::Arc;
30
31pub mod prelude {
32    pub use crate::{
33        BatchIstft, BatchIstftF32, BatchIstftF64, BatchStft, BatchStftF32, BatchStftF64, PadMode,
34        ReconstructionMode, Spectrum, SpectrumF32, SpectrumF64, SpectrumFrame, SpectrumFrameF32,
35        SpectrumFrameF64, StftConfig, StftConfigF32, StftConfigF64, StreamingIstft,
36        StreamingIstftF32, StreamingIstftF64, StreamingStft, StreamingStftF32, StreamingStftF64,
37        WindowType, apply_padding,
38    };
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum ReconstructionMode {
43    /// Overlap-Add: normalize by sum(w), requires COLA condition
44    Ola,
45
46    /// Weighted Overlap-Add: normalize by sum(w^2), requires NOLA condition
47    Wola,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum WindowType {
52    Hann,
53    Hamming,
54    Blackman,
55}
56
57#[derive(Debug, Clone)]
58pub enum ConfigError<T: Float + fmt::Debug> {
59    NolaViolation { min_energy: T, threshold: T },
60    ColaViolation { max_deviation: T, threshold: T },
61    InvalidHopSize,
62    InvalidFftSize,
63}
64
65impl<T: Float + fmt::Display + fmt::Debug> fmt::Display for ConfigError<T> {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        match self {
68            ConfigError::NolaViolation {
69                min_energy,
70                threshold,
71            } => {
72                write!(
73                    f,
74                    "NOLA condition violated: min_energy={} < threshold={}",
75                    min_energy, threshold
76                )
77            }
78            ConfigError::ColaViolation {
79                max_deviation,
80                threshold,
81            } => {
82                write!(
83                    f,
84                    "COLA condition violated: max_deviation={} > threshold={}",
85                    max_deviation, threshold
86                )
87            }
88            ConfigError::InvalidHopSize => write!(f, "Invalid hop size"),
89            ConfigError::InvalidFftSize => write!(f, "Invalid FFT size"),
90        }
91    }
92}
93
94impl<T: Float + fmt::Display + fmt::Debug> std::error::Error for ConfigError<T> {}
95
96#[derive(Debug, Clone, Copy)]
97pub enum PadMode {
98    Reflect,
99    Zero,
100    Edge,
101}
102
103#[derive(Clone)]
104pub struct StftConfig<T: Float> {
105    pub fft_size: usize,
106    pub hop_size: usize,
107    pub window: WindowType,
108    pub reconstruction_mode: ReconstructionMode,
109    _phantom: std::marker::PhantomData<T>,
110}
111
112impl<T: Float + FromPrimitive + fmt::Debug> StftConfig<T> {
113    fn nola_threshold() -> T {
114        T::from(1e-8).unwrap()
115    }
116
117    fn cola_relative_tolerance() -> T {
118        T::from(1e-4).unwrap()
119    }
120
121    pub fn new(
122        fft_size: usize,
123        hop_size: usize,
124        window: WindowType,
125        reconstruction_mode: ReconstructionMode,
126    ) -> Result<Self, ConfigError<T>> {
127        if fft_size == 0 || !fft_size.is_power_of_two() {
128            return Err(ConfigError::InvalidFftSize);
129        }
130        if hop_size == 0 || hop_size > fft_size {
131            return Err(ConfigError::InvalidHopSize);
132        }
133
134        let config = Self {
135            fft_size,
136            hop_size,
137            window,
138            reconstruction_mode,
139            _phantom: std::marker::PhantomData,
140        };
141
142        // Validate appropriate condition based on reconstruction mode
143        match reconstruction_mode {
144            ReconstructionMode::Ola => config.validate_cola()?,
145            ReconstructionMode::Wola => config.validate_nola()?,
146        }
147
148        Ok(config)
149    }
150
151    /// Default: 4096 FFT, 1024 hop, Hann window, OLA mode
152    pub fn default_4096() -> Self {
153        Self::new(4096, 1024, WindowType::Hann, ReconstructionMode::Ola)
154            .expect("Default config should always be valid")
155    }
156
157    pub fn freq_bins(&self) -> usize {
158        self.fft_size / 2 + 1
159    }
160
161    pub fn overlap_percent(&self) -> T {
162        let one = T::one();
163        let hundred = T::from(100.0).unwrap();
164        (one - T::from(self.hop_size).unwrap() / T::from(self.fft_size).unwrap()) * hundred
165    }
166
167    fn generate_window(&self) -> Vec<T> {
168        generate_window(self.window, self.fft_size)
169    }
170
171    /// Validate NOLA condition: sum(w^2) > threshold everywhere
172    pub fn validate_nola(&self) -> Result<(), ConfigError<T>> {
173        let window = self.generate_window();
174        let num_overlaps = (self.fft_size + self.hop_size - 1) / self.hop_size;
175        let test_len = self.fft_size + (num_overlaps - 1) * self.hop_size;
176        let mut energy = vec![T::zero(); test_len];
177
178        for i in 0..num_overlaps {
179            let offset = i * self.hop_size;
180            for j in 0..self.fft_size {
181                if offset + j < test_len {
182                    energy[offset + j] = energy[offset + j] + window[j] * window[j];
183                }
184            }
185        }
186
187        // Check the steady-state region (skip edges)
188        let start = self.fft_size / 2;
189        let end = test_len - self.fft_size / 2;
190        let min_energy = energy[start..end]
191            .iter()
192            .copied()
193            .min_by(|a, b| a.partial_cmp(b).unwrap())
194            .unwrap_or_else(T::zero);
195
196        if min_energy < Self::nola_threshold() {
197            return Err(ConfigError::NolaViolation {
198                min_energy,
199                threshold: Self::nola_threshold(),
200            });
201        }
202
203        Ok(())
204    }
205
206    /// Validate weak COLA condition: sum(w) is constant (within relative tolerance)
207    pub fn validate_cola(&self) -> Result<(), ConfigError<T>> {
208        let window = self.generate_window();
209        let window_len = window.len();
210
211        let mut cola_sum_period = vec![T::zero(); self.hop_size];
212        for i in 0..window_len {
213            let idx = i % self.hop_size;
214            cola_sum_period[idx] = cola_sum_period[idx] + window[i];
215        }
216
217        let zero = T::zero();
218        let min_sum = cola_sum_period
219            .iter()
220            .min_by(|a, b| a.partial_cmp(b).unwrap())
221            .unwrap_or(&zero);
222        let max_sum = cola_sum_period
223            .iter()
224            .max_by(|a, b| a.partial_cmp(b).unwrap())
225            .unwrap_or(&zero);
226
227        let epsilon = T::from(1e-9).unwrap();
228        if *max_sum < epsilon {
229            return Err(ConfigError::ColaViolation {
230                max_deviation: T::infinity(),
231                threshold: Self::cola_relative_tolerance(),
232            });
233        }
234
235        let ripple = (*max_sum - *min_sum) / *max_sum;
236
237        let is_compliant = ripple < Self::cola_relative_tolerance();
238
239        if !is_compliant {
240            return Err(ConfigError::ColaViolation {
241                max_deviation: ripple,
242                threshold: Self::cola_relative_tolerance(),
243            });
244        }
245        Ok(())
246    }
247}
248
249fn generate_window<T: Float + FromPrimitive>(window_type: WindowType, size: usize) -> Vec<T> {
250    let pi = T::from(std::f64::consts::PI).unwrap();
251    let two = T::from(2.0).unwrap();
252
253    match window_type {
254        WindowType::Hann => (0..size)
255            .map(|i| {
256                let half = T::from(0.5).unwrap();
257                let one = T::one();
258                let i_t = T::from(i).unwrap();
259                let size_m1 = T::from(size - 1).unwrap();
260                half * (one - (two * pi * i_t / size_m1).cos())
261            })
262            .collect(),
263        WindowType::Hamming => (0..size)
264            .map(|i| {
265                let i_t = T::from(i).unwrap();
266                let size_m1 = T::from(size - 1).unwrap();
267                T::from(0.54).unwrap() - T::from(0.46).unwrap() * (two * pi * i_t / size_m1).cos()
268            })
269            .collect(),
270        WindowType::Blackman => (0..size)
271            .map(|i| {
272                let i_t = T::from(i).unwrap();
273                let size_m1 = T::from(size - 1).unwrap();
274                let angle = two * pi * i_t / size_m1;
275                T::from(0.42).unwrap() - T::from(0.5).unwrap() * angle.cos()
276                    + T::from(0.08).unwrap() * (two * angle).cos()
277            })
278            .collect(),
279    }
280}
281
282#[derive(Clone)]
283pub struct SpectrumFrame<T: Float> {
284    pub freq_bins: usize,
285    pub data: Vec<Complex<T>>,
286}
287
288impl<T: Float> SpectrumFrame<T> {
289    pub fn new(freq_bins: usize) -> Self {
290        Self {
291            freq_bins,
292            data: vec![Complex::new(T::zero(), T::zero()); freq_bins],
293        }
294    }
295
296    pub fn from_data(data: Vec<Complex<T>>) -> Self {
297        let freq_bins = data.len();
298        Self { freq_bins, data }
299    }
300
301    /// Prepare frame for reuse by clearing data (keeps capacity)
302    pub fn clear(&mut self) {
303        for val in &mut self.data {
304            *val = Complex::new(T::zero(), T::zero());
305        }
306    }
307
308    /// Resize frame if needed to match freq_bins
309    pub fn resize_if_needed(&mut self, freq_bins: usize) {
310        if self.freq_bins != freq_bins {
311            self.freq_bins = freq_bins;
312            self.data
313                .resize(freq_bins, Complex::new(T::zero(), T::zero()));
314        }
315    }
316
317    /// Write data from a slice into this frame
318    pub fn write_from_slice(&mut self, data: &[Complex<T>]) {
319        self.resize_if_needed(data.len());
320        self.data.copy_from_slice(data);
321    }
322
323    /// Get the magnitude of a frequency bin
324    #[inline]
325    pub fn magnitude(&self, bin: usize) -> T {
326        let c = &self.data[bin];
327        (c.re * c.re + c.im * c.im).sqrt()
328    }
329
330    /// Get the phase of a frequency bin in radians
331    #[inline]
332    pub fn phase(&self, bin: usize) -> T {
333        let c = &self.data[bin];
334        c.im.atan2(c.re)
335    }
336
337    /// Set a frequency bin from magnitude and phase
338    pub fn set_magnitude_phase(&mut self, bin: usize, magnitude: T, phase: T) {
339        self.data[bin] = Complex::new(magnitude * phase.cos(), magnitude * phase.sin());
340    }
341
342    /// Create a SpectrumFrame from magnitude and phase arrays
343    pub fn from_magnitude_phase(magnitudes: &[T], phases: &[T]) -> Self {
344        assert_eq!(
345            magnitudes.len(),
346            phases.len(),
347            "Magnitude and phase arrays must have same length"
348        );
349        let freq_bins = magnitudes.len();
350        let data: Vec<Complex<T>> = magnitudes
351            .iter()
352            .zip(phases.iter())
353            .map(|(mag, phase)| Complex::new(*mag * phase.cos(), *mag * phase.sin()))
354            .collect();
355        Self { freq_bins, data }
356    }
357
358    /// Get all magnitudes as a Vec
359    pub fn magnitudes(&self) -> Vec<T> {
360        self.data
361            .iter()
362            .map(|c| (c.re * c.re + c.im * c.im).sqrt())
363            .collect()
364    }
365
366    /// Get all phases as a Vec
367    pub fn phases(&self) -> Vec<T> {
368        self.data.iter().map(|c| c.im.atan2(c.re)).collect()
369    }
370}
371
372#[derive(Clone)]
373pub struct Spectrum<T: Float> {
374    pub num_frames: usize,
375    pub freq_bins: usize,
376    pub data: Vec<T>,
377}
378
379impl<T: Float> Spectrum<T> {
380    pub fn new(num_frames: usize, freq_bins: usize) -> Self {
381        Self {
382            num_frames,
383            freq_bins,
384            data: vec![T::zero(); 2 * num_frames * freq_bins],
385        }
386    }
387
388    #[inline]
389    pub fn real(&self, frame: usize, bin: usize) -> T {
390        self.data[frame * self.freq_bins + bin]
391    }
392
393    #[inline]
394    pub fn imag(&self, frame: usize, bin: usize) -> T {
395        let offset = self.num_frames * self.freq_bins;
396        self.data[offset + frame * self.freq_bins + bin]
397    }
398
399    #[inline]
400    pub fn get_complex(&self, frame: usize, bin: usize) -> Complex<T> {
401        Complex::new(self.real(frame, bin), self.imag(frame, bin))
402    }
403
404    pub fn frames(&self) -> impl Iterator<Item = SpectrumFrame<T>> + '_ {
405        (0..self.num_frames).map(move |frame_idx| {
406            let data: Vec<Complex<T>> = (0..self.freq_bins)
407                .map(|bin| self.get_complex(frame_idx, bin))
408                .collect();
409            SpectrumFrame::from_data(data)
410        })
411    }
412
413    /// Set the real part of a bin
414    #[inline]
415    pub fn set_real(&mut self, frame: usize, bin: usize, value: T) {
416        self.data[frame * self.freq_bins + bin] = value;
417    }
418
419    /// Set the imaginary part of a bin
420    #[inline]
421    pub fn set_imag(&mut self, frame: usize, bin: usize, value: T) {
422        let offset = self.num_frames * self.freq_bins;
423        self.data[offset + frame * self.freq_bins + bin] = value;
424    }
425
426    /// Set a bin from a complex value
427    #[inline]
428    pub fn set_complex(&mut self, frame: usize, bin: usize, value: Complex<T>) {
429        self.set_real(frame, bin, value.re);
430        self.set_imag(frame, bin, value.im);
431    }
432
433    /// Get the magnitude of a frequency bin
434    #[inline]
435    pub fn magnitude(&self, frame: usize, bin: usize) -> T {
436        let re = self.real(frame, bin);
437        let im = self.imag(frame, bin);
438        (re * re + im * im).sqrt()
439    }
440
441    /// Get the phase of a frequency bin in radians
442    #[inline]
443    pub fn phase(&self, frame: usize, bin: usize) -> T {
444        let re = self.real(frame, bin);
445        let im = self.imag(frame, bin);
446        im.atan2(re)
447    }
448
449    /// Set a frequency bin from magnitude and phase
450    pub fn set_magnitude_phase(&mut self, frame: usize, bin: usize, magnitude: T, phase: T) {
451        self.set_real(frame, bin, magnitude * phase.cos());
452        self.set_imag(frame, bin, magnitude * phase.sin());
453    }
454
455    /// Get all magnitudes for a frame
456    pub fn frame_magnitudes(&self, frame: usize) -> Vec<T> {
457        (0..self.freq_bins)
458            .map(|bin| self.magnitude(frame, bin))
459            .collect()
460    }
461
462    /// Get all phases for a frame
463    pub fn frame_phases(&self, frame: usize) -> Vec<T> {
464        (0..self.freq_bins)
465            .map(|bin| self.phase(frame, bin))
466            .collect()
467    }
468
469    /// Apply a function to all bins
470    pub fn apply<F>(&mut self, mut f: F)
471    where
472        F: FnMut(usize, usize, Complex<T>) -> Complex<T>,
473    {
474        for frame in 0..self.num_frames {
475            for bin in 0..self.freq_bins {
476                let c = self.get_complex(frame, bin);
477                let new_c = f(frame, bin, c);
478                self.set_complex(frame, bin, new_c);
479            }
480        }
481    }
482
483    /// Apply a gain to a range of bins across all frames
484    pub fn apply_gain(&mut self, bin_range: std::ops::Range<usize>, gain: T) {
485        for frame in 0..self.num_frames {
486            for bin in bin_range.clone() {
487                if bin < self.freq_bins {
488                    let c = self.get_complex(frame, bin);
489                    self.set_complex(frame, bin, c * gain);
490                }
491            }
492        }
493    }
494
495    /// Zero out a range of bins across all frames
496    pub fn zero_bins(&mut self, bin_range: std::ops::Range<usize>) {
497        for frame in 0..self.num_frames {
498            for bin in bin_range.clone() {
499                if bin < self.freq_bins {
500                    self.set_complex(frame, bin, Complex::new(T::zero(), T::zero()));
501                }
502            }
503        }
504    }
505}
506
507pub struct BatchStft<T: Float + FftNum> {
508    config: StftConfig<T>,
509    window: Vec<T>,
510    fft: Arc<dyn Fft<T>>,
511}
512
513impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchStft<T> {
514    pub fn new(config: StftConfig<T>) -> Self {
515        let window = config.generate_window();
516        let mut planner = FftPlanner::new();
517        let fft = planner.plan_fft_forward(config.fft_size);
518
519        Self {
520            config,
521            window,
522            fft,
523        }
524    }
525
526    pub fn process(&self, signal: &[T]) -> Spectrum<T> {
527        self.process_padded(signal, PadMode::Reflect)
528    }
529
530    pub fn process_padded(&self, signal: &[T], pad_mode: PadMode) -> Spectrum<T> {
531        let pad_amount = self.config.fft_size / 2;
532        let padded = apply_padding(signal, pad_amount, pad_mode);
533
534        let num_frames = if padded.len() >= self.config.fft_size {
535            (padded.len() - self.config.fft_size) / self.config.hop_size + 1
536        } else {
537            0
538        };
539
540        let freq_bins = self.config.freq_bins();
541        let mut result = Spectrum::new(num_frames, freq_bins);
542
543        let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
544
545        for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
546            .step_by(self.config.hop_size)
547            .enumerate()
548        {
549            // Apply window and prepare FFT input
550            for i in 0..self.config.fft_size {
551                fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
552            }
553
554            // Compute FFT
555            self.fft.process(&mut fft_buffer);
556
557            // Store positive frequencies in flat layout
558            for bin in 0..freq_bins {
559                let idx = frame_idx * freq_bins + bin;
560                result.data[idx] = fft_buffer[bin].re;
561                result.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
562            }
563        }
564
565        result
566    }
567
568    /// Process signal and write into a pre-allocated Spectrum.
569    /// The spectrum must have the correct dimensions (num_frames x freq_bins).
570    /// Returns true if successful, false if dimensions don't match.
571    pub fn process_into(&self, signal: &[T], spectrum: &mut Spectrum<T>) -> bool {
572        self.process_padded_into(signal, PadMode::Reflect, spectrum)
573    }
574
575    /// Process signal with padding and write into a pre-allocated Spectrum.
576    pub fn process_padded_into(
577        &self,
578        signal: &[T],
579        pad_mode: PadMode,
580        spectrum: &mut Spectrum<T>,
581    ) -> bool {
582        let pad_amount = self.config.fft_size / 2;
583        let padded = apply_padding(signal, pad_amount, pad_mode);
584
585        let num_frames = if padded.len() >= self.config.fft_size {
586            (padded.len() - self.config.fft_size) / self.config.hop_size + 1
587        } else {
588            0
589        };
590
591        let freq_bins = self.config.freq_bins();
592
593        // Check dimensions
594        if spectrum.num_frames != num_frames || spectrum.freq_bins != freq_bins {
595            return false;
596        }
597
598        let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
599
600        for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
601            .step_by(self.config.hop_size)
602            .enumerate()
603        {
604            // Apply window and prepare FFT input
605            for i in 0..self.config.fft_size {
606                fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
607            }
608
609            // Compute FFT
610            self.fft.process(&mut fft_buffer);
611
612            // Store positive frequencies in flat layout
613            for bin in 0..freq_bins {
614                let idx = frame_idx * freq_bins + bin;
615                spectrum.data[idx] = fft_buffer[bin].re;
616                spectrum.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
617            }
618        }
619
620        true
621    }
622}
623
624pub struct BatchIstft<T: Float + FftNum> {
625    config: StftConfig<T>,
626    window: Vec<T>,
627    ifft: Arc<dyn Fft<T>>,
628}
629
630impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchIstft<T> {
631    pub fn new(config: StftConfig<T>) -> Self {
632        let window = config.generate_window();
633        let mut planner = FftPlanner::new();
634        let ifft = planner.plan_fft_inverse(config.fft_size);
635
636        Self {
637            config,
638            window,
639            ifft,
640        }
641    }
642
643    pub fn process(&self, spectrum: &Spectrum<T>) -> Vec<T> {
644        assert_eq!(
645            spectrum.freq_bins,
646            self.config.freq_bins(),
647            "Frequency bins mismatch"
648        );
649
650        let num_frames = spectrum.num_frames;
651        let original_time_len = (num_frames - 1) * self.config.hop_size;
652        let pad_amount = self.config.fft_size / 2;
653        let padded_len = original_time_len + 2 * pad_amount;
654
655        let mut overlap_buffer = vec![T::zero(); padded_len];
656        let mut window_energy = vec![T::zero(); padded_len];
657        let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
658
659        // Precompute window energy normalization
660        for frame_idx in 0..num_frames {
661            let pos = frame_idx * self.config.hop_size;
662            for i in 0..self.config.fft_size {
663                match self.config.reconstruction_mode {
664                    ReconstructionMode::Ola => {
665                        window_energy[pos + i] = window_energy[pos + i] + self.window[i];
666                    }
667                    ReconstructionMode::Wola => {
668                        window_energy[pos + i] =
669                            window_energy[pos + i] + self.window[i] * self.window[i];
670                    }
671                }
672            }
673        }
674
675        // Process each frame
676        for frame_idx in 0..num_frames {
677            // Build full spectrum with conjugate symmetry
678            for bin in 0..spectrum.freq_bins {
679                ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
680            }
681
682            // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
683            for bin in 1..(spectrum.freq_bins - 1) {
684                ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
685            }
686
687            // Compute IFFT
688            self.ifft.process(&mut ifft_buffer);
689
690            // Overlap-add
691            let pos = frame_idx * self.config.hop_size;
692            for i in 0..self.config.fft_size {
693                let fft_size_t = T::from(self.config.fft_size).unwrap();
694                let sample = ifft_buffer[i].re / fft_size_t;
695
696                match self.config.reconstruction_mode {
697                    ReconstructionMode::Ola => {
698                        // OLA: no windowing on inverse
699                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
700                    }
701                    ReconstructionMode::Wola => {
702                        // WOLA: apply window on inverse
703                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
704                    }
705                }
706            }
707        }
708
709        // Normalize by window energy
710        let threshold = T::from(1e-8).unwrap();
711        for i in 0..padded_len {
712            if window_energy[i] > threshold {
713                overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
714            }
715        }
716
717        // Remove padding
718        overlap_buffer[pad_amount..pad_amount + original_time_len].to_vec()
719    }
720
721    /// Process spectrum and write into a pre-allocated output buffer.
722    /// The output buffer will be resized if needed.
723    pub fn process_into(&self, spectrum: &Spectrum<T>, output: &mut Vec<T>) {
724        assert_eq!(
725            spectrum.freq_bins,
726            self.config.freq_bins(),
727            "Frequency bins mismatch"
728        );
729
730        let num_frames = spectrum.num_frames;
731        let original_time_len = (num_frames - 1) * self.config.hop_size;
732        let pad_amount = self.config.fft_size / 2;
733        let padded_len = original_time_len + 2 * pad_amount;
734
735        let mut overlap_buffer = vec![T::zero(); padded_len];
736        let mut window_energy = vec![T::zero(); padded_len];
737        let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
738
739        // Precompute window energy normalization
740        for frame_idx in 0..num_frames {
741            let pos = frame_idx * self.config.hop_size;
742            for i in 0..self.config.fft_size {
743                match self.config.reconstruction_mode {
744                    ReconstructionMode::Ola => {
745                        window_energy[pos + i] = window_energy[pos + i] + self.window[i];
746                    }
747                    ReconstructionMode::Wola => {
748                        window_energy[pos + i] =
749                            window_energy[pos + i] + self.window[i] * self.window[i];
750                    }
751                }
752            }
753        }
754
755        // Process each frame
756        for frame_idx in 0..num_frames {
757            // Build full spectrum with conjugate symmetry
758            for bin in 0..spectrum.freq_bins {
759                ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
760            }
761
762            // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
763            for bin in 1..(spectrum.freq_bins - 1) {
764                ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
765            }
766
767            // Compute IFFT
768            self.ifft.process(&mut ifft_buffer);
769
770            // Overlap-add
771            let pos = frame_idx * self.config.hop_size;
772            for i in 0..self.config.fft_size {
773                let fft_size_t = T::from(self.config.fft_size).unwrap();
774                let sample = ifft_buffer[i].re / fft_size_t;
775
776                match self.config.reconstruction_mode {
777                    ReconstructionMode::Ola => {
778                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
779                    }
780                    ReconstructionMode::Wola => {
781                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
782                    }
783                }
784            }
785        }
786
787        // Normalize by window energy
788        let threshold = T::from(1e-8).unwrap();
789        for i in 0..padded_len {
790            if window_energy[i] > threshold {
791                overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
792            }
793        }
794
795        // Copy to output (resize if needed)
796        output.clear();
797        output.extend_from_slice(&overlap_buffer[pad_amount..pad_amount + original_time_len]);
798    }
799}
800
801pub struct StreamingStft<T: Float + FftNum> {
802    config: StftConfig<T>,
803    window: Vec<T>,
804    fft: Arc<dyn Fft<T>>,
805    input_buffer: VecDeque<T>,
806    frame_index: usize,
807    fft_buffer: Vec<Complex<T>>,
808}
809
810impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingStft<T> {
811    pub fn new(config: StftConfig<T>) -> Self {
812        let window = config.generate_window();
813        let mut planner = FftPlanner::new();
814        let fft = planner.plan_fft_forward(config.fft_size);
815        let fft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
816
817        Self {
818            config,
819            window,
820            fft,
821            input_buffer: VecDeque::new(),
822            frame_index: 0,
823            fft_buffer,
824        }
825    }
826
827    pub fn push_samples(&mut self, samples: &[T]) -> Vec<SpectrumFrame<T>> {
828        self.input_buffer.extend(samples.iter().copied());
829
830        let mut frames = Vec::new();
831
832        while self.input_buffer.len() >= self.config.fft_size {
833            // Process one frame
834            for i in 0..self.config.fft_size {
835                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
836            }
837
838            self.fft.process(&mut self.fft_buffer);
839
840            let freq_bins = self.config.freq_bins();
841            let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
842            frames.push(SpectrumFrame::from_data(data));
843
844            // Advance by hop size
845            self.input_buffer.drain(..self.config.hop_size);
846            self.frame_index += 1;
847        }
848
849        frames
850    }
851
852    /// Push samples and write frames into a pre-allocated buffer.
853    /// Returns the number of frames written.
854    pub fn push_samples_into(
855        &mut self,
856        samples: &[T],
857        output: &mut Vec<SpectrumFrame<T>>,
858    ) -> usize {
859        self.input_buffer.extend(samples.iter().copied());
860
861        let initial_len = output.len();
862
863        while self.input_buffer.len() >= self.config.fft_size {
864            // Process one frame
865            for i in 0..self.config.fft_size {
866                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
867            }
868
869            self.fft.process(&mut self.fft_buffer);
870
871            let freq_bins = self.config.freq_bins();
872            let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
873            output.push(SpectrumFrame::from_data(data));
874
875            // Advance by hop size
876            self.input_buffer.drain(..self.config.hop_size);
877            self.frame_index += 1;
878        }
879
880        output.len() - initial_len
881    }
882
883    /// Push samples and write directly into pre-existing SpectrumFrame buffers.
884    /// This is a zero-allocation method - frames must be pre-allocated with correct size.
885    /// Returns the number of frames written.
886    ///
887    /// # Example
888    /// ```ignore
889    /// let mut frame_pool = vec![SpectrumFrame::new(config.freq_bins()); 16];
890    /// let mut frame_index = 0;
891    ///
892    /// let frames_written = stft.push_samples_write(chunk, &mut frame_pool, &mut frame_index);
893    /// // Process frames 0..frames_written
894    /// ```
895    pub fn push_samples_write(
896        &mut self,
897        samples: &[T],
898        frame_pool: &mut [SpectrumFrame<T>],
899        pool_index: &mut usize,
900    ) -> usize {
901        self.input_buffer.extend(samples.iter().copied());
902
903        let initial_index = *pool_index;
904        let freq_bins = self.config.freq_bins();
905
906        while self.input_buffer.len() >= self.config.fft_size && *pool_index < frame_pool.len() {
907            // Process one frame
908            for i in 0..self.config.fft_size {
909                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
910            }
911
912            self.fft.process(&mut self.fft_buffer);
913
914            // Write directly into the pre-allocated frame
915            let frame = &mut frame_pool[*pool_index];
916            debug_assert_eq!(
917                frame.freq_bins, freq_bins,
918                "Frame pool frames must match freq_bins"
919            );
920            frame.data[..freq_bins].copy_from_slice(&self.fft_buffer[..freq_bins]);
921
922            // Advance by hop size
923            self.input_buffer.drain(..self.config.hop_size);
924            self.frame_index += 1;
925            *pool_index += 1;
926        }
927
928        *pool_index - initial_index
929    }
930
931    pub fn flush(&mut self) -> Vec<SpectrumFrame<T>> {
932        // For streaming, we typically don't process partial frames
933        // Could zero-pad if needed, but that changes the signal
934        Vec::new()
935    }
936
937    pub fn reset(&mut self) {
938        self.input_buffer.clear();
939        self.frame_index = 0;
940    }
941
942    pub fn buffered_samples(&self) -> usize {
943        self.input_buffer.len()
944    }
945}
946
947pub struct StreamingIstft<T: Float + FftNum> {
948    config: StftConfig<T>,
949    window: Vec<T>,
950    ifft: Arc<dyn Fft<T>>,
951    overlap_buffer: Vec<T>,
952    window_energy: Vec<T>,
953    output_position: usize,
954    frames_processed: usize,
955    ifft_buffer: Vec<Complex<T>>,
956}
957
958impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingIstft<T> {
959    pub fn new(config: StftConfig<T>) -> Self {
960        let window = config.generate_window();
961        let mut planner = FftPlanner::new();
962        let ifft = planner.plan_fft_inverse(config.fft_size);
963
964        // Buffer needs to hold enough samples for full overlap
965        // For proper reconstruction, need at least fft_size samples
966        let buffer_size = config.fft_size * 2;
967        let ifft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
968
969        Self {
970            config,
971            window,
972            ifft,
973            overlap_buffer: vec![T::zero(); buffer_size],
974            window_energy: vec![T::zero(); buffer_size],
975            output_position: 0,
976            frames_processed: 0,
977            ifft_buffer,
978        }
979    }
980
981    pub fn push_frame(&mut self, frame: &SpectrumFrame<T>) -> Vec<T> {
982        assert_eq!(
983            frame.freq_bins,
984            self.config.freq_bins(),
985            "Frequency bins mismatch"
986        );
987
988        // Build full spectrum with conjugate symmetry
989        for bin in 0..frame.freq_bins {
990            self.ifft_buffer[bin] = frame.data[bin];
991        }
992
993        // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
994        for bin in 1..(frame.freq_bins - 1) {
995            self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
996        }
997
998        // Compute IFFT
999        self.ifft.process(&mut self.ifft_buffer);
1000
1001        // Overlap-add into buffer at the current write position
1002        let write_pos = self.frames_processed * self.config.hop_size;
1003        for i in 0..self.config.fft_size {
1004            let fft_size_t = T::from(self.config.fft_size).unwrap();
1005            let sample = self.ifft_buffer[i].re / fft_size_t;
1006            let buf_idx = write_pos + i;
1007
1008            // Extend buffers if needed
1009            if buf_idx >= self.overlap_buffer.len() {
1010                self.overlap_buffer.resize(buf_idx + 1, T::zero());
1011                self.window_energy.resize(buf_idx + 1, T::zero());
1012            }
1013
1014            match self.config.reconstruction_mode {
1015                ReconstructionMode::Ola => {
1016                    self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1017                    self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1018                }
1019                ReconstructionMode::Wola => {
1020                    self.overlap_buffer[buf_idx] =
1021                        self.overlap_buffer[buf_idx] + sample * self.window[i];
1022                    self.window_energy[buf_idx] =
1023                        self.window_energy[buf_idx] + self.window[i] * self.window[i];
1024                }
1025            }
1026        }
1027
1028        self.frames_processed += 1;
1029
1030        // Calculate how many samples are "ready" (have full window energy)
1031        // Samples are ready when no future frames will contribute to them
1032        let ready_until = if self.frames_processed == 1 {
1033            0 // First frame: no output yet, need overlap
1034        } else {
1035            // Samples before the current frame's start position are complete
1036            (self.frames_processed - 1) * self.config.hop_size
1037        };
1038
1039        // Extract ready samples
1040        let output_start = self.output_position;
1041        let output_end = ready_until;
1042        let mut output = Vec::new();
1043
1044        let threshold = T::from(1e-8).unwrap();
1045        if output_end > output_start {
1046            for i in output_start..output_end {
1047                let normalized = if self.window_energy[i] > threshold {
1048                    self.overlap_buffer[i] / self.window_energy[i]
1049                } else {
1050                    T::zero()
1051                };
1052                output.push(normalized);
1053            }
1054            self.output_position = output_end;
1055        }
1056
1057        output
1058    }
1059
1060    /// Push a frame and write output samples into a pre-allocated buffer.
1061    /// Returns the number of samples written.
1062    pub fn push_frame_into(&mut self, frame: &SpectrumFrame<T>, output: &mut Vec<T>) -> usize {
1063        assert_eq!(
1064            frame.freq_bins,
1065            self.config.freq_bins(),
1066            "Frequency bins mismatch"
1067        );
1068
1069        // Build full spectrum with conjugate symmetry
1070        for bin in 0..frame.freq_bins {
1071            self.ifft_buffer[bin] = frame.data[bin];
1072        }
1073
1074        // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
1075        for bin in 1..(frame.freq_bins - 1) {
1076            self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1077        }
1078
1079        // Compute IFFT
1080        self.ifft.process(&mut self.ifft_buffer);
1081
1082        // Overlap-add into buffer at the current write position
1083        let write_pos = self.frames_processed * self.config.hop_size;
1084        for i in 0..self.config.fft_size {
1085            let fft_size_t = T::from(self.config.fft_size).unwrap();
1086            let sample = self.ifft_buffer[i].re / fft_size_t;
1087            let buf_idx = write_pos + i;
1088
1089            // Extend buffers if needed
1090            if buf_idx >= self.overlap_buffer.len() {
1091                self.overlap_buffer.resize(buf_idx + 1, T::zero());
1092                self.window_energy.resize(buf_idx + 1, T::zero());
1093            }
1094
1095            match self.config.reconstruction_mode {
1096                ReconstructionMode::Ola => {
1097                    self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1098                    self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1099                }
1100                ReconstructionMode::Wola => {
1101                    self.overlap_buffer[buf_idx] =
1102                        self.overlap_buffer[buf_idx] + sample * self.window[i];
1103                    self.window_energy[buf_idx] =
1104                        self.window_energy[buf_idx] + self.window[i] * self.window[i];
1105                }
1106            }
1107        }
1108
1109        self.frames_processed += 1;
1110
1111        // Calculate how many samples are "ready" (have full window energy)
1112        // Samples are ready when no future frames will contribute to them
1113        let ready_until = if self.frames_processed == 1 {
1114            0 // First frame: no output yet, need overlap
1115        } else {
1116            // Samples before the current frame's start position are complete
1117            (self.frames_processed - 1) * self.config.hop_size
1118        };
1119
1120        // Extract ready samples
1121        let output_start = self.output_position;
1122        let output_end = ready_until;
1123        let initial_len = output.len();
1124
1125        let threshold = T::from(1e-8).unwrap();
1126        if output_end > output_start {
1127            for i in output_start..output_end {
1128                let normalized = if self.window_energy[i] > threshold {
1129                    self.overlap_buffer[i] / self.window_energy[i]
1130                } else {
1131                    T::zero()
1132                };
1133                output.push(normalized);
1134            }
1135            self.output_position = output_end;
1136        }
1137
1138        output.len() - initial_len
1139    }
1140
1141    pub fn flush(&mut self) -> Vec<T> {
1142        // Return all remaining samples in buffer
1143        let mut output = Vec::new();
1144        let threshold = T::from(1e-8).unwrap();
1145        for i in self.output_position..self.overlap_buffer.len() {
1146            if self.window_energy[i] > threshold {
1147                output.push(self.overlap_buffer[i] / self.window_energy[i]);
1148            } else if i < (self.frames_processed * self.config.hop_size + self.config.fft_size) {
1149                output.push(T::zero()); // Sample in valid range but no window energy
1150            } else {
1151                break; // Past the end of valid data
1152            }
1153        }
1154
1155        // Determine the actual end of valid data
1156        let valid_end =
1157            (self.frames_processed.saturating_sub(1)) * self.config.hop_size + self.config.fft_size;
1158        if output.len() > valid_end - self.output_position {
1159            output.truncate(valid_end - self.output_position);
1160        }
1161
1162        self.reset();
1163        output
1164    }
1165
1166    pub fn reset(&mut self) {
1167        self.overlap_buffer.clear();
1168        self.overlap_buffer
1169            .resize(self.config.fft_size * 2, T::zero());
1170        self.window_energy.clear();
1171        self.window_energy
1172            .resize(self.config.fft_size * 2, T::zero());
1173        self.output_position = 0;
1174        self.frames_processed = 0;
1175    }
1176}
1177
1178/// Apply padding to a signal.
1179/// Streaming applications should pad manually to match batch processing quality.
1180pub fn apply_padding<T: Float>(signal: &[T], pad_amount: usize, mode: PadMode) -> Vec<T> {
1181    let total_len = signal.len() + 2 * pad_amount;
1182    let mut padded = vec![T::zero(); total_len];
1183
1184    padded[pad_amount..pad_amount + signal.len()].copy_from_slice(signal);
1185
1186    match mode {
1187        PadMode::Reflect => {
1188            for i in 0..pad_amount {
1189                if i + 1 < signal.len() {
1190                    padded[pad_amount - 1 - i] = signal[i + 1];
1191                }
1192            }
1193
1194            let n = signal.len();
1195            for i in 0..pad_amount {
1196                if n >= 2 && n - 2 >= i {
1197                    padded[pad_amount + n + i] = signal[n - 2 - i];
1198                }
1199            }
1200        }
1201        PadMode::Zero => {}
1202        PadMode::Edge => {
1203            if !signal.is_empty() {
1204                for i in 0..pad_amount {
1205                    padded[i] = signal[0];
1206                }
1207                for i in 0..pad_amount {
1208                    padded[pad_amount + signal.len() + i] = signal[signal.len() - 1];
1209                }
1210            }
1211        }
1212    }
1213
1214    padded
1215}
1216
1217// Type aliases for common float types
1218pub type StftConfigF32 = StftConfig<f32>;
1219pub type StftConfigF64 = StftConfig<f64>;
1220
1221pub type BatchStftF32 = BatchStft<f32>;
1222pub type BatchStftF64 = BatchStft<f64>;
1223
1224pub type BatchIstftF32 = BatchIstft<f32>;
1225pub type BatchIstftF64 = BatchIstft<f64>;
1226
1227pub type StreamingStftF32 = StreamingStft<f32>;
1228pub type StreamingStftF64 = StreamingStft<f64>;
1229
1230pub type StreamingIstftF32 = StreamingIstft<f32>;
1231pub type StreamingIstftF64 = StreamingIstft<f64>;
1232
1233pub type SpectrumF32 = Spectrum<f32>;
1234pub type SpectrumF64 = Spectrum<f64>;
1235
1236pub type SpectrumFrameF32 = SpectrumFrame<f32>;
1237pub type SpectrumFrameF64 = SpectrumFrame<f64>;