stft_rs/
lib.rs

1/*MIT License
2
3Copyright (c) 2025 David Maseda Neira
4
5Permission is hereby granted, free of charge, to any person obtaining a copy
6of this software and associated documentation files (the "Software"), to deal
7in the Software without restriction, including without limitation the rights
8to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9copies of the Software, and to permit persons to whom the Software is
10furnished to do so, subject to the following conditions:
11
12The above copyright notice and this permission notice shall be included in all
13copies or substantial portions of the Software.
14
15THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21SOFTWARE.
22*/
23
24use num_traits::{Float, FromPrimitive};
25use rustfft::num_complex::Complex;
26use rustfft::{Fft, FftNum, FftPlanner};
27use std::collections::VecDeque;
28use std::fmt;
29use std::sync::Arc;
30
31mod utils;
32pub use utils::{apply_padding, deinterleave, deinterleave_into, interleave, interleave_into};
33
34pub mod prelude {
35    pub use crate::utils::{
36        apply_padding, deinterleave, deinterleave_into, interleave, interleave_into,
37    };
38    pub use crate::{
39        BatchIstft, BatchIstftF32, BatchIstftF64, BatchStft, BatchStftF32, BatchStftF64,
40        MultiChannelStreamingIstft, MultiChannelStreamingIstftF32, MultiChannelStreamingIstftF64,
41        MultiChannelStreamingStft, MultiChannelStreamingStftF32, MultiChannelStreamingStftF64,
42        PadMode, ReconstructionMode, Spectrum, SpectrumF32, SpectrumF64, SpectrumFrame,
43        SpectrumFrameF32, SpectrumFrameF64, StftConfig, StftConfigF32, StftConfigF64,
44        StreamingIstft, StreamingIstftF32, StreamingIstftF64, StreamingStft, StreamingStftF32,
45        StreamingStftF64, WindowType,
46    };
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum ReconstructionMode {
51    /// Overlap-Add: normalize by sum(w), requires COLA condition
52    Ola,
53
54    /// Weighted Overlap-Add: normalize by sum(w^2), requires NOLA condition
55    Wola,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum WindowType {
60    Hann,
61    Hamming,
62    Blackman,
63}
64
65#[derive(Debug, Clone)]
66pub enum ConfigError<T: Float + fmt::Debug> {
67    NolaViolation { min_energy: T, threshold: T },
68    ColaViolation { max_deviation: T, threshold: T },
69    InvalidHopSize,
70    InvalidFftSize,
71}
72
73impl<T: Float + fmt::Display + fmt::Debug> fmt::Display for ConfigError<T> {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        match self {
76            ConfigError::NolaViolation {
77                min_energy,
78                threshold,
79            } => {
80                write!(
81                    f,
82                    "NOLA condition violated: min_energy={} < threshold={}",
83                    min_energy, threshold
84                )
85            }
86            ConfigError::ColaViolation {
87                max_deviation,
88                threshold,
89            } => {
90                write!(
91                    f,
92                    "COLA condition violated: max_deviation={} > threshold={}",
93                    max_deviation, threshold
94                )
95            }
96            ConfigError::InvalidHopSize => write!(f, "Invalid hop size"),
97            ConfigError::InvalidFftSize => write!(f, "Invalid FFT size"),
98        }
99    }
100}
101
102impl<T: Float + fmt::Display + fmt::Debug> std::error::Error for ConfigError<T> {}
103
104#[derive(Debug, Clone, Copy)]
105pub enum PadMode {
106    Reflect,
107    Zero,
108    Edge,
109}
110
111#[derive(Clone)]
112pub struct StftConfig<T: Float> {
113    pub fft_size: usize,
114    pub hop_size: usize,
115    pub window: WindowType,
116    pub reconstruction_mode: ReconstructionMode,
117    _phantom: std::marker::PhantomData<T>,
118}
119
120impl<T: Float + FromPrimitive + fmt::Debug> StftConfig<T> {
121    fn nola_threshold() -> T {
122        T::from(1e-8).unwrap()
123    }
124
125    fn cola_relative_tolerance() -> T {
126        T::from(1e-4).unwrap()
127    }
128
129    pub fn new(
130        fft_size: usize,
131        hop_size: usize,
132        window: WindowType,
133        reconstruction_mode: ReconstructionMode,
134    ) -> Result<Self, ConfigError<T>> {
135        if fft_size == 0 || !fft_size.is_power_of_two() {
136            return Err(ConfigError::InvalidFftSize);
137        }
138        if hop_size == 0 || hop_size > fft_size {
139            return Err(ConfigError::InvalidHopSize);
140        }
141
142        let config = Self {
143            fft_size,
144            hop_size,
145            window,
146            reconstruction_mode,
147            _phantom: std::marker::PhantomData,
148        };
149
150        // Validate appropriate condition based on reconstruction mode
151        match reconstruction_mode {
152            ReconstructionMode::Ola => config.validate_cola()?,
153            ReconstructionMode::Wola => config.validate_nola()?,
154        }
155
156        Ok(config)
157    }
158
159    /// Default: 4096 FFT, 1024 hop, Hann window, OLA mode
160    pub fn default_4096() -> Self {
161        Self::new(4096, 1024, WindowType::Hann, ReconstructionMode::Ola)
162            .expect("Default config should always be valid")
163    }
164
165    pub fn freq_bins(&self) -> usize {
166        self.fft_size / 2 + 1
167    }
168
169    pub fn overlap_percent(&self) -> T {
170        let one = T::one();
171        let hundred = T::from(100.0).unwrap();
172        (one - T::from(self.hop_size).unwrap() / T::from(self.fft_size).unwrap()) * hundred
173    }
174
175    fn generate_window(&self) -> Vec<T> {
176        generate_window(self.window, self.fft_size)
177    }
178
179    /// Validate NOLA condition: sum(w^2) > threshold everywhere
180    pub fn validate_nola(&self) -> Result<(), ConfigError<T>> {
181        let window = self.generate_window();
182        let num_overlaps = (self.fft_size + self.hop_size - 1) / self.hop_size;
183        let test_len = self.fft_size + (num_overlaps - 1) * self.hop_size;
184        let mut energy = vec![T::zero(); test_len];
185
186        for i in 0..num_overlaps {
187            let offset = i * self.hop_size;
188            for j in 0..self.fft_size {
189                if offset + j < test_len {
190                    energy[offset + j] = energy[offset + j] + window[j] * window[j];
191                }
192            }
193        }
194
195        // Check the steady-state region (skip edges)
196        let start = self.fft_size / 2;
197        let end = test_len - self.fft_size / 2;
198        let min_energy = energy[start..end]
199            .iter()
200            .copied()
201            .min_by(|a, b| a.partial_cmp(b).unwrap())
202            .unwrap_or_else(T::zero);
203
204        if min_energy < Self::nola_threshold() {
205            return Err(ConfigError::NolaViolation {
206                min_energy,
207                threshold: Self::nola_threshold(),
208            });
209        }
210
211        Ok(())
212    }
213
214    /// Validate weak COLA condition: sum(w) is constant (within relative tolerance)
215    pub fn validate_cola(&self) -> Result<(), ConfigError<T>> {
216        let window = self.generate_window();
217        let window_len = window.len();
218
219        let mut cola_sum_period = vec![T::zero(); self.hop_size];
220        for i in 0..window_len {
221            let idx = i % self.hop_size;
222            cola_sum_period[idx] = cola_sum_period[idx] + window[i];
223        }
224
225        let zero = T::zero();
226        let min_sum = cola_sum_period
227            .iter()
228            .min_by(|a, b| a.partial_cmp(b).unwrap())
229            .unwrap_or(&zero);
230        let max_sum = cola_sum_period
231            .iter()
232            .max_by(|a, b| a.partial_cmp(b).unwrap())
233            .unwrap_or(&zero);
234
235        let epsilon = T::from(1e-9).unwrap();
236        if *max_sum < epsilon {
237            return Err(ConfigError::ColaViolation {
238                max_deviation: T::infinity(),
239                threshold: Self::cola_relative_tolerance(),
240            });
241        }
242
243        let ripple = (*max_sum - *min_sum) / *max_sum;
244
245        let is_compliant = ripple < Self::cola_relative_tolerance();
246
247        if !is_compliant {
248            return Err(ConfigError::ColaViolation {
249                max_deviation: ripple,
250                threshold: Self::cola_relative_tolerance(),
251            });
252        }
253        Ok(())
254    }
255}
256
257fn generate_window<T: Float + FromPrimitive>(window_type: WindowType, size: usize) -> Vec<T> {
258    let pi = T::from(std::f64::consts::PI).unwrap();
259    let two = T::from(2.0).unwrap();
260
261    match window_type {
262        WindowType::Hann => (0..size)
263            .map(|i| {
264                let half = T::from(0.5).unwrap();
265                let one = T::one();
266                let i_t = T::from(i).unwrap();
267                let size_m1 = T::from(size - 1).unwrap();
268                half * (one - (two * pi * i_t / size_m1).cos())
269            })
270            .collect(),
271        WindowType::Hamming => (0..size)
272            .map(|i| {
273                let i_t = T::from(i).unwrap();
274                let size_m1 = T::from(size - 1).unwrap();
275                T::from(0.54).unwrap() - T::from(0.46).unwrap() * (two * pi * i_t / size_m1).cos()
276            })
277            .collect(),
278        WindowType::Blackman => (0..size)
279            .map(|i| {
280                let i_t = T::from(i).unwrap();
281                let size_m1 = T::from(size - 1).unwrap();
282                let angle = two * pi * i_t / size_m1;
283                T::from(0.42).unwrap() - T::from(0.5).unwrap() * angle.cos()
284                    + T::from(0.08).unwrap() * (two * angle).cos()
285            })
286            .collect(),
287    }
288}
289
290#[derive(Clone)]
291pub struct SpectrumFrame<T: Float> {
292    pub freq_bins: usize,
293    pub data: Vec<Complex<T>>,
294}
295
296impl<T: Float> SpectrumFrame<T> {
297    pub fn new(freq_bins: usize) -> Self {
298        Self {
299            freq_bins,
300            data: vec![Complex::new(T::zero(), T::zero()); freq_bins],
301        }
302    }
303
304    pub fn from_data(data: Vec<Complex<T>>) -> Self {
305        let freq_bins = data.len();
306        Self { freq_bins, data }
307    }
308
309    /// Prepare frame for reuse by clearing data (keeps capacity)
310    pub fn clear(&mut self) {
311        for val in &mut self.data {
312            *val = Complex::new(T::zero(), T::zero());
313        }
314    }
315
316    /// Resize frame if needed to match freq_bins
317    pub fn resize_if_needed(&mut self, freq_bins: usize) {
318        if self.freq_bins != freq_bins {
319            self.freq_bins = freq_bins;
320            self.data
321                .resize(freq_bins, Complex::new(T::zero(), T::zero()));
322        }
323    }
324
325    /// Write data from a slice into this frame
326    pub fn write_from_slice(&mut self, data: &[Complex<T>]) {
327        self.resize_if_needed(data.len());
328        self.data.copy_from_slice(data);
329    }
330
331    /// Get the magnitude of a frequency bin
332    #[inline]
333    pub fn magnitude(&self, bin: usize) -> T {
334        let c = &self.data[bin];
335        (c.re * c.re + c.im * c.im).sqrt()
336    }
337
338    /// Get the phase of a frequency bin in radians
339    #[inline]
340    pub fn phase(&self, bin: usize) -> T {
341        let c = &self.data[bin];
342        c.im.atan2(c.re)
343    }
344
345    /// Set a frequency bin from magnitude and phase
346    pub fn set_magnitude_phase(&mut self, bin: usize, magnitude: T, phase: T) {
347        self.data[bin] = Complex::new(magnitude * phase.cos(), magnitude * phase.sin());
348    }
349
350    /// Create a SpectrumFrame from magnitude and phase arrays
351    pub fn from_magnitude_phase(magnitudes: &[T], phases: &[T]) -> Self {
352        assert_eq!(
353            magnitudes.len(),
354            phases.len(),
355            "Magnitude and phase arrays must have same length"
356        );
357        let freq_bins = magnitudes.len();
358        let data: Vec<Complex<T>> = magnitudes
359            .iter()
360            .zip(phases.iter())
361            .map(|(mag, phase)| Complex::new(*mag * phase.cos(), *mag * phase.sin()))
362            .collect();
363        Self { freq_bins, data }
364    }
365
366    /// Get all magnitudes as a Vec
367    pub fn magnitudes(&self) -> Vec<T> {
368        self.data
369            .iter()
370            .map(|c| (c.re * c.re + c.im * c.im).sqrt())
371            .collect()
372    }
373
374    /// Get all phases as a Vec
375    pub fn phases(&self) -> Vec<T> {
376        self.data.iter().map(|c| c.im.atan2(c.re)).collect()
377    }
378}
379
380#[derive(Clone)]
381pub struct Spectrum<T: Float> {
382    pub num_frames: usize,
383    pub freq_bins: usize,
384    pub data: Vec<T>,
385}
386
387impl<T: Float> Spectrum<T> {
388    pub fn new(num_frames: usize, freq_bins: usize) -> Self {
389        Self {
390            num_frames,
391            freq_bins,
392            data: vec![T::zero(); 2 * num_frames * freq_bins],
393        }
394    }
395
396    #[inline]
397    pub fn real(&self, frame: usize, bin: usize) -> T {
398        self.data[frame * self.freq_bins + bin]
399    }
400
401    #[inline]
402    pub fn imag(&self, frame: usize, bin: usize) -> T {
403        let offset = self.num_frames * self.freq_bins;
404        self.data[offset + frame * self.freq_bins + bin]
405    }
406
407    #[inline]
408    pub fn get_complex(&self, frame: usize, bin: usize) -> Complex<T> {
409        Complex::new(self.real(frame, bin), self.imag(frame, bin))
410    }
411
412    pub fn frames(&self) -> impl Iterator<Item = SpectrumFrame<T>> + '_ {
413        (0..self.num_frames).map(move |frame_idx| {
414            let data: Vec<Complex<T>> = (0..self.freq_bins)
415                .map(|bin| self.get_complex(frame_idx, bin))
416                .collect();
417            SpectrumFrame::from_data(data)
418        })
419    }
420
421    /// Set the real part of a bin
422    #[inline]
423    pub fn set_real(&mut self, frame: usize, bin: usize, value: T) {
424        self.data[frame * self.freq_bins + bin] = value;
425    }
426
427    /// Set the imaginary part of a bin
428    #[inline]
429    pub fn set_imag(&mut self, frame: usize, bin: usize, value: T) {
430        let offset = self.num_frames * self.freq_bins;
431        self.data[offset + frame * self.freq_bins + bin] = value;
432    }
433
434    /// Set a bin from a complex value
435    #[inline]
436    pub fn set_complex(&mut self, frame: usize, bin: usize, value: Complex<T>) {
437        self.set_real(frame, bin, value.re);
438        self.set_imag(frame, bin, value.im);
439    }
440
441    /// Get the magnitude of a frequency bin
442    #[inline]
443    pub fn magnitude(&self, frame: usize, bin: usize) -> T {
444        let re = self.real(frame, bin);
445        let im = self.imag(frame, bin);
446        (re * re + im * im).sqrt()
447    }
448
449    /// Get the phase of a frequency bin in radians
450    #[inline]
451    pub fn phase(&self, frame: usize, bin: usize) -> T {
452        let re = self.real(frame, bin);
453        let im = self.imag(frame, bin);
454        im.atan2(re)
455    }
456
457    /// Set a frequency bin from magnitude and phase
458    pub fn set_magnitude_phase(&mut self, frame: usize, bin: usize, magnitude: T, phase: T) {
459        self.set_real(frame, bin, magnitude * phase.cos());
460        self.set_imag(frame, bin, magnitude * phase.sin());
461    }
462
463    /// Get all magnitudes for a frame
464    pub fn frame_magnitudes(&self, frame: usize) -> Vec<T> {
465        (0..self.freq_bins)
466            .map(|bin| self.magnitude(frame, bin))
467            .collect()
468    }
469
470    /// Get all phases for a frame
471    pub fn frame_phases(&self, frame: usize) -> Vec<T> {
472        (0..self.freq_bins)
473            .map(|bin| self.phase(frame, bin))
474            .collect()
475    }
476
477    /// Apply a function to all bins
478    pub fn apply<F>(&mut self, mut f: F)
479    where
480        F: FnMut(usize, usize, Complex<T>) -> Complex<T>,
481    {
482        for frame in 0..self.num_frames {
483            for bin in 0..self.freq_bins {
484                let c = self.get_complex(frame, bin);
485                let new_c = f(frame, bin, c);
486                self.set_complex(frame, bin, new_c);
487            }
488        }
489    }
490
491    /// Apply a gain to a range of bins across all frames
492    pub fn apply_gain(&mut self, bin_range: std::ops::Range<usize>, gain: T) {
493        for frame in 0..self.num_frames {
494            for bin in bin_range.clone() {
495                if bin < self.freq_bins {
496                    let c = self.get_complex(frame, bin);
497                    self.set_complex(frame, bin, c * gain);
498                }
499            }
500        }
501    }
502
503    /// Zero out a range of bins across all frames
504    pub fn zero_bins(&mut self, bin_range: std::ops::Range<usize>) {
505        for frame in 0..self.num_frames {
506            for bin in bin_range.clone() {
507                if bin < self.freq_bins {
508                    self.set_complex(frame, bin, Complex::new(T::zero(), T::zero()));
509                }
510            }
511        }
512    }
513}
514
515pub struct BatchStft<T: Float + FftNum> {
516    config: StftConfig<T>,
517    window: Vec<T>,
518    fft: Arc<dyn Fft<T>>,
519}
520
521impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchStft<T> {
522    pub fn new(config: StftConfig<T>) -> Self {
523        let window = config.generate_window();
524        let mut planner = FftPlanner::new();
525        let fft = planner.plan_fft_forward(config.fft_size);
526
527        Self {
528            config,
529            window,
530            fft,
531        }
532    }
533
534    pub fn process(&self, signal: &[T]) -> Spectrum<T> {
535        self.process_padded(signal, PadMode::Reflect)
536    }
537
538    pub fn process_padded(&self, signal: &[T], pad_mode: PadMode) -> Spectrum<T> {
539        let pad_amount = self.config.fft_size / 2;
540        let padded = utils::apply_padding(signal, pad_amount, pad_mode);
541
542        let num_frames = if padded.len() >= self.config.fft_size {
543            (padded.len() - self.config.fft_size) / self.config.hop_size + 1
544        } else {
545            0
546        };
547
548        let freq_bins = self.config.freq_bins();
549        let mut result = Spectrum::new(num_frames, freq_bins);
550
551        let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
552
553        for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
554            .step_by(self.config.hop_size)
555            .enumerate()
556        {
557            // Apply window and prepare FFT input
558            for i in 0..self.config.fft_size {
559                fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
560            }
561
562            // Compute FFT
563            self.fft.process(&mut fft_buffer);
564
565            // Store positive frequencies in flat layout
566            for bin in 0..freq_bins {
567                let idx = frame_idx * freq_bins + bin;
568                result.data[idx] = fft_buffer[bin].re;
569                result.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
570            }
571        }
572
573        result
574    }
575
576    /// Process signal and write into a pre-allocated Spectrum.
577    /// The spectrum must have the correct dimensions (num_frames x freq_bins).
578    /// Returns true if successful, false if dimensions don't match.
579    pub fn process_into(&self, signal: &[T], spectrum: &mut Spectrum<T>) -> bool {
580        self.process_padded_into(signal, PadMode::Reflect, spectrum)
581    }
582
583    /// Process signal with padding and write into a pre-allocated Spectrum.
584    pub fn process_padded_into(
585        &self,
586        signal: &[T],
587        pad_mode: PadMode,
588        spectrum: &mut Spectrum<T>,
589    ) -> bool {
590        let pad_amount = self.config.fft_size / 2;
591        let padded = utils::apply_padding(signal, pad_amount, pad_mode);
592
593        let num_frames = if padded.len() >= self.config.fft_size {
594            (padded.len() - self.config.fft_size) / self.config.hop_size + 1
595        } else {
596            0
597        };
598
599        let freq_bins = self.config.freq_bins();
600
601        // Check dimensions
602        if spectrum.num_frames != num_frames || spectrum.freq_bins != freq_bins {
603            return false;
604        }
605
606        let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
607
608        for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
609            .step_by(self.config.hop_size)
610            .enumerate()
611        {
612            // Apply window and prepare FFT input
613            for i in 0..self.config.fft_size {
614                fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
615            }
616
617            // Compute FFT
618            self.fft.process(&mut fft_buffer);
619
620            // Store positive frequencies in flat layout
621            for bin in 0..freq_bins {
622                let idx = frame_idx * freq_bins + bin;
623                spectrum.data[idx] = fft_buffer[bin].re;
624                spectrum.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
625            }
626        }
627
628        true
629    }
630
631    /// Process multiple channels independently.
632    /// Returns one Spectrum per channel.
633    ///
634    /// # Arguments
635    ///
636    /// * `channels` - Slice of audio channels, each as a separate Vec
637    ///
638    /// # Panics
639    ///
640    /// Panics if channels is empty or if channels have different lengths.
641    ///
642    /// # Example
643    ///
644    /// ```
645    /// use stft_rs::prelude::*;
646    ///
647    /// let config = StftConfigF32::default_4096();
648    /// let stft = BatchStftF32::new(config);
649    ///
650    /// let left = vec![0.0; 44100];
651    /// let right = vec![0.0; 44100];
652    /// let channels = vec![left, right];
653    ///
654    /// let spectra = stft.process_multichannel(&channels);
655    /// assert_eq!(spectra.len(), 2); // One spectrum per channel
656    /// ```
657    pub fn process_multichannel(&self, channels: &[Vec<T>]) -> Vec<Spectrum<T>> {
658        assert!(!channels.is_empty(), "channels must not be empty");
659
660        // Validate all channels have same length
661        let expected_len = channels[0].len();
662        for (i, channel) in channels.iter().enumerate() {
663            assert_eq!(
664                channel.len(),
665                expected_len,
666                "Channel {} has length {}, expected {}",
667                i,
668                channel.len(),
669                expected_len
670            );
671        }
672
673        // Process each channel independently
674        #[cfg(feature = "rayon")]
675        {
676            use rayon::prelude::*;
677            channels
678                .par_iter()
679                .map(|channel| self.process(channel))
680                .collect()
681        }
682        #[cfg(not(feature = "rayon"))]
683        {
684            channels
685                .iter()
686                .map(|channel| self.process(channel))
687                .collect()
688        }
689    }
690
691    /// Process interleaved multi-channel audio.
692    /// Converts interleaved format (e.g., `[L,R,L,R,L,R,...]` for stereo)
693    /// into separate Spectrum for each channel.
694    ///
695    /// # Arguments
696    ///
697    /// * `data` - Interleaved audio data
698    /// * `num_channels` - Number of channels
699    ///
700    /// # Panics
701    ///
702    /// Panics if `num_channels` is 0 or if `data.len()` is not divisible by `num_channels`.
703    ///
704    /// # Example
705    ///
706    /// ```
707    /// use stft_rs::prelude::*;
708    ///
709    /// let config = StftConfigF32::default_4096();
710    /// let stft = BatchStftF32::new(config);
711    ///
712    /// // Stereo interleaved: L,R,L,R,L,R,...
713    /// let interleaved = vec![0.0; 88200]; // 2 channels * 44100 samples
714    ///
715    /// let spectra = stft.process_interleaved(&interleaved, 2);
716    /// assert_eq!(spectra.len(), 2); // One spectrum per channel
717    /// ```
718    pub fn process_interleaved(&self, data: &[T], num_channels: usize) -> Vec<Spectrum<T>> {
719        let channels = utils::deinterleave(data, num_channels);
720        self.process_multichannel(&channels)
721    }
722}
723
724pub struct BatchIstft<T: Float + FftNum> {
725    config: StftConfig<T>,
726    window: Vec<T>,
727    ifft: Arc<dyn Fft<T>>,
728}
729
730impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchIstft<T> {
731    pub fn new(config: StftConfig<T>) -> Self {
732        let window = config.generate_window();
733        let mut planner = FftPlanner::new();
734        let ifft = planner.plan_fft_inverse(config.fft_size);
735
736        Self {
737            config,
738            window,
739            ifft,
740        }
741    }
742
743    pub fn process(&self, spectrum: &Spectrum<T>) -> Vec<T> {
744        assert_eq!(
745            spectrum.freq_bins,
746            self.config.freq_bins(),
747            "Frequency bins mismatch"
748        );
749
750        let num_frames = spectrum.num_frames;
751        let original_time_len = (num_frames - 1) * self.config.hop_size;
752        let pad_amount = self.config.fft_size / 2;
753        let padded_len = original_time_len + 2 * pad_amount;
754
755        let mut overlap_buffer = vec![T::zero(); padded_len];
756        let mut window_energy = vec![T::zero(); padded_len];
757        let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
758
759        // Precompute window energy normalization
760        for frame_idx in 0..num_frames {
761            let pos = frame_idx * self.config.hop_size;
762            for i in 0..self.config.fft_size {
763                match self.config.reconstruction_mode {
764                    ReconstructionMode::Ola => {
765                        window_energy[pos + i] = window_energy[pos + i] + self.window[i];
766                    }
767                    ReconstructionMode::Wola => {
768                        window_energy[pos + i] =
769                            window_energy[pos + i] + self.window[i] * self.window[i];
770                    }
771                }
772            }
773        }
774
775        // Process each frame
776        for frame_idx in 0..num_frames {
777            // Build full spectrum with conjugate symmetry
778            for bin in 0..spectrum.freq_bins {
779                ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
780            }
781
782            // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
783            for bin in 1..(spectrum.freq_bins - 1) {
784                ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
785            }
786
787            // Compute IFFT
788            self.ifft.process(&mut ifft_buffer);
789
790            // Overlap-add
791            let pos = frame_idx * self.config.hop_size;
792            for i in 0..self.config.fft_size {
793                let fft_size_t = T::from(self.config.fft_size).unwrap();
794                let sample = ifft_buffer[i].re / fft_size_t;
795
796                match self.config.reconstruction_mode {
797                    ReconstructionMode::Ola => {
798                        // OLA: no windowing on inverse
799                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
800                    }
801                    ReconstructionMode::Wola => {
802                        // WOLA: apply window on inverse
803                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
804                    }
805                }
806            }
807        }
808
809        // Normalize by window energy
810        let threshold = T::from(1e-8).unwrap();
811        for i in 0..padded_len {
812            if window_energy[i] > threshold {
813                overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
814            }
815        }
816
817        // Remove padding
818        overlap_buffer[pad_amount..pad_amount + original_time_len].to_vec()
819    }
820
821    /// Process spectrum and write into a pre-allocated output buffer.
822    /// The output buffer will be resized if needed.
823    pub fn process_into(&self, spectrum: &Spectrum<T>, output: &mut Vec<T>) {
824        assert_eq!(
825            spectrum.freq_bins,
826            self.config.freq_bins(),
827            "Frequency bins mismatch"
828        );
829
830        let num_frames = spectrum.num_frames;
831        let original_time_len = (num_frames - 1) * self.config.hop_size;
832        let pad_amount = self.config.fft_size / 2;
833        let padded_len = original_time_len + 2 * pad_amount;
834
835        let mut overlap_buffer = vec![T::zero(); padded_len];
836        let mut window_energy = vec![T::zero(); padded_len];
837        let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
838
839        // Precompute window energy normalization
840        for frame_idx in 0..num_frames {
841            let pos = frame_idx * self.config.hop_size;
842            for i in 0..self.config.fft_size {
843                match self.config.reconstruction_mode {
844                    ReconstructionMode::Ola => {
845                        window_energy[pos + i] = window_energy[pos + i] + self.window[i];
846                    }
847                    ReconstructionMode::Wola => {
848                        window_energy[pos + i] =
849                            window_energy[pos + i] + self.window[i] * self.window[i];
850                    }
851                }
852            }
853        }
854
855        // Process each frame
856        for frame_idx in 0..num_frames {
857            // Build full spectrum with conjugate symmetry
858            for bin in 0..spectrum.freq_bins {
859                ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
860            }
861
862            // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
863            for bin in 1..(spectrum.freq_bins - 1) {
864                ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
865            }
866
867            // Compute IFFT
868            self.ifft.process(&mut ifft_buffer);
869
870            // Overlap-add
871            let pos = frame_idx * self.config.hop_size;
872            for i in 0..self.config.fft_size {
873                let fft_size_t = T::from(self.config.fft_size).unwrap();
874                let sample = ifft_buffer[i].re / fft_size_t;
875
876                match self.config.reconstruction_mode {
877                    ReconstructionMode::Ola => {
878                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
879                    }
880                    ReconstructionMode::Wola => {
881                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
882                    }
883                }
884            }
885        }
886
887        // Normalize by window energy
888        let threshold = T::from(1e-8).unwrap();
889        for i in 0..padded_len {
890            if window_energy[i] > threshold {
891                overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
892            }
893        }
894
895        // Copy to output (resize if needed)
896        output.clear();
897        output.extend_from_slice(&overlap_buffer[pad_amount..pad_amount + original_time_len]);
898    }
899
900    /// Reconstruct multiple channels from their spectra.
901    /// Returns one Vec per channel.
902    ///
903    /// # Arguments
904    ///
905    /// * `spectra` - Slice of Spectrum, one per channel
906    ///
907    /// # Panics
908    ///
909    /// Panics if spectra is empty.
910    ///
911    /// # Example
912    ///
913    /// ```
914    /// use stft_rs::prelude::*;
915    ///
916    /// let config = StftConfigF32::default_4096();
917    /// let stft = BatchStftF32::new(config.clone());
918    /// let istft = BatchIstftF32::new(config);
919    ///
920    /// let left = vec![0.0; 44100];
921    /// let right = vec![0.0; 44100];
922    /// let channels = vec![left, right];
923    ///
924    /// let spectra = stft.process_multichannel(&channels);
925    /// let reconstructed = istft.process_multichannel(&spectra);
926    ///
927    /// assert_eq!(reconstructed.len(), 2); // One channel per spectrum
928    /// ```
929    pub fn process_multichannel(&self, spectra: &[Spectrum<T>]) -> Vec<Vec<T>> {
930        assert!(!spectra.is_empty(), "spectra must not be empty");
931
932        // Process each spectrum independently
933        #[cfg(feature = "rayon")]
934        {
935            use rayon::prelude::*;
936            spectra
937                .par_iter()
938                .map(|spectrum| self.process(spectrum))
939                .collect()
940        }
941        #[cfg(not(feature = "rayon"))]
942        {
943            spectra
944                .iter()
945                .map(|spectrum| self.process(spectrum))
946                .collect()
947        }
948    }
949
950    /// Reconstruct multiple channels and interleave them into a single buffer.
951    /// Converts separate channels back to interleaved format (e.g., `[L,R,L,R,L,R,...]` for stereo).
952    ///
953    /// # Arguments
954    ///
955    /// * `spectra` - Slice of Spectrum, one per channel
956    ///
957    /// # Panics
958    ///
959    /// Panics if spectra is empty or if channels have different lengths.
960    ///
961    /// # Example
962    ///
963    /// ```
964    /// use stft_rs::prelude::*;
965    ///
966    /// let config = StftConfigF32::default_4096();
967    /// let stft = BatchStftF32::new(config.clone());
968    /// let istft = BatchIstftF32::new(config);
969    ///
970    /// // Process interleaved stereo
971    /// let interleaved = vec![0.0; 88200]; // 2 channels * 44100 samples
972    /// let spectra = stft.process_interleaved(&interleaved, 2);
973    ///
974    /// // Reconstruct back to interleaved
975    /// let output = istft.process_multichannel_interleaved(&spectra);
976    /// // Output length may differ slightly due to padding/framing
977    /// assert_eq!(output.len() / 2, 44032); // samples per channel after reconstruction
978    /// ```
979    pub fn process_multichannel_interleaved(&self, spectra: &[Spectrum<T>]) -> Vec<T> {
980        let channels = self.process_multichannel(spectra);
981        utils::interleave(&channels)
982    }
983}
984
985pub struct StreamingStft<T: Float + FftNum> {
986    config: StftConfig<T>,
987    window: Vec<T>,
988    fft: Arc<dyn Fft<T>>,
989    input_buffer: VecDeque<T>,
990    frame_index: usize,
991    fft_buffer: Vec<Complex<T>>,
992}
993
994impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingStft<T> {
995    pub fn new(config: StftConfig<T>) -> Self {
996        let window = config.generate_window();
997        let mut planner = FftPlanner::new();
998        let fft = planner.plan_fft_forward(config.fft_size);
999        let fft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1000
1001        Self {
1002            config,
1003            window,
1004            fft,
1005            input_buffer: VecDeque::new(),
1006            frame_index: 0,
1007            fft_buffer,
1008        }
1009    }
1010
1011    pub fn push_samples(&mut self, samples: &[T]) -> Vec<SpectrumFrame<T>> {
1012        self.input_buffer.extend(samples.iter().copied());
1013
1014        let mut frames = Vec::new();
1015
1016        while self.input_buffer.len() >= self.config.fft_size {
1017            // Process one frame
1018            for i in 0..self.config.fft_size {
1019                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1020            }
1021
1022            self.fft.process(&mut self.fft_buffer);
1023
1024            let freq_bins = self.config.freq_bins();
1025            let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1026            frames.push(SpectrumFrame::from_data(data));
1027
1028            // Advance by hop size
1029            self.input_buffer.drain(..self.config.hop_size);
1030            self.frame_index += 1;
1031        }
1032
1033        frames
1034    }
1035
1036    /// Push samples and write frames into a pre-allocated buffer.
1037    /// Returns the number of frames written.
1038    pub fn push_samples_into(
1039        &mut self,
1040        samples: &[T],
1041        output: &mut Vec<SpectrumFrame<T>>,
1042    ) -> usize {
1043        self.input_buffer.extend(samples.iter().copied());
1044
1045        let initial_len = output.len();
1046
1047        while self.input_buffer.len() >= self.config.fft_size {
1048            // Process one frame
1049            for i in 0..self.config.fft_size {
1050                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1051            }
1052
1053            self.fft.process(&mut self.fft_buffer);
1054
1055            let freq_bins = self.config.freq_bins();
1056            let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1057            output.push(SpectrumFrame::from_data(data));
1058
1059            // Advance by hop size
1060            self.input_buffer.drain(..self.config.hop_size);
1061            self.frame_index += 1;
1062        }
1063
1064        output.len() - initial_len
1065    }
1066
1067    /// Push samples and write directly into pre-existing SpectrumFrame buffers.
1068    /// This is a zero-allocation method - frames must be pre-allocated with correct size.
1069    /// Returns the number of frames written.
1070    ///
1071    /// # Example
1072    /// ```ignore
1073    /// let mut frame_pool = vec![SpectrumFrame::new(config.freq_bins()); 16];
1074    /// let mut frame_index = 0;
1075    ///
1076    /// let frames_written = stft.push_samples_write(chunk, &mut frame_pool, &mut frame_index);
1077    /// // Process frames 0..frames_written
1078    /// ```
1079    pub fn push_samples_write(
1080        &mut self,
1081        samples: &[T],
1082        frame_pool: &mut [SpectrumFrame<T>],
1083        pool_index: &mut usize,
1084    ) -> usize {
1085        self.input_buffer.extend(samples.iter().copied());
1086
1087        let initial_index = *pool_index;
1088        let freq_bins = self.config.freq_bins();
1089
1090        while self.input_buffer.len() >= self.config.fft_size && *pool_index < frame_pool.len() {
1091            // Process one frame
1092            for i in 0..self.config.fft_size {
1093                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1094            }
1095
1096            self.fft.process(&mut self.fft_buffer);
1097
1098            // Write directly into the pre-allocated frame
1099            let frame = &mut frame_pool[*pool_index];
1100            debug_assert_eq!(
1101                frame.freq_bins, freq_bins,
1102                "Frame pool frames must match freq_bins"
1103            );
1104            frame.data[..freq_bins].copy_from_slice(&self.fft_buffer[..freq_bins]);
1105
1106            // Advance by hop size
1107            self.input_buffer.drain(..self.config.hop_size);
1108            self.frame_index += 1;
1109            *pool_index += 1;
1110        }
1111
1112        *pool_index - initial_index
1113    }
1114
1115    pub fn flush(&mut self) -> Vec<SpectrumFrame<T>> {
1116        // For streaming, we typically don't process partial frames
1117        // Could zero-pad if needed, but that changes the signal
1118        Vec::new()
1119    }
1120
1121    pub fn reset(&mut self) {
1122        self.input_buffer.clear();
1123        self.frame_index = 0;
1124    }
1125
1126    pub fn buffered_samples(&self) -> usize {
1127        self.input_buffer.len()
1128    }
1129}
1130
1131/// Multi-channel streaming STFT processor with independent state per channel.
1132pub struct MultiChannelStreamingStft<T: Float + FftNum> {
1133    processors: Vec<StreamingStft<T>>,
1134}
1135
1136impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingStft<T> {
1137    /// Create a new multi-channel streaming STFT processor.
1138    ///
1139    /// # Arguments
1140    ///
1141    /// * `config` - STFT configuration
1142    /// * `num_channels` - Number of channels
1143    pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1144        assert!(num_channels > 0, "num_channels must be > 0");
1145        let processors = (0..num_channels)
1146            .map(|_| StreamingStft::new(config.clone()))
1147            .collect();
1148        Self { processors }
1149    }
1150
1151    /// Push samples for all channels and get frames for each channel.
1152    /// Returns Vec<Vec<SpectrumFrame>>, outer Vec = channels, inner Vec = frames.
1153    ///
1154    /// # Arguments
1155    ///
1156    /// * `channels` - Slice of sample slices, one per channel
1157    ///
1158    /// # Panics
1159    ///
1160    /// Panics if channels.len() doesn't match num_channels.
1161    pub fn push_samples(&mut self, channels: &[&[T]]) -> Vec<Vec<SpectrumFrame<T>>> {
1162        assert_eq!(
1163            channels.len(),
1164            self.processors.len(),
1165            "Expected {} channels, got {}",
1166            self.processors.len(),
1167            channels.len()
1168        );
1169
1170        #[cfg(feature = "rayon")]
1171        {
1172            use rayon::prelude::*;
1173            self.processors
1174                .par_iter_mut()
1175                .zip(channels.par_iter())
1176                .map(|(stft, channel)| stft.push_samples(channel))
1177                .collect()
1178        }
1179        #[cfg(not(feature = "rayon"))]
1180        {
1181            self.processors
1182                .iter_mut()
1183                .zip(channels.iter())
1184                .map(|(stft, channel)| stft.push_samples(channel))
1185                .collect()
1186        }
1187    }
1188
1189    /// Flush all channels and return remaining frames.
1190    pub fn flush(&mut self) -> Vec<Vec<SpectrumFrame<T>>> {
1191        #[cfg(feature = "rayon")]
1192        {
1193            use rayon::prelude::*;
1194            self.processors
1195                .par_iter_mut()
1196                .map(|stft| stft.flush())
1197                .collect()
1198        }
1199        #[cfg(not(feature = "rayon"))]
1200        {
1201            self.processors
1202                .iter_mut()
1203                .map(|stft| stft.flush())
1204                .collect()
1205        }
1206    }
1207
1208    /// Reset all channels.
1209    pub fn reset(&mut self) {
1210        #[cfg(feature = "rayon")]
1211        {
1212            use rayon::prelude::*;
1213            self.processors.par_iter_mut().for_each(|stft| stft.reset());
1214        }
1215        #[cfg(not(feature = "rayon"))]
1216        {
1217            self.processors.iter_mut().for_each(|stft| stft.reset());
1218        }
1219    }
1220
1221    /// Get the number of channels.
1222    pub fn num_channels(&self) -> usize {
1223        self.processors.len()
1224    }
1225}
1226
1227pub struct StreamingIstft<T: Float + FftNum> {
1228    config: StftConfig<T>,
1229    window: Vec<T>,
1230    ifft: Arc<dyn Fft<T>>,
1231    overlap_buffer: Vec<T>,
1232    window_energy: Vec<T>,
1233    output_position: usize,
1234    frames_processed: usize,
1235    ifft_buffer: Vec<Complex<T>>,
1236}
1237
1238impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingIstft<T> {
1239    pub fn new(config: StftConfig<T>) -> Self {
1240        let window = config.generate_window();
1241        let mut planner = FftPlanner::new();
1242        let ifft = planner.plan_fft_inverse(config.fft_size);
1243
1244        // Buffer needs to hold enough samples for full overlap
1245        // For proper reconstruction, need at least fft_size samples
1246        let buffer_size = config.fft_size * 2;
1247        let ifft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1248
1249        Self {
1250            config,
1251            window,
1252            ifft,
1253            overlap_buffer: vec![T::zero(); buffer_size],
1254            window_energy: vec![T::zero(); buffer_size],
1255            output_position: 0,
1256            frames_processed: 0,
1257            ifft_buffer,
1258        }
1259    }
1260
1261    pub fn push_frame(&mut self, frame: &SpectrumFrame<T>) -> Vec<T> {
1262        assert_eq!(
1263            frame.freq_bins,
1264            self.config.freq_bins(),
1265            "Frequency bins mismatch"
1266        );
1267
1268        // Build full spectrum with conjugate symmetry
1269        for bin in 0..frame.freq_bins {
1270            self.ifft_buffer[bin] = frame.data[bin];
1271        }
1272
1273        // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
1274        for bin in 1..(frame.freq_bins - 1) {
1275            self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1276        }
1277
1278        // Compute IFFT
1279        self.ifft.process(&mut self.ifft_buffer);
1280
1281        // Overlap-add into buffer at the current write position
1282        let write_pos = self.frames_processed * self.config.hop_size;
1283        for i in 0..self.config.fft_size {
1284            let fft_size_t = T::from(self.config.fft_size).unwrap();
1285            let sample = self.ifft_buffer[i].re / fft_size_t;
1286            let buf_idx = write_pos + i;
1287
1288            // Extend buffers if needed
1289            if buf_idx >= self.overlap_buffer.len() {
1290                self.overlap_buffer.resize(buf_idx + 1, T::zero());
1291                self.window_energy.resize(buf_idx + 1, T::zero());
1292            }
1293
1294            match self.config.reconstruction_mode {
1295                ReconstructionMode::Ola => {
1296                    self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1297                    self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1298                }
1299                ReconstructionMode::Wola => {
1300                    self.overlap_buffer[buf_idx] =
1301                        self.overlap_buffer[buf_idx] + sample * self.window[i];
1302                    self.window_energy[buf_idx] =
1303                        self.window_energy[buf_idx] + self.window[i] * self.window[i];
1304                }
1305            }
1306        }
1307
1308        self.frames_processed += 1;
1309
1310        // Calculate how many samples are "ready" (have full window energy)
1311        // Samples are ready when no future frames will contribute to them
1312        let ready_until = if self.frames_processed == 1 {
1313            0 // First frame: no output yet, need overlap
1314        } else {
1315            // Samples before the current frame's start position are complete
1316            (self.frames_processed - 1) * self.config.hop_size
1317        };
1318
1319        // Extract ready samples
1320        let output_start = self.output_position;
1321        let output_end = ready_until;
1322        let mut output = Vec::new();
1323
1324        let threshold = T::from(1e-8).unwrap();
1325        if output_end > output_start {
1326            for i in output_start..output_end {
1327                let normalized = if self.window_energy[i] > threshold {
1328                    self.overlap_buffer[i] / self.window_energy[i]
1329                } else {
1330                    T::zero()
1331                };
1332                output.push(normalized);
1333            }
1334            self.output_position = output_end;
1335        }
1336
1337        output
1338    }
1339
1340    /// Push a frame and write output samples into a pre-allocated buffer.
1341    /// Returns the number of samples written.
1342    pub fn push_frame_into(&mut self, frame: &SpectrumFrame<T>, output: &mut Vec<T>) -> usize {
1343        assert_eq!(
1344            frame.freq_bins,
1345            self.config.freq_bins(),
1346            "Frequency bins mismatch"
1347        );
1348
1349        // Build full spectrum with conjugate symmetry
1350        for bin in 0..frame.freq_bins {
1351            self.ifft_buffer[bin] = frame.data[bin];
1352        }
1353
1354        // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
1355        for bin in 1..(frame.freq_bins - 1) {
1356            self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1357        }
1358
1359        // Compute IFFT
1360        self.ifft.process(&mut self.ifft_buffer);
1361
1362        // Overlap-add into buffer at the current write position
1363        let write_pos = self.frames_processed * self.config.hop_size;
1364        for i in 0..self.config.fft_size {
1365            let fft_size_t = T::from(self.config.fft_size).unwrap();
1366            let sample = self.ifft_buffer[i].re / fft_size_t;
1367            let buf_idx = write_pos + i;
1368
1369            // Extend buffers if needed
1370            if buf_idx >= self.overlap_buffer.len() {
1371                self.overlap_buffer.resize(buf_idx + 1, T::zero());
1372                self.window_energy.resize(buf_idx + 1, T::zero());
1373            }
1374
1375            match self.config.reconstruction_mode {
1376                ReconstructionMode::Ola => {
1377                    self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1378                    self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1379                }
1380                ReconstructionMode::Wola => {
1381                    self.overlap_buffer[buf_idx] =
1382                        self.overlap_buffer[buf_idx] + sample * self.window[i];
1383                    self.window_energy[buf_idx] =
1384                        self.window_energy[buf_idx] + self.window[i] * self.window[i];
1385                }
1386            }
1387        }
1388
1389        self.frames_processed += 1;
1390
1391        // Calculate how many samples are "ready" (have full window energy)
1392        // Samples are ready when no future frames will contribute to them
1393        let ready_until = if self.frames_processed == 1 {
1394            0 // First frame: no output yet, need overlap
1395        } else {
1396            // Samples before the current frame's start position are complete
1397            (self.frames_processed - 1) * self.config.hop_size
1398        };
1399
1400        // Extract ready samples
1401        let output_start = self.output_position;
1402        let output_end = ready_until;
1403        let initial_len = output.len();
1404
1405        let threshold = T::from(1e-8).unwrap();
1406        if output_end > output_start {
1407            for i in output_start..output_end {
1408                let normalized = if self.window_energy[i] > threshold {
1409                    self.overlap_buffer[i] / self.window_energy[i]
1410                } else {
1411                    T::zero()
1412                };
1413                output.push(normalized);
1414            }
1415            self.output_position = output_end;
1416        }
1417
1418        output.len() - initial_len
1419    }
1420
1421    pub fn flush(&mut self) -> Vec<T> {
1422        // Return all remaining samples in buffer
1423        let mut output = Vec::new();
1424        let threshold = T::from(1e-8).unwrap();
1425        for i in self.output_position..self.overlap_buffer.len() {
1426            if self.window_energy[i] > threshold {
1427                output.push(self.overlap_buffer[i] / self.window_energy[i]);
1428            } else if i < (self.frames_processed * self.config.hop_size + self.config.fft_size) {
1429                output.push(T::zero()); // Sample in valid range but no window energy
1430            } else {
1431                break; // Past the end of valid data
1432            }
1433        }
1434
1435        // Determine the actual end of valid data
1436        let valid_end =
1437            (self.frames_processed.saturating_sub(1)) * self.config.hop_size + self.config.fft_size;
1438        if output.len() > valid_end - self.output_position {
1439            output.truncate(valid_end - self.output_position);
1440        }
1441
1442        self.reset();
1443        output
1444    }
1445
1446    pub fn reset(&mut self) {
1447        self.overlap_buffer.clear();
1448        self.overlap_buffer
1449            .resize(self.config.fft_size * 2, T::zero());
1450        self.window_energy.clear();
1451        self.window_energy
1452            .resize(self.config.fft_size * 2, T::zero());
1453        self.output_position = 0;
1454        self.frames_processed = 0;
1455    }
1456}
1457
1458/// Multi-channel streaming iSTFT processor with independent state per channel.
1459pub struct MultiChannelStreamingIstft<T: Float + FftNum> {
1460    processors: Vec<StreamingIstft<T>>,
1461}
1462
1463impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingIstft<T> {
1464    /// Create a new multi-channel streaming iSTFT processor.
1465    ///
1466    /// # Arguments
1467    ///
1468    /// * `config` - STFT configuration
1469    /// * `num_channels` - Number of channels
1470    pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1471        assert!(num_channels > 0, "num_channels must be > 0");
1472        let processors = (0..num_channels)
1473            .map(|_| StreamingIstft::new(config.clone()))
1474            .collect();
1475        Self { processors }
1476    }
1477
1478    /// Push frames for all channels and get samples for each channel.
1479    /// Returns Vec<Vec<T>>, outer Vec = channels, inner Vec = samples.
1480    ///
1481    /// # Arguments
1482    ///
1483    /// * `frames` - Slice of frames, one per channel
1484    ///
1485    /// # Panics
1486    ///
1487    /// Panics if frames.len() doesn't match num_channels.
1488    pub fn push_frames(&mut self, frames: &[&SpectrumFrame<T>]) -> Vec<Vec<T>> {
1489        assert_eq!(
1490            frames.len(),
1491            self.processors.len(),
1492            "Expected {} channels, got {}",
1493            self.processors.len(),
1494            frames.len()
1495        );
1496
1497        #[cfg(feature = "rayon")]
1498        {
1499            use rayon::prelude::*;
1500            self.processors
1501                .par_iter_mut()
1502                .zip(frames.par_iter())
1503                .map(|(istft, frame)| istft.push_frame(frame))
1504                .collect()
1505        }
1506        #[cfg(not(feature = "rayon"))]
1507        {
1508            self.processors
1509                .iter_mut()
1510                .zip(frames.iter())
1511                .map(|(istft, frame)| istft.push_frame(frame))
1512                .collect()
1513        }
1514    }
1515
1516    /// Flush all channels and return remaining samples.
1517    pub fn flush(&mut self) -> Vec<Vec<T>> {
1518        #[cfg(feature = "rayon")]
1519        {
1520            use rayon::prelude::*;
1521            self.processors
1522                .par_iter_mut()
1523                .map(|istft| istft.flush())
1524                .collect()
1525        }
1526        #[cfg(not(feature = "rayon"))]
1527        {
1528            self.processors
1529                .iter_mut()
1530                .map(|istft| istft.flush())
1531                .collect()
1532        }
1533    }
1534
1535    /// Reset all channels.
1536    pub fn reset(&mut self) {
1537        #[cfg(feature = "rayon")]
1538        {
1539            use rayon::prelude::*;
1540            self.processors
1541                .par_iter_mut()
1542                .for_each(|istft| istft.reset());
1543        }
1544        #[cfg(not(feature = "rayon"))]
1545        {
1546            self.processors.iter_mut().for_each(|istft| istft.reset());
1547        }
1548    }
1549
1550    /// Get the number of channels.
1551    pub fn num_channels(&self) -> usize {
1552        self.processors.len()
1553    }
1554}
1555
1556// Type aliases for common float types
1557pub type StftConfigF32 = StftConfig<f32>;
1558pub type StftConfigF64 = StftConfig<f64>;
1559
1560pub type BatchStftF32 = BatchStft<f32>;
1561pub type BatchStftF64 = BatchStft<f64>;
1562
1563pub type BatchIstftF32 = BatchIstft<f32>;
1564pub type BatchIstftF64 = BatchIstft<f64>;
1565
1566pub type StreamingStftF32 = StreamingStft<f32>;
1567pub type StreamingStftF64 = StreamingStft<f64>;
1568
1569pub type StreamingIstftF32 = StreamingIstft<f32>;
1570pub type StreamingIstftF64 = StreamingIstft<f64>;
1571
1572pub type SpectrumF32 = Spectrum<f32>;
1573pub type SpectrumF64 = Spectrum<f64>;
1574
1575pub type SpectrumFrameF32 = SpectrumFrame<f32>;
1576pub type SpectrumFrameF64 = SpectrumFrame<f64>;
1577
1578pub type MultiChannelStreamingStftF32 = MultiChannelStreamingStft<f32>;
1579pub type MultiChannelStreamingStftF64 = MultiChannelStreamingStft<f64>;
1580
1581pub type MultiChannelStreamingIstftF32 = MultiChannelStreamingIstft<f32>;
1582pub type MultiChannelStreamingIstftF64 = MultiChannelStreamingIstft<f64>;