stft_rs/
lib.rs

1/*MIT License
2
3Copyright (c) 2025 David Maseda Neira
4
5Permission is hereby granted, free of charge, to any person obtaining a copy
6of this software and associated documentation files (the "Software"), to deal
7in the Software without restriction, including without limitation the rights
8to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9copies of the Software, and to permit persons to whom the Software is
10furnished to do so, subject to the following conditions:
11
12The above copyright notice and this permission notice shall be included in all
13copies or substantial portions of the Software.
14
15THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21SOFTWARE.
22*/
23
24use num_traits::{Float, FromPrimitive};
25use rustfft::num_complex::Complex;
26use rustfft::{Fft, FftNum, FftPlanner};
27use std::collections::VecDeque;
28use std::fmt;
29use std::sync::Arc;
30
31mod utils;
32pub use utils::{apply_padding, deinterleave, deinterleave_into, interleave, interleave_into};
33
34pub mod mel;
35
36pub mod prelude {
37    pub use crate::mel::{
38        BatchMelSpectrogram, BatchMelSpectrogramF32, BatchMelSpectrogramF64, MelConfig,
39        MelConfigF32, MelConfigF64, MelFilterbank, MelFilterbankF32, MelFilterbankF64, MelNorm,
40        MelScale, MelSpectrum, MelSpectrumF32, MelSpectrumF64, StreamingMelSpectrogram,
41        StreamingMelSpectrogramF32, StreamingMelSpectrogramF64,
42    };
43    pub use crate::utils::{
44        apply_padding, deinterleave, deinterleave_into, interleave, interleave_into,
45    };
46    pub use crate::{
47        BatchIstft, BatchIstftF32, BatchIstftF64, BatchStft, BatchStftF32, BatchStftF64,
48        MultiChannelStreamingIstft, MultiChannelStreamingIstftF32, MultiChannelStreamingIstftF64,
49        MultiChannelStreamingStft, MultiChannelStreamingStftF32, MultiChannelStreamingStftF64,
50        PadMode, ReconstructionMode, Spectrum, SpectrumF32, SpectrumF64, SpectrumFrame,
51        SpectrumFrameF32, SpectrumFrameF64, StftConfig, StftConfigBuilder, StftConfigBuilderF32,
52        StftConfigBuilderF64, StftConfigF32, StftConfigF64, StreamingIstft, StreamingIstftF32,
53        StreamingIstftF64, StreamingStft, StreamingStftF32, StreamingStftF64, WindowType,
54    };
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum ReconstructionMode {
59    /// Overlap-Add: normalize by sum(w), requires COLA condition
60    Ola,
61
62    /// Weighted Overlap-Add: normalize by sum(w^2), requires NOLA condition
63    Wola,
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum WindowType {
68    Hann,
69    Hamming,
70    Blackman,
71}
72
73#[derive(Debug, Clone)]
74pub enum ConfigError<T: Float + fmt::Debug> {
75    NolaViolation { min_energy: T, threshold: T },
76    ColaViolation { max_deviation: T, threshold: T },
77    InvalidHopSize,
78    InvalidFftSize,
79}
80
81impl<T: Float + fmt::Display + fmt::Debug> fmt::Display for ConfigError<T> {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        match self {
84            ConfigError::NolaViolation {
85                min_energy,
86                threshold,
87            } => {
88                write!(
89                    f,
90                    "NOLA condition violated: min_energy={} < threshold={}",
91                    min_energy, threshold
92                )
93            }
94            ConfigError::ColaViolation {
95                max_deviation,
96                threshold,
97            } => {
98                write!(
99                    f,
100                    "COLA condition violated: max_deviation={} > threshold={}",
101                    max_deviation, threshold
102                )
103            }
104            ConfigError::InvalidHopSize => write!(f, "Invalid hop size"),
105            ConfigError::InvalidFftSize => write!(f, "Invalid FFT size"),
106        }
107    }
108}
109
110impl<T: Float + fmt::Display + fmt::Debug> std::error::Error for ConfigError<T> {}
111
112#[derive(Debug, Clone, Copy)]
113pub enum PadMode {
114    Reflect,
115    Zero,
116    Edge,
117}
118
119#[derive(Clone)]
120pub struct StftConfig<T: Float> {
121    pub fft_size: usize,
122    pub hop_size: usize,
123    pub window: WindowType,
124    pub reconstruction_mode: ReconstructionMode,
125    _phantom: std::marker::PhantomData<T>,
126}
127
128impl<T: Float + FromPrimitive + fmt::Debug> StftConfig<T> {
129    fn nola_threshold() -> T {
130        T::from(1e-8).unwrap()
131    }
132
133    fn cola_relative_tolerance() -> T {
134        T::from(1e-4).unwrap()
135    }
136
137    #[deprecated(
138        since = "0.4.0",
139        note = "Use `StftConfig::builder()` instead for a more flexible API"
140    )]
141    pub fn new(
142        fft_size: usize,
143        hop_size: usize,
144        window: WindowType,
145        reconstruction_mode: ReconstructionMode,
146    ) -> Result<Self, ConfigError<T>> {
147        if fft_size == 0 || !fft_size.is_power_of_two() {
148            return Err(ConfigError::InvalidFftSize);
149        }
150        if hop_size == 0 || hop_size > fft_size {
151            return Err(ConfigError::InvalidHopSize);
152        }
153
154        let config = Self {
155            fft_size,
156            hop_size,
157            window,
158            reconstruction_mode,
159            _phantom: std::marker::PhantomData,
160        };
161
162        // Validate appropriate condition based on reconstruction mode
163        match reconstruction_mode {
164            ReconstructionMode::Ola => config.validate_cola()?,
165            ReconstructionMode::Wola => config.validate_nola()?,
166        }
167
168        Ok(config)
169    }
170
171    /// Create a new builder for StftConfig
172    pub fn builder() -> StftConfigBuilder<T> {
173        StftConfigBuilder::new()
174    }
175
176    /// Default: 4096 FFT, 1024 hop, Hann window, OLA mode
177    #[allow(deprecated)]
178    pub fn default_4096() -> Self {
179        Self::new(4096, 1024, WindowType::Hann, ReconstructionMode::Ola)
180            .expect("Default config should always be valid")
181    }
182
183    pub fn freq_bins(&self) -> usize {
184        self.fft_size / 2 + 1
185    }
186
187    pub fn overlap_percent(&self) -> T {
188        let one = T::one();
189        let hundred = T::from(100.0).unwrap();
190        (one - T::from(self.hop_size).unwrap() / T::from(self.fft_size).unwrap()) * hundred
191    }
192
193    fn generate_window(&self) -> Vec<T> {
194        generate_window(self.window, self.fft_size)
195    }
196
197    /// Validate NOLA condition: sum(w^2) > threshold everywhere
198    pub fn validate_nola(&self) -> Result<(), ConfigError<T>> {
199        let window = self.generate_window();
200        let num_overlaps = self.fft_size.div_ceil(self.hop_size);
201        let test_len = self.fft_size + (num_overlaps - 1) * self.hop_size;
202        let mut energy = vec![T::zero(); test_len];
203
204        for i in 0..num_overlaps {
205            let offset = i * self.hop_size;
206            for j in 0..self.fft_size {
207                if offset + j < test_len {
208                    energy[offset + j] = energy[offset + j] + window[j] * window[j];
209                }
210            }
211        }
212
213        // Check the steady-state region (skip edges)
214        let start = self.fft_size / 2;
215        let end = test_len - self.fft_size / 2;
216        let min_energy = energy[start..end]
217            .iter()
218            .copied()
219            .min_by(|a, b| a.partial_cmp(b).unwrap())
220            .unwrap_or_else(T::zero);
221
222        if min_energy < Self::nola_threshold() {
223            return Err(ConfigError::NolaViolation {
224                min_energy,
225                threshold: Self::nola_threshold(),
226            });
227        }
228
229        Ok(())
230    }
231
232    /// Validate weak COLA condition: sum(w) is constant (within relative tolerance)
233    pub fn validate_cola(&self) -> Result<(), ConfigError<T>> {
234        let window = self.generate_window();
235        let window_len = window.len();
236
237        let mut cola_sum_period = vec![T::zero(); self.hop_size];
238        (0..window_len).for_each(|i| {
239            let idx = i % self.hop_size;
240            cola_sum_period[idx] = cola_sum_period[idx] + window[i];
241        });
242
243        let zero = T::zero();
244        let min_sum = cola_sum_period
245            .iter()
246            .min_by(|a, b| a.partial_cmp(b).unwrap())
247            .unwrap_or(&zero);
248        let max_sum = cola_sum_period
249            .iter()
250            .max_by(|a, b| a.partial_cmp(b).unwrap())
251            .unwrap_or(&zero);
252
253        let epsilon = T::from(1e-9).unwrap();
254        if *max_sum < epsilon {
255            return Err(ConfigError::ColaViolation {
256                max_deviation: T::infinity(),
257                threshold: Self::cola_relative_tolerance(),
258            });
259        }
260
261        let ripple = (*max_sum - *min_sum) / *max_sum;
262
263        let is_compliant = ripple < Self::cola_relative_tolerance();
264
265        if !is_compliant {
266            return Err(ConfigError::ColaViolation {
267                max_deviation: ripple,
268                threshold: Self::cola_relative_tolerance(),
269            });
270        }
271        Ok(())
272    }
273}
274
275/// Builder for StftConfig with fluent API
276pub struct StftConfigBuilder<T: Float> {
277    fft_size: Option<usize>,
278    hop_size: Option<usize>,
279    window: WindowType,
280    reconstruction_mode: ReconstructionMode,
281    _phantom: std::marker::PhantomData<T>,
282}
283
284impl<T: Float + FromPrimitive + fmt::Debug> StftConfigBuilder<T> {
285    /// Create a new builder with default values (Hann window, OLA mode)
286    pub fn new() -> Self {
287        Self {
288            fft_size: None,
289            hop_size: None,
290            window: WindowType::Hann,
291            reconstruction_mode: ReconstructionMode::Ola,
292            _phantom: std::marker::PhantomData,
293        }
294    }
295
296    /// Set the FFT size (must be a power of two)
297    pub fn fft_size(mut self, fft_size: usize) -> Self {
298        self.fft_size = Some(fft_size);
299        self
300    }
301
302    /// Set the hop size (must be > 0 and <= fft_size)
303    pub fn hop_size(mut self, hop_size: usize) -> Self {
304        self.hop_size = Some(hop_size);
305        self
306    }
307
308    /// Set the window type (default: Hann)
309    pub fn window(mut self, window: WindowType) -> Self {
310        self.window = window;
311        self
312    }
313
314    /// Set the reconstruction mode (default: OLA)
315    pub fn reconstruction_mode(mut self, mode: ReconstructionMode) -> Self {
316        self.reconstruction_mode = mode;
317        self
318    }
319
320    /// Build the StftConfig, validating all parameters
321    ///
322    /// Returns an error if:
323    /// - fft_size is not set or not a power of two
324    /// - hop_size is not set, zero, or greater than fft_size
325    /// - COLA/NOLA conditions are violated
326    #[allow(deprecated)]
327    pub fn build(self) -> Result<StftConfig<T>, ConfigError<T>> {
328        let fft_size = self.fft_size.ok_or(ConfigError::InvalidFftSize)?;
329        let hop_size = self.hop_size.ok_or(ConfigError::InvalidHopSize)?;
330
331        StftConfig::new(fft_size, hop_size, self.window, self.reconstruction_mode)
332    }
333}
334
335impl<T: Float + FromPrimitive + fmt::Debug> Default for StftConfigBuilder<T> {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341fn generate_window<T: Float + FromPrimitive>(window_type: WindowType, size: usize) -> Vec<T> {
342    let pi = T::from(std::f64::consts::PI).unwrap();
343    let two = T::from(2.0).unwrap();
344
345    match window_type {
346        WindowType::Hann => (0..size)
347            .map(|i| {
348                let half = T::from(0.5).unwrap();
349                let one = T::one();
350                let i_t = T::from(i).unwrap();
351                let size_m1 = T::from(size - 1).unwrap();
352                half * (one - (two * pi * i_t / size_m1).cos())
353            })
354            .collect(),
355        WindowType::Hamming => (0..size)
356            .map(|i| {
357                let i_t = T::from(i).unwrap();
358                let size_m1 = T::from(size - 1).unwrap();
359                T::from(0.54).unwrap() - T::from(0.46).unwrap() * (two * pi * i_t / size_m1).cos()
360            })
361            .collect(),
362        WindowType::Blackman => (0..size)
363            .map(|i| {
364                let i_t = T::from(i).unwrap();
365                let size_m1 = T::from(size - 1).unwrap();
366                let angle = two * pi * i_t / size_m1;
367                T::from(0.42).unwrap() - T::from(0.5).unwrap() * angle.cos()
368                    + T::from(0.08).unwrap() * (two * angle).cos()
369            })
370            .collect(),
371    }
372}
373
374#[derive(Clone)]
375pub struct SpectrumFrame<T: Float> {
376    pub freq_bins: usize,
377    pub data: Vec<Complex<T>>,
378}
379
380impl<T: Float> SpectrumFrame<T> {
381    pub fn new(freq_bins: usize) -> Self {
382        Self {
383            freq_bins,
384            data: vec![Complex::new(T::zero(), T::zero()); freq_bins],
385        }
386    }
387
388    pub fn from_data(data: Vec<Complex<T>>) -> Self {
389        let freq_bins = data.len();
390        Self { freq_bins, data }
391    }
392
393    /// Prepare frame for reuse by clearing data (keeps capacity)
394    pub fn clear(&mut self) {
395        for val in &mut self.data {
396            *val = Complex::new(T::zero(), T::zero());
397        }
398    }
399
400    /// Resize frame if needed to match freq_bins
401    pub fn resize_if_needed(&mut self, freq_bins: usize) {
402        if self.freq_bins != freq_bins {
403            self.freq_bins = freq_bins;
404            self.data
405                .resize(freq_bins, Complex::new(T::zero(), T::zero()));
406        }
407    }
408
409    /// Write data from a slice into this frame
410    pub fn write_from_slice(&mut self, data: &[Complex<T>]) {
411        self.resize_if_needed(data.len());
412        self.data.copy_from_slice(data);
413    }
414
415    /// Get the magnitude of a frequency bin
416    #[inline]
417    pub fn magnitude(&self, bin: usize) -> T {
418        let c = &self.data[bin];
419        (c.re * c.re + c.im * c.im).sqrt()
420    }
421
422    /// Get the phase of a frequency bin in radians
423    #[inline]
424    pub fn phase(&self, bin: usize) -> T {
425        let c = &self.data[bin];
426        c.im.atan2(c.re)
427    }
428
429    /// Set a frequency bin from magnitude and phase
430    pub fn set_magnitude_phase(&mut self, bin: usize, magnitude: T, phase: T) {
431        self.data[bin] = Complex::new(magnitude * phase.cos(), magnitude * phase.sin());
432    }
433
434    /// Create a SpectrumFrame from magnitude and phase arrays
435    pub fn from_magnitude_phase(magnitudes: &[T], phases: &[T]) -> Self {
436        assert_eq!(
437            magnitudes.len(),
438            phases.len(),
439            "Magnitude and phase arrays must have same length"
440        );
441        let freq_bins = magnitudes.len();
442        let data: Vec<Complex<T>> = magnitudes
443            .iter()
444            .zip(phases.iter())
445            .map(|(mag, phase)| Complex::new(*mag * phase.cos(), *mag * phase.sin()))
446            .collect();
447        Self { freq_bins, data }
448    }
449
450    /// Get all magnitudes as a Vec
451    pub fn magnitudes(&self) -> Vec<T> {
452        self.data
453            .iter()
454            .map(|c| (c.re * c.re + c.im * c.im).sqrt())
455            .collect()
456    }
457
458    /// Get all phases as a Vec
459    pub fn phases(&self) -> Vec<T> {
460        self.data.iter().map(|c| c.im.atan2(c.re)).collect()
461    }
462}
463
464#[derive(Clone)]
465pub struct Spectrum<T: Float> {
466    pub num_frames: usize,
467    pub freq_bins: usize,
468    pub data: Vec<T>,
469}
470
471impl<T: Float> Spectrum<T> {
472    pub fn new(num_frames: usize, freq_bins: usize) -> Self {
473        Self {
474            num_frames,
475            freq_bins,
476            data: vec![T::zero(); 2 * num_frames * freq_bins],
477        }
478    }
479
480    #[inline]
481    pub fn real(&self, frame: usize, bin: usize) -> T {
482        self.data[frame * self.freq_bins + bin]
483    }
484
485    #[inline]
486    pub fn imag(&self, frame: usize, bin: usize) -> T {
487        let offset = self.num_frames * self.freq_bins;
488        self.data[offset + frame * self.freq_bins + bin]
489    }
490
491    #[inline]
492    pub fn get_complex(&self, frame: usize, bin: usize) -> Complex<T> {
493        Complex::new(self.real(frame, bin), self.imag(frame, bin))
494    }
495
496    pub fn frames(&self) -> impl Iterator<Item = SpectrumFrame<T>> + '_ {
497        (0..self.num_frames).map(move |frame_idx| {
498            let data: Vec<Complex<T>> = (0..self.freq_bins)
499                .map(|bin| self.get_complex(frame_idx, bin))
500                .collect();
501            SpectrumFrame::from_data(data)
502        })
503    }
504
505    /// Set the real part of a bin
506    #[inline]
507    pub fn set_real(&mut self, frame: usize, bin: usize, value: T) {
508        self.data[frame * self.freq_bins + bin] = value;
509    }
510
511    /// Set the imaginary part of a bin
512    #[inline]
513    pub fn set_imag(&mut self, frame: usize, bin: usize, value: T) {
514        let offset = self.num_frames * self.freq_bins;
515        self.data[offset + frame * self.freq_bins + bin] = value;
516    }
517
518    /// Set a bin from a complex value
519    #[inline]
520    pub fn set_complex(&mut self, frame: usize, bin: usize, value: Complex<T>) {
521        self.set_real(frame, bin, value.re);
522        self.set_imag(frame, bin, value.im);
523    }
524
525    /// Get the magnitude of a frequency bin
526    #[inline]
527    pub fn magnitude(&self, frame: usize, bin: usize) -> T {
528        let re = self.real(frame, bin);
529        let im = self.imag(frame, bin);
530        (re * re + im * im).sqrt()
531    }
532
533    /// Get the phase of a frequency bin in radians
534    #[inline]
535    pub fn phase(&self, frame: usize, bin: usize) -> T {
536        let re = self.real(frame, bin);
537        let im = self.imag(frame, bin);
538        im.atan2(re)
539    }
540
541    /// Set a frequency bin from magnitude and phase
542    pub fn set_magnitude_phase(&mut self, frame: usize, bin: usize, magnitude: T, phase: T) {
543        self.set_real(frame, bin, magnitude * phase.cos());
544        self.set_imag(frame, bin, magnitude * phase.sin());
545    }
546
547    /// Get all magnitudes for a frame
548    pub fn frame_magnitudes(&self, frame: usize) -> Vec<T> {
549        (0..self.freq_bins)
550            .map(|bin| self.magnitude(frame, bin))
551            .collect()
552    }
553
554    /// Get all phases for a frame
555    pub fn frame_phases(&self, frame: usize) -> Vec<T> {
556        (0..self.freq_bins)
557            .map(|bin| self.phase(frame, bin))
558            .collect()
559    }
560
561    /// Apply a function to all bins
562    pub fn apply<F>(&mut self, mut f: F)
563    where
564        F: FnMut(usize, usize, Complex<T>) -> Complex<T>,
565    {
566        for frame in 0..self.num_frames {
567            for bin in 0..self.freq_bins {
568                let c = self.get_complex(frame, bin);
569                let new_c = f(frame, bin, c);
570                self.set_complex(frame, bin, new_c);
571            }
572        }
573    }
574
575    /// Apply a gain to a range of bins across all frames
576    pub fn apply_gain(&mut self, bin_range: std::ops::Range<usize>, gain: T) {
577        for frame in 0..self.num_frames {
578            for bin in bin_range.clone() {
579                if bin < self.freq_bins {
580                    let c = self.get_complex(frame, bin);
581                    self.set_complex(frame, bin, c * gain);
582                }
583            }
584        }
585    }
586
587    /// Zero out a range of bins across all frames
588    pub fn zero_bins(&mut self, bin_range: std::ops::Range<usize>) {
589        for frame in 0..self.num_frames {
590            for bin in bin_range.clone() {
591                if bin < self.freq_bins {
592                    self.set_complex(frame, bin, Complex::new(T::zero(), T::zero()));
593                }
594            }
595        }
596    }
597}
598
599pub struct BatchStft<T: Float + FftNum> {
600    config: StftConfig<T>,
601    window: Vec<T>,
602    fft: Arc<dyn Fft<T>>,
603}
604
605impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchStft<T> {
606    pub fn new(config: StftConfig<T>) -> Self {
607        let window = config.generate_window();
608        let mut planner = FftPlanner::new();
609        let fft = planner.plan_fft_forward(config.fft_size);
610
611        Self {
612            config,
613            window,
614            fft,
615        }
616    }
617
618    pub fn process(&self, signal: &[T]) -> Spectrum<T> {
619        self.process_padded(signal, PadMode::Reflect)
620    }
621
622    pub fn process_padded(&self, signal: &[T], pad_mode: PadMode) -> Spectrum<T> {
623        let pad_amount = self.config.fft_size / 2;
624        let padded = utils::apply_padding(signal, pad_amount, pad_mode);
625
626        let num_frames = if padded.len() >= self.config.fft_size {
627            (padded.len() - self.config.fft_size) / self.config.hop_size + 1
628        } else {
629            0
630        };
631
632        let freq_bins = self.config.freq_bins();
633        let mut result = Spectrum::new(num_frames, freq_bins);
634
635        let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
636
637        for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
638            .step_by(self.config.hop_size)
639            .enumerate()
640        {
641            // Apply window and prepare FFT input
642            for i in 0..self.config.fft_size {
643                fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
644            }
645
646            // Compute FFT
647            self.fft.process(&mut fft_buffer);
648
649            // Store positive frequencies in flat layout
650            (0..freq_bins).for_each(|bin| {
651                let idx = frame_idx * freq_bins + bin;
652                result.data[idx] = fft_buffer[bin].re;
653                result.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
654            });
655        }
656
657        result
658    }
659
660    /// Process signal and write into a pre-allocated Spectrum.
661    /// The spectrum must have the correct dimensions (num_frames x freq_bins).
662    /// Returns true if successful, false if dimensions don't match.
663    pub fn process_into(&self, signal: &[T], spectrum: &mut Spectrum<T>) -> bool {
664        self.process_padded_into(signal, PadMode::Reflect, spectrum)
665    }
666
667    /// Process signal with padding and write into a pre-allocated Spectrum.
668    pub fn process_padded_into(
669        &self,
670        signal: &[T],
671        pad_mode: PadMode,
672        spectrum: &mut Spectrum<T>,
673    ) -> bool {
674        let pad_amount = self.config.fft_size / 2;
675        let padded = utils::apply_padding(signal, pad_amount, pad_mode);
676
677        let num_frames = if padded.len() >= self.config.fft_size {
678            (padded.len() - self.config.fft_size) / self.config.hop_size + 1
679        } else {
680            0
681        };
682
683        let freq_bins = self.config.freq_bins();
684
685        // Check dimensions
686        if spectrum.num_frames != num_frames || spectrum.freq_bins != freq_bins {
687            return false;
688        }
689
690        let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
691
692        for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
693            .step_by(self.config.hop_size)
694            .enumerate()
695        {
696            // Apply window and prepare FFT input
697            for i in 0..self.config.fft_size {
698                fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
699            }
700
701            // Compute FFT
702            self.fft.process(&mut fft_buffer);
703
704            // Store positive frequencies in flat layout
705            (0..freq_bins).for_each(|bin| {
706                let idx = frame_idx * freq_bins + bin;
707                spectrum.data[idx] = fft_buffer[bin].re;
708                spectrum.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
709            });
710        }
711
712        true
713    }
714
715    /// Process multiple channels independently.
716    /// Returns one Spectrum per channel.
717    ///
718    /// # Arguments
719    ///
720    /// * `channels` - Slice of audio channels, each as a separate Vec
721    ///
722    /// # Panics
723    ///
724    /// Panics if channels is empty or if channels have different lengths.
725    ///
726    /// # Example
727    ///
728    /// ```
729    /// use stft_rs::prelude::*;
730    ///
731    /// let config = StftConfigF32::default_4096();
732    /// let stft = BatchStftF32::new(config);
733    ///
734    /// let left = vec![0.0; 44100];
735    /// let right = vec![0.0; 44100];
736    /// let channels = vec![left, right];
737    ///
738    /// let spectra = stft.process_multichannel(&channels);
739    /// assert_eq!(spectra.len(), 2); // One spectrum per channel
740    /// ```
741    pub fn process_multichannel(&self, channels: &[Vec<T>]) -> Vec<Spectrum<T>> {
742        assert!(!channels.is_empty(), "channels must not be empty");
743
744        // Validate all channels have same length
745        let expected_len = channels[0].len();
746        for (i, channel) in channels.iter().enumerate() {
747            assert_eq!(
748                channel.len(),
749                expected_len,
750                "Channel {} has length {}, expected {}",
751                i,
752                channel.len(),
753                expected_len
754            );
755        }
756
757        // Process each channel independently
758        #[cfg(feature = "rayon")]
759        {
760            use rayon::prelude::*;
761            channels
762                .par_iter()
763                .map(|channel| self.process(channel))
764                .collect()
765        }
766        #[cfg(not(feature = "rayon"))]
767        {
768            channels
769                .iter()
770                .map(|channel| self.process(channel))
771                .collect()
772        }
773    }
774
775    /// Process interleaved multi-channel audio.
776    /// Converts interleaved format (e.g., `[L,R,L,R,L,R,...]` for stereo)
777    /// into separate Spectrum for each channel.
778    ///
779    /// # Arguments
780    ///
781    /// * `data` - Interleaved audio data
782    /// * `num_channels` - Number of channels
783    ///
784    /// # Panics
785    ///
786    /// Panics if `num_channels` is 0 or if `data.len()` is not divisible by `num_channels`.
787    ///
788    /// # Example
789    ///
790    /// ```
791    /// use stft_rs::prelude::*;
792    ///
793    /// let config = StftConfigF32::default_4096();
794    /// let stft = BatchStftF32::new(config);
795    ///
796    /// // Stereo interleaved: L,R,L,R,L,R,...
797    /// let interleaved = vec![0.0; 88200]; // 2 channels * 44100 samples
798    ///
799    /// let spectra = stft.process_interleaved(&interleaved, 2);
800    /// assert_eq!(spectra.len(), 2); // One spectrum per channel
801    /// ```
802    pub fn process_interleaved(&self, data: &[T], num_channels: usize) -> Vec<Spectrum<T>> {
803        let channels = utils::deinterleave(data, num_channels);
804        self.process_multichannel(&channels)
805    }
806}
807
808pub struct BatchIstft<T: Float + FftNum> {
809    config: StftConfig<T>,
810    window: Vec<T>,
811    ifft: Arc<dyn Fft<T>>,
812}
813
814impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchIstft<T> {
815    pub fn new(config: StftConfig<T>) -> Self {
816        let window = config.generate_window();
817        let mut planner = FftPlanner::new();
818        let ifft = planner.plan_fft_inverse(config.fft_size);
819
820        Self {
821            config,
822            window,
823            ifft,
824        }
825    }
826
827    pub fn process(&self, spectrum: &Spectrum<T>) -> Vec<T> {
828        assert_eq!(
829            spectrum.freq_bins,
830            self.config.freq_bins(),
831            "Frequency bins mismatch"
832        );
833
834        let num_frames = spectrum.num_frames;
835        let original_time_len = (num_frames - 1) * self.config.hop_size;
836        let pad_amount = self.config.fft_size / 2;
837        let padded_len = original_time_len + 2 * pad_amount;
838
839        let mut overlap_buffer = vec![T::zero(); padded_len];
840        let mut window_energy = vec![T::zero(); padded_len];
841        let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
842
843        // Precompute window energy normalization
844        for frame_idx in 0..num_frames {
845            let pos = frame_idx * self.config.hop_size;
846            for i in 0..self.config.fft_size {
847                match self.config.reconstruction_mode {
848                    ReconstructionMode::Ola => {
849                        window_energy[pos + i] = window_energy[pos + i] + self.window[i];
850                    }
851                    ReconstructionMode::Wola => {
852                        window_energy[pos + i] =
853                            window_energy[pos + i] + self.window[i] * self.window[i];
854                    }
855                }
856            }
857        }
858
859        // Process each frame
860        for frame_idx in 0..num_frames {
861            // Build full spectrum with conjugate symmetry
862            (0..spectrum.freq_bins).for_each(|bin| {
863                ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
864            });
865
866            // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
867            for bin in 1..(spectrum.freq_bins - 1) {
868                ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
869            }
870
871            // Compute IFFT
872            self.ifft.process(&mut ifft_buffer);
873
874            // Overlap-add
875            let pos = frame_idx * self.config.hop_size;
876            for i in 0..self.config.fft_size {
877                let fft_size_t = T::from(self.config.fft_size).unwrap();
878                let sample = ifft_buffer[i].re / fft_size_t;
879
880                match self.config.reconstruction_mode {
881                    ReconstructionMode::Ola => {
882                        // OLA: no windowing on inverse
883                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
884                    }
885                    ReconstructionMode::Wola => {
886                        // WOLA: apply window on inverse
887                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
888                    }
889                }
890            }
891        }
892
893        // Normalize by window energy
894        let threshold = T::from(1e-8).unwrap();
895        for i in 0..padded_len {
896            if window_energy[i] > threshold {
897                overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
898            }
899        }
900
901        // Remove padding
902        overlap_buffer[pad_amount..pad_amount + original_time_len].to_vec()
903    }
904
905    /// Process spectrum and write into a pre-allocated output buffer.
906    /// The output buffer will be resized if needed.
907    pub fn process_into(&self, spectrum: &Spectrum<T>, output: &mut Vec<T>) {
908        assert_eq!(
909            spectrum.freq_bins,
910            self.config.freq_bins(),
911            "Frequency bins mismatch"
912        );
913
914        let num_frames = spectrum.num_frames;
915        let original_time_len = (num_frames - 1) * self.config.hop_size;
916        let pad_amount = self.config.fft_size / 2;
917        let padded_len = original_time_len + 2 * pad_amount;
918
919        let mut overlap_buffer = vec![T::zero(); padded_len];
920        let mut window_energy = vec![T::zero(); padded_len];
921        let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
922
923        // Precompute window energy normalization
924        for frame_idx in 0..num_frames {
925            let pos = frame_idx * self.config.hop_size;
926            for i in 0..self.config.fft_size {
927                match self.config.reconstruction_mode {
928                    ReconstructionMode::Ola => {
929                        window_energy[pos + i] = window_energy[pos + i] + self.window[i];
930                    }
931                    ReconstructionMode::Wola => {
932                        window_energy[pos + i] =
933                            window_energy[pos + i] + self.window[i] * self.window[i];
934                    }
935                }
936            }
937        }
938
939        // Process each frame
940        for frame_idx in 0..num_frames {
941            // Build full spectrum with conjugate symmetry
942            (0..spectrum.freq_bins).for_each(|bin| {
943                ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
944            });
945
946            // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
947            for bin in 1..(spectrum.freq_bins - 1) {
948                ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
949            }
950
951            // Compute IFFT
952            self.ifft.process(&mut ifft_buffer);
953
954            // Overlap-add
955            let pos = frame_idx * self.config.hop_size;
956            for i in 0..self.config.fft_size {
957                let fft_size_t = T::from(self.config.fft_size).unwrap();
958                let sample = ifft_buffer[i].re / fft_size_t;
959
960                match self.config.reconstruction_mode {
961                    ReconstructionMode::Ola => {
962                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
963                    }
964                    ReconstructionMode::Wola => {
965                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
966                    }
967                }
968            }
969        }
970
971        // Normalize by window energy
972        let threshold = T::from(1e-8).unwrap();
973        for i in 0..padded_len {
974            if window_energy[i] > threshold {
975                overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
976            }
977        }
978
979        // Copy to output (resize if needed)
980        output.clear();
981        output.extend_from_slice(&overlap_buffer[pad_amount..pad_amount + original_time_len]);
982    }
983
984    /// Reconstruct multiple channels from their spectra.
985    /// Returns one Vec per channel.
986    ///
987    /// # Arguments
988    ///
989    /// * `spectra` - Slice of Spectrum, one per channel
990    ///
991    /// # Panics
992    ///
993    /// Panics if spectra is empty.
994    ///
995    /// # Example
996    ///
997    /// ```
998    /// use stft_rs::prelude::*;
999    ///
1000    /// let config = StftConfigF32::default_4096();
1001    /// let stft = BatchStftF32::new(config.clone());
1002    /// let istft = BatchIstftF32::new(config);
1003    ///
1004    /// let left = vec![0.0; 44100];
1005    /// let right = vec![0.0; 44100];
1006    /// let channels = vec![left, right];
1007    ///
1008    /// let spectra = stft.process_multichannel(&channels);
1009    /// let reconstructed = istft.process_multichannel(&spectra);
1010    ///
1011    /// assert_eq!(reconstructed.len(), 2); // One channel per spectrum
1012    /// ```
1013    pub fn process_multichannel(&self, spectra: &[Spectrum<T>]) -> Vec<Vec<T>> {
1014        assert!(!spectra.is_empty(), "spectra must not be empty");
1015
1016        // Process each spectrum independently
1017        #[cfg(feature = "rayon")]
1018        {
1019            use rayon::prelude::*;
1020            spectra
1021                .par_iter()
1022                .map(|spectrum| self.process(spectrum))
1023                .collect()
1024        }
1025        #[cfg(not(feature = "rayon"))]
1026        {
1027            spectra
1028                .iter()
1029                .map(|spectrum| self.process(spectrum))
1030                .collect()
1031        }
1032    }
1033
1034    /// Reconstruct multiple channels and interleave them into a single buffer.
1035    /// Converts separate channels back to interleaved format (e.g., `[L,R,L,R,L,R,...]` for stereo).
1036    ///
1037    /// # Arguments
1038    ///
1039    /// * `spectra` - Slice of Spectrum, one per channel
1040    ///
1041    /// # Panics
1042    ///
1043    /// Panics if spectra is empty or if channels have different lengths.
1044    ///
1045    /// # Example
1046    ///
1047    /// ```
1048    /// use stft_rs::prelude::*;
1049    ///
1050    /// let config = StftConfigF32::default_4096();
1051    /// let stft = BatchStftF32::new(config.clone());
1052    /// let istft = BatchIstftF32::new(config);
1053    ///
1054    /// // Process interleaved stereo
1055    /// let interleaved = vec![0.0; 88200]; // 2 channels * 44100 samples
1056    /// let spectra = stft.process_interleaved(&interleaved, 2);
1057    ///
1058    /// // Reconstruct back to interleaved
1059    /// let output = istft.process_multichannel_interleaved(&spectra);
1060    /// // Output length may differ slightly due to padding/framing
1061    /// assert_eq!(output.len() / 2, 44032); // samples per channel after reconstruction
1062    /// ```
1063    pub fn process_multichannel_interleaved(&self, spectra: &[Spectrum<T>]) -> Vec<T> {
1064        let channels = self.process_multichannel(spectra);
1065        utils::interleave(&channels)
1066    }
1067}
1068
1069pub struct StreamingStft<T: Float + FftNum> {
1070    config: StftConfig<T>,
1071    window: Vec<T>,
1072    fft: Arc<dyn Fft<T>>,
1073    input_buffer: VecDeque<T>,
1074    frame_index: usize,
1075    fft_buffer: Vec<Complex<T>>,
1076}
1077
1078impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingStft<T> {
1079    pub fn new(config: StftConfig<T>) -> Self {
1080        let window = config.generate_window();
1081        let mut planner = FftPlanner::new();
1082        let fft = planner.plan_fft_forward(config.fft_size);
1083        let fft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1084
1085        Self {
1086            config,
1087            window,
1088            fft,
1089            input_buffer: VecDeque::new(),
1090            frame_index: 0,
1091            fft_buffer,
1092        }
1093    }
1094
1095    pub fn push_samples(&mut self, samples: &[T]) -> Vec<SpectrumFrame<T>> {
1096        self.input_buffer.extend(samples.iter().copied());
1097
1098        let mut frames = Vec::new();
1099
1100        while self.input_buffer.len() >= self.config.fft_size {
1101            // Process one frame
1102            for i in 0..self.config.fft_size {
1103                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1104            }
1105
1106            self.fft.process(&mut self.fft_buffer);
1107
1108            let freq_bins = self.config.freq_bins();
1109            let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1110            frames.push(SpectrumFrame::from_data(data));
1111
1112            // Advance by hop size
1113            self.input_buffer.drain(..self.config.hop_size);
1114            self.frame_index += 1;
1115        }
1116
1117        frames
1118    }
1119
1120    /// Push samples and write frames into a pre-allocated buffer.
1121    /// Returns the number of frames written.
1122    pub fn push_samples_into(
1123        &mut self,
1124        samples: &[T],
1125        output: &mut Vec<SpectrumFrame<T>>,
1126    ) -> usize {
1127        self.input_buffer.extend(samples.iter().copied());
1128
1129        let initial_len = output.len();
1130
1131        while self.input_buffer.len() >= self.config.fft_size {
1132            // Process one frame
1133            for i in 0..self.config.fft_size {
1134                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1135            }
1136
1137            self.fft.process(&mut self.fft_buffer);
1138
1139            let freq_bins = self.config.freq_bins();
1140            let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1141            output.push(SpectrumFrame::from_data(data));
1142
1143            // Advance by hop size
1144            self.input_buffer.drain(..self.config.hop_size);
1145            self.frame_index += 1;
1146        }
1147
1148        output.len() - initial_len
1149    }
1150
1151    /// Push samples and write directly into pre-existing SpectrumFrame buffers.
1152    /// This is a zero-allocation method - frames must be pre-allocated with correct size.
1153    /// Returns the number of frames written.
1154    ///
1155    /// # Example
1156    /// ```ignore
1157    /// let mut frame_pool = vec![SpectrumFrame::new(config.freq_bins()); 16];
1158    /// let mut frame_index = 0;
1159    ///
1160    /// let frames_written = stft.push_samples_write(chunk, &mut frame_pool, &mut frame_index);
1161    /// // Process frames 0..frames_written
1162    /// ```
1163    pub fn push_samples_write(
1164        &mut self,
1165        samples: &[T],
1166        frame_pool: &mut [SpectrumFrame<T>],
1167        pool_index: &mut usize,
1168    ) -> usize {
1169        self.input_buffer.extend(samples.iter().copied());
1170
1171        let initial_index = *pool_index;
1172        let freq_bins = self.config.freq_bins();
1173
1174        while self.input_buffer.len() >= self.config.fft_size && *pool_index < frame_pool.len() {
1175            // Process one frame
1176            for i in 0..self.config.fft_size {
1177                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1178            }
1179
1180            self.fft.process(&mut self.fft_buffer);
1181
1182            // Write directly into the pre-allocated frame
1183            let frame = &mut frame_pool[*pool_index];
1184            debug_assert_eq!(
1185                frame.freq_bins, freq_bins,
1186                "Frame pool frames must match freq_bins"
1187            );
1188            frame.data[..freq_bins].copy_from_slice(&self.fft_buffer[..freq_bins]);
1189
1190            // Advance by hop size
1191            self.input_buffer.drain(..self.config.hop_size);
1192            self.frame_index += 1;
1193            *pool_index += 1;
1194        }
1195
1196        *pool_index - initial_index
1197    }
1198
1199    pub fn flush(&mut self) -> Vec<SpectrumFrame<T>> {
1200        // For streaming, we typically don't process partial frames
1201        // Could zero-pad if needed, but that changes the signal
1202        Vec::new()
1203    }
1204
1205    pub fn reset(&mut self) {
1206        self.input_buffer.clear();
1207        self.frame_index = 0;
1208    }
1209
1210    pub fn buffered_samples(&self) -> usize {
1211        self.input_buffer.len()
1212    }
1213}
1214
1215/// Multi-channel streaming STFT processor with independent state per channel.
1216pub struct MultiChannelStreamingStft<T: Float + FftNum> {
1217    processors: Vec<StreamingStft<T>>,
1218}
1219
1220impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingStft<T> {
1221    /// Create a new multi-channel streaming STFT processor.
1222    ///
1223    /// # Arguments
1224    ///
1225    /// * `config` - STFT configuration
1226    /// * `num_channels` - Number of channels
1227    pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1228        assert!(num_channels > 0, "num_channels must be > 0");
1229        let processors = (0..num_channels)
1230            .map(|_| StreamingStft::new(config.clone()))
1231            .collect();
1232        Self { processors }
1233    }
1234
1235    /// Push samples for all channels and get frames for each channel.
1236    /// Returns Vec<Vec<SpectrumFrame>>, outer Vec = channels, inner Vec = frames.
1237    ///
1238    /// # Arguments
1239    ///
1240    /// * `channels` - Slice of sample slices, one per channel
1241    ///
1242    /// # Panics
1243    ///
1244    /// Panics if channels.len() doesn't match num_channels.
1245    pub fn push_samples(&mut self, channels: &[&[T]]) -> Vec<Vec<SpectrumFrame<T>>> {
1246        assert_eq!(
1247            channels.len(),
1248            self.processors.len(),
1249            "Expected {} channels, got {}",
1250            self.processors.len(),
1251            channels.len()
1252        );
1253
1254        #[cfg(feature = "rayon")]
1255        {
1256            use rayon::prelude::*;
1257            self.processors
1258                .par_iter_mut()
1259                .zip(channels.par_iter())
1260                .map(|(stft, channel)| stft.push_samples(channel))
1261                .collect()
1262        }
1263        #[cfg(not(feature = "rayon"))]
1264        {
1265            self.processors
1266                .iter_mut()
1267                .zip(channels.iter())
1268                .map(|(stft, channel)| stft.push_samples(channel))
1269                .collect()
1270        }
1271    }
1272
1273    /// Flush all channels and return remaining frames.
1274    pub fn flush(&mut self) -> Vec<Vec<SpectrumFrame<T>>> {
1275        #[cfg(feature = "rayon")]
1276        {
1277            use rayon::prelude::*;
1278            self.processors
1279                .par_iter_mut()
1280                .map(|stft| stft.flush())
1281                .collect()
1282        }
1283        #[cfg(not(feature = "rayon"))]
1284        {
1285            self.processors
1286                .iter_mut()
1287                .map(|stft| stft.flush())
1288                .collect()
1289        }
1290    }
1291
1292    /// Reset all channels.
1293    pub fn reset(&mut self) {
1294        #[cfg(feature = "rayon")]
1295        {
1296            use rayon::prelude::*;
1297            self.processors.par_iter_mut().for_each(|stft| stft.reset());
1298        }
1299        #[cfg(not(feature = "rayon"))]
1300        {
1301            self.processors.iter_mut().for_each(|stft| stft.reset());
1302        }
1303    }
1304
1305    /// Get the number of channels.
1306    pub fn num_channels(&self) -> usize {
1307        self.processors.len()
1308    }
1309}
1310
1311pub struct StreamingIstft<T: Float + FftNum> {
1312    config: StftConfig<T>,
1313    window: Vec<T>,
1314    ifft: Arc<dyn Fft<T>>,
1315    overlap_buffer: Vec<T>,
1316    window_energy: Vec<T>,
1317    output_position: usize,
1318    frames_processed: usize,
1319    ifft_buffer: Vec<Complex<T>>,
1320}
1321
1322impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingIstft<T> {
1323    pub fn new(config: StftConfig<T>) -> Self {
1324        let window = config.generate_window();
1325        let mut planner = FftPlanner::new();
1326        let ifft = planner.plan_fft_inverse(config.fft_size);
1327
1328        // Buffer needs to hold enough samples for full overlap
1329        // For proper reconstruction, need at least fft_size samples
1330        let buffer_size = config.fft_size * 2;
1331        let ifft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1332
1333        Self {
1334            config,
1335            window,
1336            ifft,
1337            overlap_buffer: vec![T::zero(); buffer_size],
1338            window_energy: vec![T::zero(); buffer_size],
1339            output_position: 0,
1340            frames_processed: 0,
1341            ifft_buffer,
1342        }
1343    }
1344
1345    pub fn push_frame(&mut self, frame: &SpectrumFrame<T>) -> Vec<T> {
1346        assert_eq!(
1347            frame.freq_bins,
1348            self.config.freq_bins(),
1349            "Frequency bins mismatch"
1350        );
1351
1352        // Build full spectrum with conjugate symmetry
1353        for bin in 0..frame.freq_bins {
1354            self.ifft_buffer[bin] = frame.data[bin];
1355        }
1356
1357        // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
1358        for bin in 1..(frame.freq_bins - 1) {
1359            self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1360        }
1361
1362        // Compute IFFT
1363        self.ifft.process(&mut self.ifft_buffer);
1364
1365        // Overlap-add into buffer at the current write position
1366        let write_pos = self.frames_processed * self.config.hop_size;
1367        for i in 0..self.config.fft_size {
1368            let fft_size_t = T::from(self.config.fft_size).unwrap();
1369            let sample = self.ifft_buffer[i].re / fft_size_t;
1370            let buf_idx = write_pos + i;
1371
1372            // Extend buffers if needed
1373            if buf_idx >= self.overlap_buffer.len() {
1374                self.overlap_buffer.resize(buf_idx + 1, T::zero());
1375                self.window_energy.resize(buf_idx + 1, T::zero());
1376            }
1377
1378            match self.config.reconstruction_mode {
1379                ReconstructionMode::Ola => {
1380                    self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1381                    self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1382                }
1383                ReconstructionMode::Wola => {
1384                    self.overlap_buffer[buf_idx] =
1385                        self.overlap_buffer[buf_idx] + sample * self.window[i];
1386                    self.window_energy[buf_idx] =
1387                        self.window_energy[buf_idx] + self.window[i] * self.window[i];
1388                }
1389            }
1390        }
1391
1392        self.frames_processed += 1;
1393
1394        // Calculate how many samples are "ready" (have full window energy)
1395        // Samples are ready when no future frames will contribute to them
1396        let ready_until = if self.frames_processed == 1 {
1397            0 // First frame: no output yet, need overlap
1398        } else {
1399            // Samples before the current frame's start position are complete
1400            (self.frames_processed - 1) * self.config.hop_size
1401        };
1402
1403        // Extract ready samples
1404        let output_start = self.output_position;
1405        let output_end = ready_until;
1406        let mut output = Vec::new();
1407
1408        let threshold = T::from(1e-8).unwrap();
1409        if output_end > output_start {
1410            for i in output_start..output_end {
1411                let normalized = if self.window_energy[i] > threshold {
1412                    self.overlap_buffer[i] / self.window_energy[i]
1413                } else {
1414                    T::zero()
1415                };
1416                output.push(normalized);
1417            }
1418            self.output_position = output_end;
1419        }
1420
1421        output
1422    }
1423
1424    /// Push a frame and write output samples into a pre-allocated buffer.
1425    /// Returns the number of samples written.
1426    pub fn push_frame_into(&mut self, frame: &SpectrumFrame<T>, output: &mut Vec<T>) -> usize {
1427        assert_eq!(
1428            frame.freq_bins,
1429            self.config.freq_bins(),
1430            "Frequency bins mismatch"
1431        );
1432
1433        // Build full spectrum with conjugate symmetry
1434        for bin in 0..frame.freq_bins {
1435            self.ifft_buffer[bin] = frame.data[bin];
1436        }
1437
1438        // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
1439        for bin in 1..(frame.freq_bins - 1) {
1440            self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1441        }
1442
1443        // Compute IFFT
1444        self.ifft.process(&mut self.ifft_buffer);
1445
1446        // Overlap-add into buffer at the current write position
1447        let write_pos = self.frames_processed * self.config.hop_size;
1448        for i in 0..self.config.fft_size {
1449            let fft_size_t = T::from(self.config.fft_size).unwrap();
1450            let sample = self.ifft_buffer[i].re / fft_size_t;
1451            let buf_idx = write_pos + i;
1452
1453            // Extend buffers if needed
1454            if buf_idx >= self.overlap_buffer.len() {
1455                self.overlap_buffer.resize(buf_idx + 1, T::zero());
1456                self.window_energy.resize(buf_idx + 1, T::zero());
1457            }
1458
1459            match self.config.reconstruction_mode {
1460                ReconstructionMode::Ola => {
1461                    self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1462                    self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1463                }
1464                ReconstructionMode::Wola => {
1465                    self.overlap_buffer[buf_idx] =
1466                        self.overlap_buffer[buf_idx] + sample * self.window[i];
1467                    self.window_energy[buf_idx] =
1468                        self.window_energy[buf_idx] + self.window[i] * self.window[i];
1469                }
1470            }
1471        }
1472
1473        self.frames_processed += 1;
1474
1475        // Calculate how many samples are "ready" (have full window energy)
1476        // Samples are ready when no future frames will contribute to them
1477        let ready_until = if self.frames_processed == 1 {
1478            0 // First frame: no output yet, need overlap
1479        } else {
1480            // Samples before the current frame's start position are complete
1481            (self.frames_processed - 1) * self.config.hop_size
1482        };
1483
1484        // Extract ready samples
1485        let output_start = self.output_position;
1486        let output_end = ready_until;
1487        let initial_len = output.len();
1488
1489        let threshold = T::from(1e-8).unwrap();
1490        if output_end > output_start {
1491            for i in output_start..output_end {
1492                let normalized = if self.window_energy[i] > threshold {
1493                    self.overlap_buffer[i] / self.window_energy[i]
1494                } else {
1495                    T::zero()
1496                };
1497                output.push(normalized);
1498            }
1499            self.output_position = output_end;
1500        }
1501
1502        output.len() - initial_len
1503    }
1504
1505    pub fn flush(&mut self) -> Vec<T> {
1506        // Return all remaining samples in buffer
1507        let mut output = Vec::new();
1508        let threshold = T::from(1e-8).unwrap();
1509        for i in self.output_position..self.overlap_buffer.len() {
1510            if self.window_energy[i] > threshold {
1511                output.push(self.overlap_buffer[i] / self.window_energy[i]);
1512            } else if i < (self.frames_processed * self.config.hop_size + self.config.fft_size) {
1513                output.push(T::zero()); // Sample in valid range but no window energy
1514            } else {
1515                break; // Past the end of valid data
1516            }
1517        }
1518
1519        // Determine the actual end of valid data
1520        let valid_end =
1521            (self.frames_processed.saturating_sub(1)) * self.config.hop_size + self.config.fft_size;
1522        if output.len() > valid_end - self.output_position {
1523            output.truncate(valid_end - self.output_position);
1524        }
1525
1526        self.reset();
1527        output
1528    }
1529
1530    pub fn reset(&mut self) {
1531        self.overlap_buffer.clear();
1532        self.overlap_buffer
1533            .resize(self.config.fft_size * 2, T::zero());
1534        self.window_energy.clear();
1535        self.window_energy
1536            .resize(self.config.fft_size * 2, T::zero());
1537        self.output_position = 0;
1538        self.frames_processed = 0;
1539    }
1540}
1541
1542/// Multi-channel streaming iSTFT processor with independent state per channel.
1543pub struct MultiChannelStreamingIstft<T: Float + FftNum> {
1544    processors: Vec<StreamingIstft<T>>,
1545}
1546
1547impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingIstft<T> {
1548    /// Create a new multi-channel streaming iSTFT processor.
1549    ///
1550    /// # Arguments
1551    ///
1552    /// * `config` - STFT configuration
1553    /// * `num_channels` - Number of channels
1554    pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1555        assert!(num_channels > 0, "num_channels must be > 0");
1556        let processors = (0..num_channels)
1557            .map(|_| StreamingIstft::new(config.clone()))
1558            .collect();
1559        Self { processors }
1560    }
1561
1562    /// Push frames for all channels and get samples for each channel.
1563    /// Returns Vec<Vec<T>>, outer Vec = channels, inner Vec = samples.
1564    ///
1565    /// # Arguments
1566    ///
1567    /// * `frames` - Slice of frames, one per channel
1568    ///
1569    /// # Panics
1570    ///
1571    /// Panics if frames.len() doesn't match num_channels.
1572    pub fn push_frames(&mut self, frames: &[&SpectrumFrame<T>]) -> Vec<Vec<T>> {
1573        assert_eq!(
1574            frames.len(),
1575            self.processors.len(),
1576            "Expected {} channels, got {}",
1577            self.processors.len(),
1578            frames.len()
1579        );
1580
1581        #[cfg(feature = "rayon")]
1582        {
1583            use rayon::prelude::*;
1584            self.processors
1585                .par_iter_mut()
1586                .zip(frames.par_iter())
1587                .map(|(istft, frame)| istft.push_frame(frame))
1588                .collect()
1589        }
1590        #[cfg(not(feature = "rayon"))]
1591        {
1592            self.processors
1593                .iter_mut()
1594                .zip(frames.iter())
1595                .map(|(istft, frame)| istft.push_frame(frame))
1596                .collect()
1597        }
1598    }
1599
1600    /// Flush all channels and return remaining samples.
1601    pub fn flush(&mut self) -> Vec<Vec<T>> {
1602        #[cfg(feature = "rayon")]
1603        {
1604            use rayon::prelude::*;
1605            self.processors
1606                .par_iter_mut()
1607                .map(|istft| istft.flush())
1608                .collect()
1609        }
1610        #[cfg(not(feature = "rayon"))]
1611        {
1612            self.processors
1613                .iter_mut()
1614                .map(|istft| istft.flush())
1615                .collect()
1616        }
1617    }
1618
1619    /// Reset all channels.
1620    pub fn reset(&mut self) {
1621        #[cfg(feature = "rayon")]
1622        {
1623            use rayon::prelude::*;
1624            self.processors
1625                .par_iter_mut()
1626                .for_each(|istft| istft.reset());
1627        }
1628        #[cfg(not(feature = "rayon"))]
1629        {
1630            self.processors.iter_mut().for_each(|istft| istft.reset());
1631        }
1632    }
1633
1634    /// Get the number of channels.
1635    pub fn num_channels(&self) -> usize {
1636        self.processors.len()
1637    }
1638}
1639
1640// Type aliases for common float types
1641pub type StftConfigF32 = StftConfig<f32>;
1642pub type StftConfigF64 = StftConfig<f64>;
1643
1644pub type StftConfigBuilderF32 = StftConfigBuilder<f32>;
1645pub type StftConfigBuilderF64 = StftConfigBuilder<f64>;
1646
1647pub type BatchStftF32 = BatchStft<f32>;
1648pub type BatchStftF64 = BatchStft<f64>;
1649
1650pub type BatchIstftF32 = BatchIstft<f32>;
1651pub type BatchIstftF64 = BatchIstft<f64>;
1652
1653pub type StreamingStftF32 = StreamingStft<f32>;
1654pub type StreamingStftF64 = StreamingStft<f64>;
1655
1656pub type StreamingIstftF32 = StreamingIstft<f32>;
1657pub type StreamingIstftF64 = StreamingIstft<f64>;
1658
1659pub type SpectrumF32 = Spectrum<f32>;
1660pub type SpectrumF64 = Spectrum<f64>;
1661
1662pub type SpectrumFrameF32 = SpectrumFrame<f32>;
1663pub type SpectrumFrameF64 = SpectrumFrame<f64>;
1664
1665pub type MultiChannelStreamingStftF32 = MultiChannelStreamingStft<f32>;
1666pub type MultiChannelStreamingStftF64 = MultiChannelStreamingStft<f64>;
1667
1668pub type MultiChannelStreamingIstftF32 = MultiChannelStreamingIstft<f32>;
1669pub type MultiChannelStreamingIstftF64 = MultiChannelStreamingIstft<f64>;