stft_rs/
lib.rs

1/*MIT License
2
3Copyright (c) 2025 David Maseda Neira
4
5Permission is hereby granted, free of charge, to any person obtaining a copy
6of this software and associated documentation files (the "Software"), to deal
7in the Software without restriction, including without limitation the rights
8to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9copies of the Software, and to permit persons to whom the Software is
10furnished to do so, subject to the following conditions:
11
12The above copyright notice and this permission notice shall be included in all
13copies or substantial portions of the Software.
14
15THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21SOFTWARE.
22*/
23
24use num_traits::{Float, FromPrimitive};
25use rustfft::num_complex::Complex;
26use rustfft::{Fft, FftNum, FftPlanner};
27use std::collections::VecDeque;
28use std::fmt;
29use std::sync::Arc;
30
31mod utils;
32pub use utils::{apply_padding, deinterleave, deinterleave_into, interleave, interleave_into};
33
34pub mod prelude {
35    pub use crate::utils::{
36        apply_padding, deinterleave, deinterleave_into, interleave, interleave_into,
37    };
38    pub use crate::{
39        BatchIstft, BatchIstftF32, BatchIstftF64, BatchStft, BatchStftF32, BatchStftF64,
40        MultiChannelStreamingIstft, MultiChannelStreamingIstftF32, MultiChannelStreamingIstftF64,
41        MultiChannelStreamingStft, MultiChannelStreamingStftF32, MultiChannelStreamingStftF64,
42        PadMode, ReconstructionMode, Spectrum, SpectrumF32, SpectrumF64, SpectrumFrame,
43        SpectrumFrameF32, SpectrumFrameF64, StftConfig, StftConfigBuilder, StftConfigBuilderF32,
44        StftConfigBuilderF64, StftConfigF32, StftConfigF64, StreamingIstft, StreamingIstftF32,
45        StreamingIstftF64, StreamingStft, StreamingStftF32, StreamingStftF64, WindowType,
46    };
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum ReconstructionMode {
51    /// Overlap-Add: normalize by sum(w), requires COLA condition
52    Ola,
53
54    /// Weighted Overlap-Add: normalize by sum(w^2), requires NOLA condition
55    Wola,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum WindowType {
60    Hann,
61    Hamming,
62    Blackman,
63}
64
65#[derive(Debug, Clone)]
66pub enum ConfigError<T: Float + fmt::Debug> {
67    NolaViolation { min_energy: T, threshold: T },
68    ColaViolation { max_deviation: T, threshold: T },
69    InvalidHopSize,
70    InvalidFftSize,
71}
72
73impl<T: Float + fmt::Display + fmt::Debug> fmt::Display for ConfigError<T> {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        match self {
76            ConfigError::NolaViolation {
77                min_energy,
78                threshold,
79            } => {
80                write!(
81                    f,
82                    "NOLA condition violated: min_energy={} < threshold={}",
83                    min_energy, threshold
84                )
85            }
86            ConfigError::ColaViolation {
87                max_deviation,
88                threshold,
89            } => {
90                write!(
91                    f,
92                    "COLA condition violated: max_deviation={} > threshold={}",
93                    max_deviation, threshold
94                )
95            }
96            ConfigError::InvalidHopSize => write!(f, "Invalid hop size"),
97            ConfigError::InvalidFftSize => write!(f, "Invalid FFT size"),
98        }
99    }
100}
101
102impl<T: Float + fmt::Display + fmt::Debug> std::error::Error for ConfigError<T> {}
103
104#[derive(Debug, Clone, Copy)]
105pub enum PadMode {
106    Reflect,
107    Zero,
108    Edge,
109}
110
111#[derive(Clone)]
112pub struct StftConfig<T: Float> {
113    pub fft_size: usize,
114    pub hop_size: usize,
115    pub window: WindowType,
116    pub reconstruction_mode: ReconstructionMode,
117    _phantom: std::marker::PhantomData<T>,
118}
119
120impl<T: Float + FromPrimitive + fmt::Debug> StftConfig<T> {
121    fn nola_threshold() -> T {
122        T::from(1e-8).unwrap()
123    }
124
125    fn cola_relative_tolerance() -> T {
126        T::from(1e-4).unwrap()
127    }
128
129    #[deprecated(
130        since = "0.4.0",
131        note = "Use `StftConfig::builder()` instead for a more flexible API"
132    )]
133    pub fn new(
134        fft_size: usize,
135        hop_size: usize,
136        window: WindowType,
137        reconstruction_mode: ReconstructionMode,
138    ) -> Result<Self, ConfigError<T>> {
139        if fft_size == 0 || !fft_size.is_power_of_two() {
140            return Err(ConfigError::InvalidFftSize);
141        }
142        if hop_size == 0 || hop_size > fft_size {
143            return Err(ConfigError::InvalidHopSize);
144        }
145
146        let config = Self {
147            fft_size,
148            hop_size,
149            window,
150            reconstruction_mode,
151            _phantom: std::marker::PhantomData,
152        };
153
154        // Validate appropriate condition based on reconstruction mode
155        match reconstruction_mode {
156            ReconstructionMode::Ola => config.validate_cola()?,
157            ReconstructionMode::Wola => config.validate_nola()?,
158        }
159
160        Ok(config)
161    }
162
163    /// Create a new builder for StftConfig
164    pub fn builder() -> StftConfigBuilder<T> {
165        StftConfigBuilder::new()
166    }
167
168    /// Default: 4096 FFT, 1024 hop, Hann window, OLA mode
169    #[allow(deprecated)]
170    pub fn default_4096() -> Self {
171        Self::new(4096, 1024, WindowType::Hann, ReconstructionMode::Ola)
172            .expect("Default config should always be valid")
173    }
174
175    pub fn freq_bins(&self) -> usize {
176        self.fft_size / 2 + 1
177    }
178
179    pub fn overlap_percent(&self) -> T {
180        let one = T::one();
181        let hundred = T::from(100.0).unwrap();
182        (one - T::from(self.hop_size).unwrap() / T::from(self.fft_size).unwrap()) * hundred
183    }
184
185    fn generate_window(&self) -> Vec<T> {
186        generate_window(self.window, self.fft_size)
187    }
188
189    /// Validate NOLA condition: sum(w^2) > threshold everywhere
190    pub fn validate_nola(&self) -> Result<(), ConfigError<T>> {
191        let window = self.generate_window();
192        let num_overlaps = (self.fft_size + self.hop_size - 1) / self.hop_size;
193        let test_len = self.fft_size + (num_overlaps - 1) * self.hop_size;
194        let mut energy = vec![T::zero(); test_len];
195
196        for i in 0..num_overlaps {
197            let offset = i * self.hop_size;
198            for j in 0..self.fft_size {
199                if offset + j < test_len {
200                    energy[offset + j] = energy[offset + j] + window[j] * window[j];
201                }
202            }
203        }
204
205        // Check the steady-state region (skip edges)
206        let start = self.fft_size / 2;
207        let end = test_len - self.fft_size / 2;
208        let min_energy = energy[start..end]
209            .iter()
210            .copied()
211            .min_by(|a, b| a.partial_cmp(b).unwrap())
212            .unwrap_or_else(T::zero);
213
214        if min_energy < Self::nola_threshold() {
215            return Err(ConfigError::NolaViolation {
216                min_energy,
217                threshold: Self::nola_threshold(),
218            });
219        }
220
221        Ok(())
222    }
223
224    /// Validate weak COLA condition: sum(w) is constant (within relative tolerance)
225    pub fn validate_cola(&self) -> Result<(), ConfigError<T>> {
226        let window = self.generate_window();
227        let window_len = window.len();
228
229        let mut cola_sum_period = vec![T::zero(); self.hop_size];
230        for i in 0..window_len {
231            let idx = i % self.hop_size;
232            cola_sum_period[idx] = cola_sum_period[idx] + window[i];
233        }
234
235        let zero = T::zero();
236        let min_sum = cola_sum_period
237            .iter()
238            .min_by(|a, b| a.partial_cmp(b).unwrap())
239            .unwrap_or(&zero);
240        let max_sum = cola_sum_period
241            .iter()
242            .max_by(|a, b| a.partial_cmp(b).unwrap())
243            .unwrap_or(&zero);
244
245        let epsilon = T::from(1e-9).unwrap();
246        if *max_sum < epsilon {
247            return Err(ConfigError::ColaViolation {
248                max_deviation: T::infinity(),
249                threshold: Self::cola_relative_tolerance(),
250            });
251        }
252
253        let ripple = (*max_sum - *min_sum) / *max_sum;
254
255        let is_compliant = ripple < Self::cola_relative_tolerance();
256
257        if !is_compliant {
258            return Err(ConfigError::ColaViolation {
259                max_deviation: ripple,
260                threshold: Self::cola_relative_tolerance(),
261            });
262        }
263        Ok(())
264    }
265}
266
267/// Builder for StftConfig with fluent API
268pub struct StftConfigBuilder<T: Float> {
269    fft_size: Option<usize>,
270    hop_size: Option<usize>,
271    window: WindowType,
272    reconstruction_mode: ReconstructionMode,
273    _phantom: std::marker::PhantomData<T>,
274}
275
276impl<T: Float + FromPrimitive + fmt::Debug> StftConfigBuilder<T> {
277    /// Create a new builder with default values (Hann window, OLA mode)
278    pub fn new() -> Self {
279        Self {
280            fft_size: None,
281            hop_size: None,
282            window: WindowType::Hann,
283            reconstruction_mode: ReconstructionMode::Ola,
284            _phantom: std::marker::PhantomData,
285        }
286    }
287
288    /// Set the FFT size (must be a power of two)
289    pub fn fft_size(mut self, fft_size: usize) -> Self {
290        self.fft_size = Some(fft_size);
291        self
292    }
293
294    /// Set the hop size (must be > 0 and <= fft_size)
295    pub fn hop_size(mut self, hop_size: usize) -> Self {
296        self.hop_size = Some(hop_size);
297        self
298    }
299
300    /// Set the window type (default: Hann)
301    pub fn window(mut self, window: WindowType) -> Self {
302        self.window = window;
303        self
304    }
305
306    /// Set the reconstruction mode (default: OLA)
307    pub fn reconstruction_mode(mut self, mode: ReconstructionMode) -> Self {
308        self.reconstruction_mode = mode;
309        self
310    }
311
312    /// Build the StftConfig, validating all parameters
313    ///
314    /// Returns an error if:
315    /// - fft_size is not set or not a power of two
316    /// - hop_size is not set, zero, or greater than fft_size
317    /// - COLA/NOLA conditions are violated
318    #[allow(deprecated)]
319    pub fn build(self) -> Result<StftConfig<T>, ConfigError<T>> {
320        let fft_size = self.fft_size.ok_or(ConfigError::InvalidFftSize)?;
321        let hop_size = self.hop_size.ok_or(ConfigError::InvalidHopSize)?;
322
323        StftConfig::new(fft_size, hop_size, self.window, self.reconstruction_mode)
324    }
325}
326
327impl<T: Float + FromPrimitive + fmt::Debug> Default for StftConfigBuilder<T> {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333fn generate_window<T: Float + FromPrimitive>(window_type: WindowType, size: usize) -> Vec<T> {
334    let pi = T::from(std::f64::consts::PI).unwrap();
335    let two = T::from(2.0).unwrap();
336
337    match window_type {
338        WindowType::Hann => (0..size)
339            .map(|i| {
340                let half = T::from(0.5).unwrap();
341                let one = T::one();
342                let i_t = T::from(i).unwrap();
343                let size_m1 = T::from(size - 1).unwrap();
344                half * (one - (two * pi * i_t / size_m1).cos())
345            })
346            .collect(),
347        WindowType::Hamming => (0..size)
348            .map(|i| {
349                let i_t = T::from(i).unwrap();
350                let size_m1 = T::from(size - 1).unwrap();
351                T::from(0.54).unwrap() - T::from(0.46).unwrap() * (two * pi * i_t / size_m1).cos()
352            })
353            .collect(),
354        WindowType::Blackman => (0..size)
355            .map(|i| {
356                let i_t = T::from(i).unwrap();
357                let size_m1 = T::from(size - 1).unwrap();
358                let angle = two * pi * i_t / size_m1;
359                T::from(0.42).unwrap() - T::from(0.5).unwrap() * angle.cos()
360                    + T::from(0.08).unwrap() * (two * angle).cos()
361            })
362            .collect(),
363    }
364}
365
366#[derive(Clone)]
367pub struct SpectrumFrame<T: Float> {
368    pub freq_bins: usize,
369    pub data: Vec<Complex<T>>,
370}
371
372impl<T: Float> SpectrumFrame<T> {
373    pub fn new(freq_bins: usize) -> Self {
374        Self {
375            freq_bins,
376            data: vec![Complex::new(T::zero(), T::zero()); freq_bins],
377        }
378    }
379
380    pub fn from_data(data: Vec<Complex<T>>) -> Self {
381        let freq_bins = data.len();
382        Self { freq_bins, data }
383    }
384
385    /// Prepare frame for reuse by clearing data (keeps capacity)
386    pub fn clear(&mut self) {
387        for val in &mut self.data {
388            *val = Complex::new(T::zero(), T::zero());
389        }
390    }
391
392    /// Resize frame if needed to match freq_bins
393    pub fn resize_if_needed(&mut self, freq_bins: usize) {
394        if self.freq_bins != freq_bins {
395            self.freq_bins = freq_bins;
396            self.data
397                .resize(freq_bins, Complex::new(T::zero(), T::zero()));
398        }
399    }
400
401    /// Write data from a slice into this frame
402    pub fn write_from_slice(&mut self, data: &[Complex<T>]) {
403        self.resize_if_needed(data.len());
404        self.data.copy_from_slice(data);
405    }
406
407    /// Get the magnitude of a frequency bin
408    #[inline]
409    pub fn magnitude(&self, bin: usize) -> T {
410        let c = &self.data[bin];
411        (c.re * c.re + c.im * c.im).sqrt()
412    }
413
414    /// Get the phase of a frequency bin in radians
415    #[inline]
416    pub fn phase(&self, bin: usize) -> T {
417        let c = &self.data[bin];
418        c.im.atan2(c.re)
419    }
420
421    /// Set a frequency bin from magnitude and phase
422    pub fn set_magnitude_phase(&mut self, bin: usize, magnitude: T, phase: T) {
423        self.data[bin] = Complex::new(magnitude * phase.cos(), magnitude * phase.sin());
424    }
425
426    /// Create a SpectrumFrame from magnitude and phase arrays
427    pub fn from_magnitude_phase(magnitudes: &[T], phases: &[T]) -> Self {
428        assert_eq!(
429            magnitudes.len(),
430            phases.len(),
431            "Magnitude and phase arrays must have same length"
432        );
433        let freq_bins = magnitudes.len();
434        let data: Vec<Complex<T>> = magnitudes
435            .iter()
436            .zip(phases.iter())
437            .map(|(mag, phase)| Complex::new(*mag * phase.cos(), *mag * phase.sin()))
438            .collect();
439        Self { freq_bins, data }
440    }
441
442    /// Get all magnitudes as a Vec
443    pub fn magnitudes(&self) -> Vec<T> {
444        self.data
445            .iter()
446            .map(|c| (c.re * c.re + c.im * c.im).sqrt())
447            .collect()
448    }
449
450    /// Get all phases as a Vec
451    pub fn phases(&self) -> Vec<T> {
452        self.data.iter().map(|c| c.im.atan2(c.re)).collect()
453    }
454}
455
456#[derive(Clone)]
457pub struct Spectrum<T: Float> {
458    pub num_frames: usize,
459    pub freq_bins: usize,
460    pub data: Vec<T>,
461}
462
463impl<T: Float> Spectrum<T> {
464    pub fn new(num_frames: usize, freq_bins: usize) -> Self {
465        Self {
466            num_frames,
467            freq_bins,
468            data: vec![T::zero(); 2 * num_frames * freq_bins],
469        }
470    }
471
472    #[inline]
473    pub fn real(&self, frame: usize, bin: usize) -> T {
474        self.data[frame * self.freq_bins + bin]
475    }
476
477    #[inline]
478    pub fn imag(&self, frame: usize, bin: usize) -> T {
479        let offset = self.num_frames * self.freq_bins;
480        self.data[offset + frame * self.freq_bins + bin]
481    }
482
483    #[inline]
484    pub fn get_complex(&self, frame: usize, bin: usize) -> Complex<T> {
485        Complex::new(self.real(frame, bin), self.imag(frame, bin))
486    }
487
488    pub fn frames(&self) -> impl Iterator<Item = SpectrumFrame<T>> + '_ {
489        (0..self.num_frames).map(move |frame_idx| {
490            let data: Vec<Complex<T>> = (0..self.freq_bins)
491                .map(|bin| self.get_complex(frame_idx, bin))
492                .collect();
493            SpectrumFrame::from_data(data)
494        })
495    }
496
497    /// Set the real part of a bin
498    #[inline]
499    pub fn set_real(&mut self, frame: usize, bin: usize, value: T) {
500        self.data[frame * self.freq_bins + bin] = value;
501    }
502
503    /// Set the imaginary part of a bin
504    #[inline]
505    pub fn set_imag(&mut self, frame: usize, bin: usize, value: T) {
506        let offset = self.num_frames * self.freq_bins;
507        self.data[offset + frame * self.freq_bins + bin] = value;
508    }
509
510    /// Set a bin from a complex value
511    #[inline]
512    pub fn set_complex(&mut self, frame: usize, bin: usize, value: Complex<T>) {
513        self.set_real(frame, bin, value.re);
514        self.set_imag(frame, bin, value.im);
515    }
516
517    /// Get the magnitude of a frequency bin
518    #[inline]
519    pub fn magnitude(&self, frame: usize, bin: usize) -> T {
520        let re = self.real(frame, bin);
521        let im = self.imag(frame, bin);
522        (re * re + im * im).sqrt()
523    }
524
525    /// Get the phase of a frequency bin in radians
526    #[inline]
527    pub fn phase(&self, frame: usize, bin: usize) -> T {
528        let re = self.real(frame, bin);
529        let im = self.imag(frame, bin);
530        im.atan2(re)
531    }
532
533    /// Set a frequency bin from magnitude and phase
534    pub fn set_magnitude_phase(&mut self, frame: usize, bin: usize, magnitude: T, phase: T) {
535        self.set_real(frame, bin, magnitude * phase.cos());
536        self.set_imag(frame, bin, magnitude * phase.sin());
537    }
538
539    /// Get all magnitudes for a frame
540    pub fn frame_magnitudes(&self, frame: usize) -> Vec<T> {
541        (0..self.freq_bins)
542            .map(|bin| self.magnitude(frame, bin))
543            .collect()
544    }
545
546    /// Get all phases for a frame
547    pub fn frame_phases(&self, frame: usize) -> Vec<T> {
548        (0..self.freq_bins)
549            .map(|bin| self.phase(frame, bin))
550            .collect()
551    }
552
553    /// Apply a function to all bins
554    pub fn apply<F>(&mut self, mut f: F)
555    where
556        F: FnMut(usize, usize, Complex<T>) -> Complex<T>,
557    {
558        for frame in 0..self.num_frames {
559            for bin in 0..self.freq_bins {
560                let c = self.get_complex(frame, bin);
561                let new_c = f(frame, bin, c);
562                self.set_complex(frame, bin, new_c);
563            }
564        }
565    }
566
567    /// Apply a gain to a range of bins across all frames
568    pub fn apply_gain(&mut self, bin_range: std::ops::Range<usize>, gain: T) {
569        for frame in 0..self.num_frames {
570            for bin in bin_range.clone() {
571                if bin < self.freq_bins {
572                    let c = self.get_complex(frame, bin);
573                    self.set_complex(frame, bin, c * gain);
574                }
575            }
576        }
577    }
578
579    /// Zero out a range of bins across all frames
580    pub fn zero_bins(&mut self, bin_range: std::ops::Range<usize>) {
581        for frame in 0..self.num_frames {
582            for bin in bin_range.clone() {
583                if bin < self.freq_bins {
584                    self.set_complex(frame, bin, Complex::new(T::zero(), T::zero()));
585                }
586            }
587        }
588    }
589}
590
591pub struct BatchStft<T: Float + FftNum> {
592    config: StftConfig<T>,
593    window: Vec<T>,
594    fft: Arc<dyn Fft<T>>,
595}
596
597impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchStft<T> {
598    pub fn new(config: StftConfig<T>) -> Self {
599        let window = config.generate_window();
600        let mut planner = FftPlanner::new();
601        let fft = planner.plan_fft_forward(config.fft_size);
602
603        Self {
604            config,
605            window,
606            fft,
607        }
608    }
609
610    pub fn process(&self, signal: &[T]) -> Spectrum<T> {
611        self.process_padded(signal, PadMode::Reflect)
612    }
613
614    pub fn process_padded(&self, signal: &[T], pad_mode: PadMode) -> Spectrum<T> {
615        let pad_amount = self.config.fft_size / 2;
616        let padded = utils::apply_padding(signal, pad_amount, pad_mode);
617
618        let num_frames = if padded.len() >= self.config.fft_size {
619            (padded.len() - self.config.fft_size) / self.config.hop_size + 1
620        } else {
621            0
622        };
623
624        let freq_bins = self.config.freq_bins();
625        let mut result = Spectrum::new(num_frames, freq_bins);
626
627        let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
628
629        for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
630            .step_by(self.config.hop_size)
631            .enumerate()
632        {
633            // Apply window and prepare FFT input
634            for i in 0..self.config.fft_size {
635                fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
636            }
637
638            // Compute FFT
639            self.fft.process(&mut fft_buffer);
640
641            // Store positive frequencies in flat layout
642            for bin in 0..freq_bins {
643                let idx = frame_idx * freq_bins + bin;
644                result.data[idx] = fft_buffer[bin].re;
645                result.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
646            }
647        }
648
649        result
650    }
651
652    /// Process signal and write into a pre-allocated Spectrum.
653    /// The spectrum must have the correct dimensions (num_frames x freq_bins).
654    /// Returns true if successful, false if dimensions don't match.
655    pub fn process_into(&self, signal: &[T], spectrum: &mut Spectrum<T>) -> bool {
656        self.process_padded_into(signal, PadMode::Reflect, spectrum)
657    }
658
659    /// Process signal with padding and write into a pre-allocated Spectrum.
660    pub fn process_padded_into(
661        &self,
662        signal: &[T],
663        pad_mode: PadMode,
664        spectrum: &mut Spectrum<T>,
665    ) -> bool {
666        let pad_amount = self.config.fft_size / 2;
667        let padded = utils::apply_padding(signal, pad_amount, pad_mode);
668
669        let num_frames = if padded.len() >= self.config.fft_size {
670            (padded.len() - self.config.fft_size) / self.config.hop_size + 1
671        } else {
672            0
673        };
674
675        let freq_bins = self.config.freq_bins();
676
677        // Check dimensions
678        if spectrum.num_frames != num_frames || spectrum.freq_bins != freq_bins {
679            return false;
680        }
681
682        let mut fft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
683
684        for (frame_idx, frame_start) in (0..padded.len() - self.config.fft_size + 1)
685            .step_by(self.config.hop_size)
686            .enumerate()
687        {
688            // Apply window and prepare FFT input
689            for i in 0..self.config.fft_size {
690                fft_buffer[i] = Complex::new(padded[frame_start + i] * self.window[i], T::zero());
691            }
692
693            // Compute FFT
694            self.fft.process(&mut fft_buffer);
695
696            // Store positive frequencies in flat layout
697            for bin in 0..freq_bins {
698                let idx = frame_idx * freq_bins + bin;
699                spectrum.data[idx] = fft_buffer[bin].re;
700                spectrum.data[num_frames * freq_bins + idx] = fft_buffer[bin].im;
701            }
702        }
703
704        true
705    }
706
707    /// Process multiple channels independently.
708    /// Returns one Spectrum per channel.
709    ///
710    /// # Arguments
711    ///
712    /// * `channels` - Slice of audio channels, each as a separate Vec
713    ///
714    /// # Panics
715    ///
716    /// Panics if channels is empty or if channels have different lengths.
717    ///
718    /// # Example
719    ///
720    /// ```
721    /// use stft_rs::prelude::*;
722    ///
723    /// let config = StftConfigF32::default_4096();
724    /// let stft = BatchStftF32::new(config);
725    ///
726    /// let left = vec![0.0; 44100];
727    /// let right = vec![0.0; 44100];
728    /// let channels = vec![left, right];
729    ///
730    /// let spectra = stft.process_multichannel(&channels);
731    /// assert_eq!(spectra.len(), 2); // One spectrum per channel
732    /// ```
733    pub fn process_multichannel(&self, channels: &[Vec<T>]) -> Vec<Spectrum<T>> {
734        assert!(!channels.is_empty(), "channels must not be empty");
735
736        // Validate all channels have same length
737        let expected_len = channels[0].len();
738        for (i, channel) in channels.iter().enumerate() {
739            assert_eq!(
740                channel.len(),
741                expected_len,
742                "Channel {} has length {}, expected {}",
743                i,
744                channel.len(),
745                expected_len
746            );
747        }
748
749        // Process each channel independently
750        #[cfg(feature = "rayon")]
751        {
752            use rayon::prelude::*;
753            channels
754                .par_iter()
755                .map(|channel| self.process(channel))
756                .collect()
757        }
758        #[cfg(not(feature = "rayon"))]
759        {
760            channels
761                .iter()
762                .map(|channel| self.process(channel))
763                .collect()
764        }
765    }
766
767    /// Process interleaved multi-channel audio.
768    /// Converts interleaved format (e.g., `[L,R,L,R,L,R,...]` for stereo)
769    /// into separate Spectrum for each channel.
770    ///
771    /// # Arguments
772    ///
773    /// * `data` - Interleaved audio data
774    /// * `num_channels` - Number of channels
775    ///
776    /// # Panics
777    ///
778    /// Panics if `num_channels` is 0 or if `data.len()` is not divisible by `num_channels`.
779    ///
780    /// # Example
781    ///
782    /// ```
783    /// use stft_rs::prelude::*;
784    ///
785    /// let config = StftConfigF32::default_4096();
786    /// let stft = BatchStftF32::new(config);
787    ///
788    /// // Stereo interleaved: L,R,L,R,L,R,...
789    /// let interleaved = vec![0.0; 88200]; // 2 channels * 44100 samples
790    ///
791    /// let spectra = stft.process_interleaved(&interleaved, 2);
792    /// assert_eq!(spectra.len(), 2); // One spectrum per channel
793    /// ```
794    pub fn process_interleaved(&self, data: &[T], num_channels: usize) -> Vec<Spectrum<T>> {
795        let channels = utils::deinterleave(data, num_channels);
796        self.process_multichannel(&channels)
797    }
798}
799
800pub struct BatchIstft<T: Float + FftNum> {
801    config: StftConfig<T>,
802    window: Vec<T>,
803    ifft: Arc<dyn Fft<T>>,
804}
805
806impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchIstft<T> {
807    pub fn new(config: StftConfig<T>) -> Self {
808        let window = config.generate_window();
809        let mut planner = FftPlanner::new();
810        let ifft = planner.plan_fft_inverse(config.fft_size);
811
812        Self {
813            config,
814            window,
815            ifft,
816        }
817    }
818
819    pub fn process(&self, spectrum: &Spectrum<T>) -> Vec<T> {
820        assert_eq!(
821            spectrum.freq_bins,
822            self.config.freq_bins(),
823            "Frequency bins mismatch"
824        );
825
826        let num_frames = spectrum.num_frames;
827        let original_time_len = (num_frames - 1) * self.config.hop_size;
828        let pad_amount = self.config.fft_size / 2;
829        let padded_len = original_time_len + 2 * pad_amount;
830
831        let mut overlap_buffer = vec![T::zero(); padded_len];
832        let mut window_energy = vec![T::zero(); padded_len];
833        let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
834
835        // Precompute window energy normalization
836        for frame_idx in 0..num_frames {
837            let pos = frame_idx * self.config.hop_size;
838            for i in 0..self.config.fft_size {
839                match self.config.reconstruction_mode {
840                    ReconstructionMode::Ola => {
841                        window_energy[pos + i] = window_energy[pos + i] + self.window[i];
842                    }
843                    ReconstructionMode::Wola => {
844                        window_energy[pos + i] =
845                            window_energy[pos + i] + self.window[i] * self.window[i];
846                    }
847                }
848            }
849        }
850
851        // Process each frame
852        for frame_idx in 0..num_frames {
853            // Build full spectrum with conjugate symmetry
854            for bin in 0..spectrum.freq_bins {
855                ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
856            }
857
858            // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
859            for bin in 1..(spectrum.freq_bins - 1) {
860                ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
861            }
862
863            // Compute IFFT
864            self.ifft.process(&mut ifft_buffer);
865
866            // Overlap-add
867            let pos = frame_idx * self.config.hop_size;
868            for i in 0..self.config.fft_size {
869                let fft_size_t = T::from(self.config.fft_size).unwrap();
870                let sample = ifft_buffer[i].re / fft_size_t;
871
872                match self.config.reconstruction_mode {
873                    ReconstructionMode::Ola => {
874                        // OLA: no windowing on inverse
875                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
876                    }
877                    ReconstructionMode::Wola => {
878                        // WOLA: apply window on inverse
879                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
880                    }
881                }
882            }
883        }
884
885        // Normalize by window energy
886        let threshold = T::from(1e-8).unwrap();
887        for i in 0..padded_len {
888            if window_energy[i] > threshold {
889                overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
890            }
891        }
892
893        // Remove padding
894        overlap_buffer[pad_amount..pad_amount + original_time_len].to_vec()
895    }
896
897    /// Process spectrum and write into a pre-allocated output buffer.
898    /// The output buffer will be resized if needed.
899    pub fn process_into(&self, spectrum: &Spectrum<T>, output: &mut Vec<T>) {
900        assert_eq!(
901            spectrum.freq_bins,
902            self.config.freq_bins(),
903            "Frequency bins mismatch"
904        );
905
906        let num_frames = spectrum.num_frames;
907        let original_time_len = (num_frames - 1) * self.config.hop_size;
908        let pad_amount = self.config.fft_size / 2;
909        let padded_len = original_time_len + 2 * pad_amount;
910
911        let mut overlap_buffer = vec![T::zero(); padded_len];
912        let mut window_energy = vec![T::zero(); padded_len];
913        let mut ifft_buffer = vec![Complex::new(T::zero(), T::zero()); self.config.fft_size];
914
915        // Precompute window energy normalization
916        for frame_idx in 0..num_frames {
917            let pos = frame_idx * self.config.hop_size;
918            for i in 0..self.config.fft_size {
919                match self.config.reconstruction_mode {
920                    ReconstructionMode::Ola => {
921                        window_energy[pos + i] = window_energy[pos + i] + self.window[i];
922                    }
923                    ReconstructionMode::Wola => {
924                        window_energy[pos + i] =
925                            window_energy[pos + i] + self.window[i] * self.window[i];
926                    }
927                }
928            }
929        }
930
931        // Process each frame
932        for frame_idx in 0..num_frames {
933            // Build full spectrum with conjugate symmetry
934            for bin in 0..spectrum.freq_bins {
935                ifft_buffer[bin] = spectrum.get_complex(frame_idx, bin);
936            }
937
938            // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
939            for bin in 1..(spectrum.freq_bins - 1) {
940                ifft_buffer[self.config.fft_size - bin] = ifft_buffer[bin].conj();
941            }
942
943            // Compute IFFT
944            self.ifft.process(&mut ifft_buffer);
945
946            // Overlap-add
947            let pos = frame_idx * self.config.hop_size;
948            for i in 0..self.config.fft_size {
949                let fft_size_t = T::from(self.config.fft_size).unwrap();
950                let sample = ifft_buffer[i].re / fft_size_t;
951
952                match self.config.reconstruction_mode {
953                    ReconstructionMode::Ola => {
954                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample;
955                    }
956                    ReconstructionMode::Wola => {
957                        overlap_buffer[pos + i] = overlap_buffer[pos + i] + sample * self.window[i];
958                    }
959                }
960            }
961        }
962
963        // Normalize by window energy
964        let threshold = T::from(1e-8).unwrap();
965        for i in 0..padded_len {
966            if window_energy[i] > threshold {
967                overlap_buffer[i] = overlap_buffer[i] / window_energy[i];
968            }
969        }
970
971        // Copy to output (resize if needed)
972        output.clear();
973        output.extend_from_slice(&overlap_buffer[pad_amount..pad_amount + original_time_len]);
974    }
975
976    /// Reconstruct multiple channels from their spectra.
977    /// Returns one Vec per channel.
978    ///
979    /// # Arguments
980    ///
981    /// * `spectra` - Slice of Spectrum, one per channel
982    ///
983    /// # Panics
984    ///
985    /// Panics if spectra is empty.
986    ///
987    /// # Example
988    ///
989    /// ```
990    /// use stft_rs::prelude::*;
991    ///
992    /// let config = StftConfigF32::default_4096();
993    /// let stft = BatchStftF32::new(config.clone());
994    /// let istft = BatchIstftF32::new(config);
995    ///
996    /// let left = vec![0.0; 44100];
997    /// let right = vec![0.0; 44100];
998    /// let channels = vec![left, right];
999    ///
1000    /// let spectra = stft.process_multichannel(&channels);
1001    /// let reconstructed = istft.process_multichannel(&spectra);
1002    ///
1003    /// assert_eq!(reconstructed.len(), 2); // One channel per spectrum
1004    /// ```
1005    pub fn process_multichannel(&self, spectra: &[Spectrum<T>]) -> Vec<Vec<T>> {
1006        assert!(!spectra.is_empty(), "spectra must not be empty");
1007
1008        // Process each spectrum independently
1009        #[cfg(feature = "rayon")]
1010        {
1011            use rayon::prelude::*;
1012            spectra
1013                .par_iter()
1014                .map(|spectrum| self.process(spectrum))
1015                .collect()
1016        }
1017        #[cfg(not(feature = "rayon"))]
1018        {
1019            spectra
1020                .iter()
1021                .map(|spectrum| self.process(spectrum))
1022                .collect()
1023        }
1024    }
1025
1026    /// Reconstruct multiple channels and interleave them into a single buffer.
1027    /// Converts separate channels back to interleaved format (e.g., `[L,R,L,R,L,R,...]` for stereo).
1028    ///
1029    /// # Arguments
1030    ///
1031    /// * `spectra` - Slice of Spectrum, one per channel
1032    ///
1033    /// # Panics
1034    ///
1035    /// Panics if spectra is empty or if channels have different lengths.
1036    ///
1037    /// # Example
1038    ///
1039    /// ```
1040    /// use stft_rs::prelude::*;
1041    ///
1042    /// let config = StftConfigF32::default_4096();
1043    /// let stft = BatchStftF32::new(config.clone());
1044    /// let istft = BatchIstftF32::new(config);
1045    ///
1046    /// // Process interleaved stereo
1047    /// let interleaved = vec![0.0; 88200]; // 2 channels * 44100 samples
1048    /// let spectra = stft.process_interleaved(&interleaved, 2);
1049    ///
1050    /// // Reconstruct back to interleaved
1051    /// let output = istft.process_multichannel_interleaved(&spectra);
1052    /// // Output length may differ slightly due to padding/framing
1053    /// assert_eq!(output.len() / 2, 44032); // samples per channel after reconstruction
1054    /// ```
1055    pub fn process_multichannel_interleaved(&self, spectra: &[Spectrum<T>]) -> Vec<T> {
1056        let channels = self.process_multichannel(spectra);
1057        utils::interleave(&channels)
1058    }
1059}
1060
1061pub struct StreamingStft<T: Float + FftNum> {
1062    config: StftConfig<T>,
1063    window: Vec<T>,
1064    fft: Arc<dyn Fft<T>>,
1065    input_buffer: VecDeque<T>,
1066    frame_index: usize,
1067    fft_buffer: Vec<Complex<T>>,
1068}
1069
1070impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingStft<T> {
1071    pub fn new(config: StftConfig<T>) -> Self {
1072        let window = config.generate_window();
1073        let mut planner = FftPlanner::new();
1074        let fft = planner.plan_fft_forward(config.fft_size);
1075        let fft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1076
1077        Self {
1078            config,
1079            window,
1080            fft,
1081            input_buffer: VecDeque::new(),
1082            frame_index: 0,
1083            fft_buffer,
1084        }
1085    }
1086
1087    pub fn push_samples(&mut self, samples: &[T]) -> Vec<SpectrumFrame<T>> {
1088        self.input_buffer.extend(samples.iter().copied());
1089
1090        let mut frames = Vec::new();
1091
1092        while self.input_buffer.len() >= self.config.fft_size {
1093            // Process one frame
1094            for i in 0..self.config.fft_size {
1095                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1096            }
1097
1098            self.fft.process(&mut self.fft_buffer);
1099
1100            let freq_bins = self.config.freq_bins();
1101            let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1102            frames.push(SpectrumFrame::from_data(data));
1103
1104            // Advance by hop size
1105            self.input_buffer.drain(..self.config.hop_size);
1106            self.frame_index += 1;
1107        }
1108
1109        frames
1110    }
1111
1112    /// Push samples and write frames into a pre-allocated buffer.
1113    /// Returns the number of frames written.
1114    pub fn push_samples_into(
1115        &mut self,
1116        samples: &[T],
1117        output: &mut Vec<SpectrumFrame<T>>,
1118    ) -> usize {
1119        self.input_buffer.extend(samples.iter().copied());
1120
1121        let initial_len = output.len();
1122
1123        while self.input_buffer.len() >= self.config.fft_size {
1124            // Process one frame
1125            for i in 0..self.config.fft_size {
1126                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1127            }
1128
1129            self.fft.process(&mut self.fft_buffer);
1130
1131            let freq_bins = self.config.freq_bins();
1132            let data: Vec<Complex<T>> = self.fft_buffer[..freq_bins].to_vec();
1133            output.push(SpectrumFrame::from_data(data));
1134
1135            // Advance by hop size
1136            self.input_buffer.drain(..self.config.hop_size);
1137            self.frame_index += 1;
1138        }
1139
1140        output.len() - initial_len
1141    }
1142
1143    /// Push samples and write directly into pre-existing SpectrumFrame buffers.
1144    /// This is a zero-allocation method - frames must be pre-allocated with correct size.
1145    /// Returns the number of frames written.
1146    ///
1147    /// # Example
1148    /// ```ignore
1149    /// let mut frame_pool = vec![SpectrumFrame::new(config.freq_bins()); 16];
1150    /// let mut frame_index = 0;
1151    ///
1152    /// let frames_written = stft.push_samples_write(chunk, &mut frame_pool, &mut frame_index);
1153    /// // Process frames 0..frames_written
1154    /// ```
1155    pub fn push_samples_write(
1156        &mut self,
1157        samples: &[T],
1158        frame_pool: &mut [SpectrumFrame<T>],
1159        pool_index: &mut usize,
1160    ) -> usize {
1161        self.input_buffer.extend(samples.iter().copied());
1162
1163        let initial_index = *pool_index;
1164        let freq_bins = self.config.freq_bins();
1165
1166        while self.input_buffer.len() >= self.config.fft_size && *pool_index < frame_pool.len() {
1167            // Process one frame
1168            for i in 0..self.config.fft_size {
1169                self.fft_buffer[i] = Complex::new(self.input_buffer[i] * self.window[i], T::zero());
1170            }
1171
1172            self.fft.process(&mut self.fft_buffer);
1173
1174            // Write directly into the pre-allocated frame
1175            let frame = &mut frame_pool[*pool_index];
1176            debug_assert_eq!(
1177                frame.freq_bins, freq_bins,
1178                "Frame pool frames must match freq_bins"
1179            );
1180            frame.data[..freq_bins].copy_from_slice(&self.fft_buffer[..freq_bins]);
1181
1182            // Advance by hop size
1183            self.input_buffer.drain(..self.config.hop_size);
1184            self.frame_index += 1;
1185            *pool_index += 1;
1186        }
1187
1188        *pool_index - initial_index
1189    }
1190
1191    pub fn flush(&mut self) -> Vec<SpectrumFrame<T>> {
1192        // For streaming, we typically don't process partial frames
1193        // Could zero-pad if needed, but that changes the signal
1194        Vec::new()
1195    }
1196
1197    pub fn reset(&mut self) {
1198        self.input_buffer.clear();
1199        self.frame_index = 0;
1200    }
1201
1202    pub fn buffered_samples(&self) -> usize {
1203        self.input_buffer.len()
1204    }
1205}
1206
1207/// Multi-channel streaming STFT processor with independent state per channel.
1208pub struct MultiChannelStreamingStft<T: Float + FftNum> {
1209    processors: Vec<StreamingStft<T>>,
1210}
1211
1212impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingStft<T> {
1213    /// Create a new multi-channel streaming STFT processor.
1214    ///
1215    /// # Arguments
1216    ///
1217    /// * `config` - STFT configuration
1218    /// * `num_channels` - Number of channels
1219    pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1220        assert!(num_channels > 0, "num_channels must be > 0");
1221        let processors = (0..num_channels)
1222            .map(|_| StreamingStft::new(config.clone()))
1223            .collect();
1224        Self { processors }
1225    }
1226
1227    /// Push samples for all channels and get frames for each channel.
1228    /// Returns Vec<Vec<SpectrumFrame>>, outer Vec = channels, inner Vec = frames.
1229    ///
1230    /// # Arguments
1231    ///
1232    /// * `channels` - Slice of sample slices, one per channel
1233    ///
1234    /// # Panics
1235    ///
1236    /// Panics if channels.len() doesn't match num_channels.
1237    pub fn push_samples(&mut self, channels: &[&[T]]) -> Vec<Vec<SpectrumFrame<T>>> {
1238        assert_eq!(
1239            channels.len(),
1240            self.processors.len(),
1241            "Expected {} channels, got {}",
1242            self.processors.len(),
1243            channels.len()
1244        );
1245
1246        #[cfg(feature = "rayon")]
1247        {
1248            use rayon::prelude::*;
1249            self.processors
1250                .par_iter_mut()
1251                .zip(channels.par_iter())
1252                .map(|(stft, channel)| stft.push_samples(channel))
1253                .collect()
1254        }
1255        #[cfg(not(feature = "rayon"))]
1256        {
1257            self.processors
1258                .iter_mut()
1259                .zip(channels.iter())
1260                .map(|(stft, channel)| stft.push_samples(channel))
1261                .collect()
1262        }
1263    }
1264
1265    /// Flush all channels and return remaining frames.
1266    pub fn flush(&mut self) -> Vec<Vec<SpectrumFrame<T>>> {
1267        #[cfg(feature = "rayon")]
1268        {
1269            use rayon::prelude::*;
1270            self.processors
1271                .par_iter_mut()
1272                .map(|stft| stft.flush())
1273                .collect()
1274        }
1275        #[cfg(not(feature = "rayon"))]
1276        {
1277            self.processors
1278                .iter_mut()
1279                .map(|stft| stft.flush())
1280                .collect()
1281        }
1282    }
1283
1284    /// Reset all channels.
1285    pub fn reset(&mut self) {
1286        #[cfg(feature = "rayon")]
1287        {
1288            use rayon::prelude::*;
1289            self.processors.par_iter_mut().for_each(|stft| stft.reset());
1290        }
1291        #[cfg(not(feature = "rayon"))]
1292        {
1293            self.processors.iter_mut().for_each(|stft| stft.reset());
1294        }
1295    }
1296
1297    /// Get the number of channels.
1298    pub fn num_channels(&self) -> usize {
1299        self.processors.len()
1300    }
1301}
1302
1303pub struct StreamingIstft<T: Float + FftNum> {
1304    config: StftConfig<T>,
1305    window: Vec<T>,
1306    ifft: Arc<dyn Fft<T>>,
1307    overlap_buffer: Vec<T>,
1308    window_energy: Vec<T>,
1309    output_position: usize,
1310    frames_processed: usize,
1311    ifft_buffer: Vec<Complex<T>>,
1312}
1313
1314impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingIstft<T> {
1315    pub fn new(config: StftConfig<T>) -> Self {
1316        let window = config.generate_window();
1317        let mut planner = FftPlanner::new();
1318        let ifft = planner.plan_fft_inverse(config.fft_size);
1319
1320        // Buffer needs to hold enough samples for full overlap
1321        // For proper reconstruction, need at least fft_size samples
1322        let buffer_size = config.fft_size * 2;
1323        let ifft_buffer = vec![Complex::new(T::zero(), T::zero()); config.fft_size];
1324
1325        Self {
1326            config,
1327            window,
1328            ifft,
1329            overlap_buffer: vec![T::zero(); buffer_size],
1330            window_energy: vec![T::zero(); buffer_size],
1331            output_position: 0,
1332            frames_processed: 0,
1333            ifft_buffer,
1334        }
1335    }
1336
1337    pub fn push_frame(&mut self, frame: &SpectrumFrame<T>) -> Vec<T> {
1338        assert_eq!(
1339            frame.freq_bins,
1340            self.config.freq_bins(),
1341            "Frequency bins mismatch"
1342        );
1343
1344        // Build full spectrum with conjugate symmetry
1345        for bin in 0..frame.freq_bins {
1346            self.ifft_buffer[bin] = frame.data[bin];
1347        }
1348
1349        // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
1350        for bin in 1..(frame.freq_bins - 1) {
1351            self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1352        }
1353
1354        // Compute IFFT
1355        self.ifft.process(&mut self.ifft_buffer);
1356
1357        // Overlap-add into buffer at the current write position
1358        let write_pos = self.frames_processed * self.config.hop_size;
1359        for i in 0..self.config.fft_size {
1360            let fft_size_t = T::from(self.config.fft_size).unwrap();
1361            let sample = self.ifft_buffer[i].re / fft_size_t;
1362            let buf_idx = write_pos + i;
1363
1364            // Extend buffers if needed
1365            if buf_idx >= self.overlap_buffer.len() {
1366                self.overlap_buffer.resize(buf_idx + 1, T::zero());
1367                self.window_energy.resize(buf_idx + 1, T::zero());
1368            }
1369
1370            match self.config.reconstruction_mode {
1371                ReconstructionMode::Ola => {
1372                    self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1373                    self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1374                }
1375                ReconstructionMode::Wola => {
1376                    self.overlap_buffer[buf_idx] =
1377                        self.overlap_buffer[buf_idx] + sample * self.window[i];
1378                    self.window_energy[buf_idx] =
1379                        self.window_energy[buf_idx] + self.window[i] * self.window[i];
1380                }
1381            }
1382        }
1383
1384        self.frames_processed += 1;
1385
1386        // Calculate how many samples are "ready" (have full window energy)
1387        // Samples are ready when no future frames will contribute to them
1388        let ready_until = if self.frames_processed == 1 {
1389            0 // First frame: no output yet, need overlap
1390        } else {
1391            // Samples before the current frame's start position are complete
1392            (self.frames_processed - 1) * self.config.hop_size
1393        };
1394
1395        // Extract ready samples
1396        let output_start = self.output_position;
1397        let output_end = ready_until;
1398        let mut output = Vec::new();
1399
1400        let threshold = T::from(1e-8).unwrap();
1401        if output_end > output_start {
1402            for i in output_start..output_end {
1403                let normalized = if self.window_energy[i] > threshold {
1404                    self.overlap_buffer[i] / self.window_energy[i]
1405                } else {
1406                    T::zero()
1407                };
1408                output.push(normalized);
1409            }
1410            self.output_position = output_end;
1411        }
1412
1413        output
1414    }
1415
1416    /// Push a frame and write output samples into a pre-allocated buffer.
1417    /// Returns the number of samples written.
1418    pub fn push_frame_into(&mut self, frame: &SpectrumFrame<T>, output: &mut Vec<T>) -> usize {
1419        assert_eq!(
1420            frame.freq_bins,
1421            self.config.freq_bins(),
1422            "Frequency bins mismatch"
1423        );
1424
1425        // Build full spectrum with conjugate symmetry
1426        for bin in 0..frame.freq_bins {
1427            self.ifft_buffer[bin] = frame.data[bin];
1428        }
1429
1430        // Conjugate symmetry for negative frequencies (skip DC and Nyquist)
1431        for bin in 1..(frame.freq_bins - 1) {
1432            self.ifft_buffer[self.config.fft_size - bin] = self.ifft_buffer[bin].conj();
1433        }
1434
1435        // Compute IFFT
1436        self.ifft.process(&mut self.ifft_buffer);
1437
1438        // Overlap-add into buffer at the current write position
1439        let write_pos = self.frames_processed * self.config.hop_size;
1440        for i in 0..self.config.fft_size {
1441            let fft_size_t = T::from(self.config.fft_size).unwrap();
1442            let sample = self.ifft_buffer[i].re / fft_size_t;
1443            let buf_idx = write_pos + i;
1444
1445            // Extend buffers if needed
1446            if buf_idx >= self.overlap_buffer.len() {
1447                self.overlap_buffer.resize(buf_idx + 1, T::zero());
1448                self.window_energy.resize(buf_idx + 1, T::zero());
1449            }
1450
1451            match self.config.reconstruction_mode {
1452                ReconstructionMode::Ola => {
1453                    self.overlap_buffer[buf_idx] = self.overlap_buffer[buf_idx] + sample;
1454                    self.window_energy[buf_idx] = self.window_energy[buf_idx] + self.window[i];
1455                }
1456                ReconstructionMode::Wola => {
1457                    self.overlap_buffer[buf_idx] =
1458                        self.overlap_buffer[buf_idx] + sample * self.window[i];
1459                    self.window_energy[buf_idx] =
1460                        self.window_energy[buf_idx] + self.window[i] * self.window[i];
1461                }
1462            }
1463        }
1464
1465        self.frames_processed += 1;
1466
1467        // Calculate how many samples are "ready" (have full window energy)
1468        // Samples are ready when no future frames will contribute to them
1469        let ready_until = if self.frames_processed == 1 {
1470            0 // First frame: no output yet, need overlap
1471        } else {
1472            // Samples before the current frame's start position are complete
1473            (self.frames_processed - 1) * self.config.hop_size
1474        };
1475
1476        // Extract ready samples
1477        let output_start = self.output_position;
1478        let output_end = ready_until;
1479        let initial_len = output.len();
1480
1481        let threshold = T::from(1e-8).unwrap();
1482        if output_end > output_start {
1483            for i in output_start..output_end {
1484                let normalized = if self.window_energy[i] > threshold {
1485                    self.overlap_buffer[i] / self.window_energy[i]
1486                } else {
1487                    T::zero()
1488                };
1489                output.push(normalized);
1490            }
1491            self.output_position = output_end;
1492        }
1493
1494        output.len() - initial_len
1495    }
1496
1497    pub fn flush(&mut self) -> Vec<T> {
1498        // Return all remaining samples in buffer
1499        let mut output = Vec::new();
1500        let threshold = T::from(1e-8).unwrap();
1501        for i in self.output_position..self.overlap_buffer.len() {
1502            if self.window_energy[i] > threshold {
1503                output.push(self.overlap_buffer[i] / self.window_energy[i]);
1504            } else if i < (self.frames_processed * self.config.hop_size + self.config.fft_size) {
1505                output.push(T::zero()); // Sample in valid range but no window energy
1506            } else {
1507                break; // Past the end of valid data
1508            }
1509        }
1510
1511        // Determine the actual end of valid data
1512        let valid_end =
1513            (self.frames_processed.saturating_sub(1)) * self.config.hop_size + self.config.fft_size;
1514        if output.len() > valid_end - self.output_position {
1515            output.truncate(valid_end - self.output_position);
1516        }
1517
1518        self.reset();
1519        output
1520    }
1521
1522    pub fn reset(&mut self) {
1523        self.overlap_buffer.clear();
1524        self.overlap_buffer
1525            .resize(self.config.fft_size * 2, T::zero());
1526        self.window_energy.clear();
1527        self.window_energy
1528            .resize(self.config.fft_size * 2, T::zero());
1529        self.output_position = 0;
1530        self.frames_processed = 0;
1531    }
1532}
1533
1534/// Multi-channel streaming iSTFT processor with independent state per channel.
1535pub struct MultiChannelStreamingIstft<T: Float + FftNum> {
1536    processors: Vec<StreamingIstft<T>>,
1537}
1538
1539impl<T: Float + FftNum + FromPrimitive + fmt::Debug> MultiChannelStreamingIstft<T> {
1540    /// Create a new multi-channel streaming iSTFT processor.
1541    ///
1542    /// # Arguments
1543    ///
1544    /// * `config` - STFT configuration
1545    /// * `num_channels` - Number of channels
1546    pub fn new(config: StftConfig<T>, num_channels: usize) -> Self {
1547        assert!(num_channels > 0, "num_channels must be > 0");
1548        let processors = (0..num_channels)
1549            .map(|_| StreamingIstft::new(config.clone()))
1550            .collect();
1551        Self { processors }
1552    }
1553
1554    /// Push frames for all channels and get samples for each channel.
1555    /// Returns Vec<Vec<T>>, outer Vec = channels, inner Vec = samples.
1556    ///
1557    /// # Arguments
1558    ///
1559    /// * `frames` - Slice of frames, one per channel
1560    ///
1561    /// # Panics
1562    ///
1563    /// Panics if frames.len() doesn't match num_channels.
1564    pub fn push_frames(&mut self, frames: &[&SpectrumFrame<T>]) -> Vec<Vec<T>> {
1565        assert_eq!(
1566            frames.len(),
1567            self.processors.len(),
1568            "Expected {} channels, got {}",
1569            self.processors.len(),
1570            frames.len()
1571        );
1572
1573        #[cfg(feature = "rayon")]
1574        {
1575            use rayon::prelude::*;
1576            self.processors
1577                .par_iter_mut()
1578                .zip(frames.par_iter())
1579                .map(|(istft, frame)| istft.push_frame(frame))
1580                .collect()
1581        }
1582        #[cfg(not(feature = "rayon"))]
1583        {
1584            self.processors
1585                .iter_mut()
1586                .zip(frames.iter())
1587                .map(|(istft, frame)| istft.push_frame(frame))
1588                .collect()
1589        }
1590    }
1591
1592    /// Flush all channels and return remaining samples.
1593    pub fn flush(&mut self) -> Vec<Vec<T>> {
1594        #[cfg(feature = "rayon")]
1595        {
1596            use rayon::prelude::*;
1597            self.processors
1598                .par_iter_mut()
1599                .map(|istft| istft.flush())
1600                .collect()
1601        }
1602        #[cfg(not(feature = "rayon"))]
1603        {
1604            self.processors
1605                .iter_mut()
1606                .map(|istft| istft.flush())
1607                .collect()
1608        }
1609    }
1610
1611    /// Reset all channels.
1612    pub fn reset(&mut self) {
1613        #[cfg(feature = "rayon")]
1614        {
1615            use rayon::prelude::*;
1616            self.processors
1617                .par_iter_mut()
1618                .for_each(|istft| istft.reset());
1619        }
1620        #[cfg(not(feature = "rayon"))]
1621        {
1622            self.processors.iter_mut().for_each(|istft| istft.reset());
1623        }
1624    }
1625
1626    /// Get the number of channels.
1627    pub fn num_channels(&self) -> usize {
1628        self.processors.len()
1629    }
1630}
1631
1632// Type aliases for common float types
1633pub type StftConfigF32 = StftConfig<f32>;
1634pub type StftConfigF64 = StftConfig<f64>;
1635
1636pub type StftConfigBuilderF32 = StftConfigBuilder<f32>;
1637pub type StftConfigBuilderF64 = StftConfigBuilder<f64>;
1638
1639pub type BatchStftF32 = BatchStft<f32>;
1640pub type BatchStftF64 = BatchStft<f64>;
1641
1642pub type BatchIstftF32 = BatchIstft<f32>;
1643pub type BatchIstftF64 = BatchIstft<f64>;
1644
1645pub type StreamingStftF32 = StreamingStft<f32>;
1646pub type StreamingStftF64 = StreamingStft<f64>;
1647
1648pub type StreamingIstftF32 = StreamingIstft<f32>;
1649pub type StreamingIstftF64 = StreamingIstft<f64>;
1650
1651pub type SpectrumF32 = Spectrum<f32>;
1652pub type SpectrumF64 = Spectrum<f64>;
1653
1654pub type SpectrumFrameF32 = SpectrumFrame<f32>;
1655pub type SpectrumFrameF64 = SpectrumFrame<f64>;
1656
1657pub type MultiChannelStreamingStftF32 = MultiChannelStreamingStft<f32>;
1658pub type MultiChannelStreamingStftF64 = MultiChannelStreamingStft<f64>;
1659
1660pub type MultiChannelStreamingIstftF32 = MultiChannelStreamingIstft<f32>;
1661pub type MultiChannelStreamingIstftF64 = MultiChannelStreamingIstft<f64>;