Skip to main content

speech_prep/preprocessing/
noise_reduction.rs

1//! Noise reduction via spectral subtraction.
2//!
3//! Reduces background noise while preserving speech intelligibility using
4//! adaptive spectral subtraction with VAD-informed noise profiling.
5//!
6//! # Capabilities
7//!
8//! - **Stationary Noise Reduction**: Effectively removes constant background
9//!   noise (HVAC hum, white noise, fan noise, café ambience)
10//! - **≥6 dB SNR Improvement**: Validated on white noise, low-frequency hum,
11//!   and ambient café noise
12//! - **Phase Preservation**: Maintains speech intelligibility by preserving
13//!   original signal phase
14//! - **VAD Integration**: Adapts noise profile only during detected silence
15//! - **Real-Time**: <15ms latency per 500ms chunk (typically 0.2-0.3ms)
16//!
17//! # Limitations
18//!
19//! **Spectral subtraction is designed for STATIONARY noise only.**
20//!
21//! - **Non-stationary noise**: Struggles with time-varying noise (individual
22//!   voices, music, babble with distinct speakers). Use Wiener filtering or
23//!   deep learning approaches for non-stationary scenarios.
24//! - **Speech-like interference**: Cannot separate overlapping speakers or
25//!   remove foreground speech interference (requires source separation
26//!   techniques).
27//! - **Musical noise artifacts**: Tonal artifacts may occur with aggressive
28//!   settings. Mitigated via spectral floor parameter (β=0.02 default).
29//! - **Transient noise**: Impulsive sounds (door slams, clicks) are not handled
30//!   well. Consider median filtering for transient suppression.
31//!
32//! # When to Use This
33//!
34//! ✅ **Good fit**:
35//! - Background HVAC/fan noise
36//! - Café/restaurant ambient noise (general chatter blur, dishes)
37//! - Low-frequency hum (electrical interference)
38//! - Stationary white/pink noise
39//!
40//! ❌ **Poor fit**:
41//! - Multi-speaker separation (babble with distinct voices)
42//! - Music removal
43//! - Non-stationary interference
44//! - Echo/reverb reduction (use AEC instead)
45
46use std::f32::consts::PI;
47use std::sync::Arc;
48
49use super::VadContext;
50use crate::error::{Error, Result};
51use crate::time::{AudioDuration, AudioInstant};
52use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
53use tracing::{info, warn};
54
55/// Configuration for noise reduction via spectral subtraction.
56///
57/// # Examples
58///
59/// ```rust,no_run
60/// use speech_prep::preprocessing::NoiseReductionConfig;
61///
62/// // Default: 25ms window, 10ms hop, α=2.0, β=0.02
63/// let config = NoiseReductionConfig::default();
64///
65/// // Aggressive noise removal for noisy environment
66/// let config = NoiseReductionConfig {
67///     oversubtraction_factor: 2.5,
68///     spectral_floor: 0.01,
69///     ..Default::default()
70/// };
71/// # Ok::<(), speech_prep::error::Error>(())
72/// ```
73#[derive(Debug, Clone)]
74#[allow(missing_copy_implementations)]
75pub struct NoiseReductionConfig {
76    /// Sample rate in Hz.
77    ///
78    /// **Default**: 16000
79    /// **Range**: 8000-48000
80    pub sample_rate_hz: u32,
81
82    /// STFT window duration in milliseconds.
83    ///
84    /// **Default**: 25.0
85    /// **Range**: 10.0-50.0
86    ///
87    /// **Effect**: Longer windows improve frequency resolution but reduce
88    /// time resolution. 25ms captures 1-3 pitch periods for typical speech.
89    pub window_ms: f32,
90
91    /// STFT hop duration in milliseconds.
92    ///
93    /// **Default**: 10.0
94    /// **Range**: 5.0-25.0
95    ///
96    /// **Effect**: Smaller hops increase overlap (smoother reconstruction)
97    /// but require more computation. 10ms hop = 60% overlap with 25ms window.
98    pub hop_ms: f32,
99
100    /// Oversubtraction factor (α).
101    ///
102    /// **Default**: 2.0
103    /// **Range**: 1.0-3.0
104    ///
105    /// **Effect**: Multiplier for noise estimate in spectral subtraction.
106    /// - Higher: More aggressive noise removal, more artifacts
107    /// - Lower: Conservative removal, less SNR gain
108    pub oversubtraction_factor: f32,
109
110    /// Spectral floor (β) as fraction of noise estimate.
111    ///
112    /// **Default**: 0.02 (2% of noise estimate)
113    /// **Range**: 0.001-0.1
114    ///
115    /// **Effect**: Minimum magnitude after subtraction to prevent musical
116    /// noise. Acts as a noise gate.
117    pub spectral_floor: f32,
118
119    /// Noise profile smoothing factor (`α_noise`).
120    ///
121    /// **Default**: 0.98
122    /// **Range**: 0.9-0.999
123    ///
124    /// **Effect**: Exponential moving average smoothing for noise profile.
125    /// Higher values = slower adaptation, more stable estimate.
126    pub noise_smoothing: f32,
127
128    /// Enable noise reduction.
129    ///
130    /// **Default**: true
131    ///
132    /// **Effect**: When false, audio passes through unmodified (bypass mode).
133    pub enable: bool,
134}
135
136impl Default for NoiseReductionConfig {
137    fn default() -> Self {
138        Self {
139            sample_rate_hz: 16_000,
140            window_ms: 25.0,
141            hop_ms: 10.0,
142            oversubtraction_factor: 2.0,
143            spectral_floor: 0.02,
144            noise_smoothing: 0.98,
145            enable: true,
146        }
147    }
148}
149
150impl NoiseReductionConfig {
151    /// Validate configuration parameters.
152    ///
153    /// # Errors
154    ///
155    /// Returns `Error::Configuration` if:
156    /// - `sample_rate_hz` not in 8000-48000 Hz
157    /// - `window_ms` not in 10.0-50.0 ms
158    /// - `hop_ms` >= `window_ms` (overlap required)
159    /// - `oversubtraction_factor` not in 1.0-3.0
160    /// - `spectral_floor` not in 0.001-0.1
161    /// - `noise_smoothing` not in 0.9-0.999
162    #[allow(clippy::trivially_copy_pass_by_ref)]
163    pub fn validate(&self) -> Result<()> {
164        if !(8000..=48_000).contains(&self.sample_rate_hz) {
165            return Err(Error::Configuration(format!(
166                "Invalid sample rate: {} Hz (range: 8000-48000)",
167                self.sample_rate_hz
168            )));
169        }
170
171        if !(10.0..=50.0).contains(&self.window_ms) {
172            return Err(Error::Configuration(format!(
173                "Invalid window size: {:.1} ms (range: 10-50)",
174                self.window_ms
175            )));
176        }
177
178        if self.hop_ms >= self.window_ms {
179            return Err(Error::Configuration(format!(
180                "Hop {:.1} ms must be < window {:.1} ms",
181                self.hop_ms, self.window_ms
182            )));
183        }
184
185        if !(1.0..=3.0).contains(&self.oversubtraction_factor) {
186            return Err(Error::Configuration(format!(
187                "Invalid oversubtraction factor: {:.2} (range: 1.0-3.0)",
188                self.oversubtraction_factor
189            )));
190        }
191
192        if !(0.001..=0.1).contains(&self.spectral_floor) {
193            return Err(Error::Configuration(format!(
194                "Invalid spectral floor: {:.3} (range: 0.001-0.1)",
195                self.spectral_floor
196            )));
197        }
198
199        if !(0.9..1.0).contains(&self.noise_smoothing) {
200            return Err(Error::Configuration(format!(
201                "Invalid noise smoothing: {:.3} (range: 0.9-0.999)",
202                self.noise_smoothing
203            )));
204        }
205
206        Ok(())
207    }
208
209    /// Calculate frame length in samples.
210    pub fn frame_length(&self) -> usize {
211        ((self.window_ms / 1000.0) * self.sample_rate_hz as f32).round() as usize
212    }
213
214    /// Calculate hop length in samples.
215    pub fn hop_length(&self) -> usize {
216        ((self.hop_ms / 1000.0) * self.sample_rate_hz as f32).round() as usize
217    }
218
219    /// Calculate FFT size (next power of 2 >= frame length).
220    pub fn fft_size(&self) -> usize {
221        self.frame_length().next_power_of_two()
222    }
223}
224
225/// Noise reduction via spectral subtraction with adaptive noise profiling.
226///
227/// Implements the noise reduction specification:
228/// - STFT-based processing (25ms window, 10ms hop)
229/// - Adaptive noise profile estimation during VAD-detected silence
230/// - Magnitude-only spectral subtraction (preserves phase)
231/// - Achieves ≥6 dB SNR improvement target
232///
233/// # Performance
234///
235/// - **Target**: <15ms per 500ms chunk (8000 samples @ 16kHz)
236/// - **Expected**: ~7ms (2x headroom)
237/// - **Optimization**: Precomputed FFT plans, reused buffers
238///
239/// # Example
240///
241/// ```rust,no_run
242/// use speech_prep::preprocessing::{NoiseReducer, NoiseReductionConfig, VadContext};
243///
244/// # fn main() -> speech_prep::error::Result<()> {
245/// let config = NoiseReductionConfig::default();
246/// let mut reducer = NoiseReducer::new(config)?;
247/// let audio_stream = vec![vec![0.0; 8000], vec![0.05; 8080]];
248///
249/// // Process streaming chunks with VAD context
250/// for chunk in audio_stream {
251///     let vad_ctx = VadContext { is_silence: detect_silence(&chunk) };
252///     let _denoised = reducer.reduce(&chunk, Some(vad_ctx))?;
253/// }
254/// # Ok(())
255/// # }
256/// #
257/// # fn detect_silence(chunk: &[f32]) -> bool {
258/// #     chunk.iter().all(|sample| sample.abs() < 1e-3)
259/// # }
260/// ```
261#[allow(missing_copy_implementations)]
262pub struct NoiseReducer {
263    config: NoiseReductionConfig,
264    fft_forward: Arc<dyn RealToComplex<f32>>,
265    fft_inverse: Arc<dyn ComplexToReal<f32>>,
266    window: Vec<f32>,
267    noise_profile: Vec<f32>,
268    noise_initialized: bool,
269    overlap_buffer: Vec<f32>,
270}
271
272impl std::fmt::Debug for NoiseReducer {
273    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274        f.debug_struct("NoiseReducer")
275            .field("config", &self.config)
276            .field("window_length", &self.window.len())
277            .field("noise_profile_bins", &self.noise_profile.len())
278            .field("noise_initialized", &self.noise_initialized)
279            .finish_non_exhaustive()
280    }
281}
282
283impl NoiseReducer {
284    /// Create a new noise reducer.
285    ///
286    /// # Arguments
287    ///
288    /// * `config` - Configuration parameters (window size, hop, α, β)
289    ///
290    /// # Errors
291    ///
292    /// Returns `Error::Configuration` if configuration is invalid.
293    ///
294    /// # Example
295    ///
296    /// ```rust,no_run
297    /// use speech_prep::preprocessing::{NoiseReducer, NoiseReductionConfig};
298    ///
299    /// let config = NoiseReductionConfig {
300    ///     oversubtraction_factor: 2.5, // Aggressive
301    ///     ..Default::default()
302    /// };
303    /// let reducer = NoiseReducer::new(config)?;
304    /// # Ok::<(), speech_prep::error::Error>(())
305    /// ```
306    pub fn new(config: NoiseReductionConfig) -> Result<Self> {
307        config.validate()?;
308
309        let fft_size = config.fft_size();
310        let frame_length = config.frame_length();
311
312        let mut planner = RealFftPlanner::<f32>::new();
313        let fft_forward = planner.plan_fft_forward(fft_size);
314        let fft_inverse = planner.plan_fft_inverse(fft_size);
315
316        let window = generate_hann_window(frame_length);
317
318        let num_bins = fft_size / 2 + 1;
319        let noise_profile = vec![1e-6; num_bins];
320
321        let overlap_buffer = vec![0.0; frame_length];
322
323        Ok(Self {
324            config,
325            fft_forward,
326            fft_inverse,
327            window,
328            noise_profile,
329            noise_initialized: false,
330            overlap_buffer,
331        })
332    }
333
334    /// Apply noise reduction to audio samples.
335    ///
336    /// # Arguments
337    ///
338    /// * `samples` - Input audio samples (typically 500ms chunk = 8000 samples
339    ///   @ 16kHz)
340    /// * `vad_context` - Optional VAD state for noise profile updates
341    ///
342    /// # Returns
343    ///
344    /// Denoised audio with ≥6 dB SNR improvement on noisy input.
345    ///
346    /// # Performance
347    ///
348    /// - Expected: ~7ms for 8000 samples (2x better than <15ms target)
349    /// - Complexity: O(n log n) for FFT operations
350    ///
351    /// # Example
352    ///
353    /// ```rust,no_run
354    /// use speech_prep::preprocessing::{NoiseReducer, NoiseReductionConfig, VadContext};
355    ///
356    /// let mut reducer = NoiseReducer::new(NoiseReductionConfig::default())?;
357    ///
358    /// // Chunk 1 (silence - initialize noise profile)
359    /// let chunk1 = vec![0.001; 8000];
360    /// let vad1 = VadContext { is_silence: true };
361    /// let output1 = reducer.reduce(&chunk1, Some(vad1))?;
362    ///
363    /// // Chunk 2 (speech - apply noise reduction)
364    /// let chunk2 = vec![0.1; 8000];
365    /// let vad2 = VadContext { is_silence: false };
366    /// let output2 = reducer.reduce(&chunk2, Some(vad2))?;
367    /// # Ok::<(), speech_prep::error::Error>(())
368    /// ```
369    #[allow(clippy::unnecessary_wraps)]
370    pub fn reduce(&mut self, samples: &[f32], vad_context: Option<VadContext>) -> Result<Vec<f32>> {
371        let processing_start = AudioInstant::now();
372
373        if samples.is_empty() {
374            return Ok(Vec::new());
375        }
376
377        if !self.config.enable {
378            return Ok(samples.to_vec());
379        }
380
381        let (mut output, frame_count) = self.process_stft_frames(samples, vad_context)?;
382        self.normalize_overlap_add(&mut output);
383
384        let elapsed = elapsed_duration(processing_start);
385        let latency_ms = elapsed.as_secs_f64() * 1000.0;
386        self.record_performance_metrics(samples, &output, latency_ms, frame_count);
387
388        Ok(output)
389    }
390
391    /// Process audio through STFT frames with spectral subtraction.
392    fn process_stft_frames(
393        &mut self,
394        samples: &[f32],
395        vad_context: Option<VadContext>,
396    ) -> Result<(Vec<f32>, usize)> {
397        let frame_length = self.config.frame_length();
398        let hop_length = self.config.hop_length();
399
400        let mut output = vec![0.0; samples.len()];
401        let mut frame_idx = 0;
402        let mut pos = 0;
403
404        while pos < samples.len() {
405            let remaining = samples.len() - pos;
406
407            let frame = Self::extract_frame(samples, pos, frame_length, remaining)?;
408
409            let processed =
410                self.process_single_frame(&frame, vad_context, remaining >= frame_length)?;
411
412            Self::accumulate_frame_output(&processed, &mut output, pos);
413
414            frame_idx += 1;
415
416            if remaining < hop_length {
417                break;
418            }
419            pos += hop_length;
420        }
421
422        Ok((output, frame_idx))
423    }
424
425    /// Extract a frame from the input samples, zero-padding if partial.
426    fn extract_frame(
427        samples: &[f32],
428        pos: usize,
429        frame_length: usize,
430        remaining: usize,
431    ) -> Result<Vec<f32>> {
432        let mut frame_buf = vec![0.0; frame_length];
433
434        if remaining >= frame_length {
435            let src = samples
436                .get(pos..pos + frame_length)
437                .ok_or_else(|| Error::Processing("frame window out of bounds".into()))?;
438            frame_buf.copy_from_slice(src);
439        } else {
440            let src = samples
441                .get(pos..)
442                .ok_or_else(|| Error::Processing("frame tail out of bounds".into()))?;
443            if let Some(dst) = frame_buf.get_mut(..remaining) {
444                dst.copy_from_slice(src);
445            }
446        }
447
448        Ok(frame_buf)
449    }
450
451    /// Process a single frame through FFT, spectral subtraction, and IFFT.
452    fn process_single_frame(
453        &mut self,
454        frame: &[f32],
455        vad_context: Option<VadContext>,
456        is_full_frame: bool,
457    ) -> Result<Vec<f32>> {
458        let fft_size = self.config.fft_size();
459
460        let windowed: Vec<f32> = frame
461            .iter()
462            .zip(&self.window)
463            .map(|(&s, &w)| s * w)
464            .collect();
465
466        let complex_spectrum = self.forward_fft_complex(&windowed)?;
467        let magnitudes: Vec<f32> = complex_spectrum.iter().map(|c| c.norm()).collect();
468
469        let is_silence = vad_context.is_some_and(|ctx| ctx.is_silence);
470        if is_silence && is_full_frame {
471            self.update_noise_profile(&magnitudes);
472        }
473
474        let cleaned_magnitudes = self.spectral_subtract(&magnitudes);
475
476        let cleaned_complex =
477            Self::reconstruct_complex_spectrum(&complex_spectrum, &cleaned_magnitudes);
478
479        let time_signal = self.inverse_fft_complex(&cleaned_complex, fft_size)?;
480
481        let windowed_output: Vec<f32> = time_signal
482            .iter()
483            .take(frame.len())
484            .zip(&self.window)
485            .map(|(&s, &w)| s * w)
486            .collect();
487
488        Ok(windowed_output)
489    }
490
491    /// Reconstruct complex spectrum preserving phase from original signal.
492    fn reconstruct_complex_spectrum(
493        original_spectrum: &[realfft::num_complex::Complex<f32>],
494        cleaned_magnitudes: &[f32],
495    ) -> Vec<realfft::num_complex::Complex<f32>> {
496        original_spectrum
497            .iter()
498            .zip(cleaned_magnitudes)
499            .enumerate()
500            .map(|(i, (original, &new_mag))| {
501                if i == 0 || i == original_spectrum.len() - 1 {
502                    // DC and Nyquist bins must be real-valued
503                    realfft::num_complex::Complex::new(new_mag, 0.0)
504                } else {
505                    let phase = original.arg();
506                    realfft::num_complex::Complex::from_polar(new_mag, phase)
507                }
508            })
509            .collect()
510    }
511
512    /// Accumulate processed frame into output buffer (overlap-add).
513    fn accumulate_frame_output(frame: &[f32], output: &mut [f32], pos: usize) {
514        for (i, &sample) in frame.iter().enumerate() {
515            let out_idx = pos + i;
516            if let Some(dst) = output.get_mut(out_idx) {
517                *dst += sample;
518            }
519        }
520    }
521
522    /// Normalize output by overlap-add window sum.
523    fn normalize_overlap_add(&self, output: &mut [f32]) {
524        let hop_length = self.config.hop_length();
525        let window_sum = self.calculate_window_overlap_sum(hop_length);
526
527        if window_sum > 1e-6 {
528            for sample in output {
529                *sample /= window_sum;
530            }
531        }
532    }
533
534    fn record_performance_metrics(
535        &self,
536        input: &[f32],
537        output: &[f32],
538        latency_ms: f64,
539        frame_count: usize,
540    ) {
541        if input.len() < 8000 {
542            return;
543        }
544
545        if latency_ms > 15.0 {
546            warn!(
547                target: "audio.preprocess.noise_reduction",
548                latency_ms,
549                samples = input.len(),
550                frames = frame_count,
551                oversubtraction = self.config.oversubtraction_factor,
552                spectral_floor = self.config.spectral_floor,
553                "noise reduction latency exceeded target"
554            );
555        }
556
557        let avg_noise_floor = self.noise_floor().max(1e-12);
558        let noise_floor_db = 20.0 * avg_noise_floor.log10();
559
560        let signal_power_out =
561            output.iter().map(|sample| sample * sample).sum::<f32>() / output.len() as f32;
562        let residual_power: f32 = input
563            .iter()
564            .zip(output)
565            .map(|(&noisy, &clean)| {
566                let residual = noisy - clean;
567                residual * residual
568            })
569            .sum::<f32>()
570            / output.len() as f32;
571
572        let snr_improvement_db = if residual_power > 1e-12 && signal_power_out > 0.0 {
573            10.0 * (signal_power_out / residual_power).log10()
574        } else {
575            0.0
576        };
577
578        info!(
579            target: "audio.preprocess.noise_reduction",
580            noise_floor_db,
581            snr_improvement_db,
582            latency_ms,
583            frames = frame_count,
584            samples = input.len(),
585            oversubtraction = self.config.oversubtraction_factor,
586            spectral_floor = self.config.spectral_floor,
587            "noise reduction metrics"
588        );
589    }
590
591    /// Reset noise profile for new audio stream.
592    ///
593    /// Clears noise estimate and overlap-add state.
594    /// Use this when starting a new, independent audio stream.
595    pub fn reset(&mut self) {
596        self.noise_profile.fill(1e-6);
597        self.noise_initialized = false;
598        self.overlap_buffer.fill(0.0);
599    }
600
601    /// Get current average noise floor (for debugging/observability).
602    #[must_use]
603    pub fn noise_floor(&self) -> f32 {
604        if self.noise_profile.is_empty() {
605            return 0.0;
606        }
607        self.noise_profile.iter().sum::<f32>() / self.noise_profile.len() as f32
608    }
609
610    /// Get current configuration.
611    #[must_use]
612    pub fn config(&self) -> &NoiseReductionConfig {
613        &self.config
614    }
615
616    // Forward FFT with zero-padding (returns complex spectrum)
617    fn forward_fft_complex(
618        &self,
619        windowed: &[f32],
620    ) -> Result<Vec<realfft::num_complex::Complex<f32>>> {
621        // Prepare input buffer (zero-padded to FFT size)
622        let mut input = self.fft_forward.make_input_vec();
623        for (i, &sample) in windowed.iter().enumerate() {
624            if let Some(dst) = input.get_mut(i) {
625                *dst = sample;
626            }
627        }
628
629        // Perform FFT
630        let mut spectrum = self.fft_forward.make_output_vec();
631        self.fft_forward
632            .process(&mut input, &mut spectrum)
633            .map_err(|e| Error::Processing(format!("FFT failed: {e}")))?;
634
635        Ok(spectrum)
636    }
637
638    // Inverse FFT from complex spectrum (preserves phase)
639    fn inverse_fft_complex(
640        &self,
641        complex_spectrum: &[realfft::num_complex::Complex<f32>],
642        fft_size: usize,
643    ) -> Result<Vec<f32>> {
644        // Prepare input buffer
645        let mut spectrum = self.fft_inverse.make_input_vec();
646        for (i, &c) in complex_spectrum.iter().enumerate() {
647            if let Some(bin) = spectrum.get_mut(i) {
648                *bin = c;
649            }
650        }
651
652        // Perform inverse FFT
653        let mut output = self.fft_inverse.make_output_vec();
654        self.fft_inverse
655            .process(&mut spectrum, &mut output)
656            .map_err(|e| Error::Processing(format!("IFFT failed: {e}")))?;
657
658        // Normalize by FFT size
659        for sample in &mut output {
660            *sample /= fft_size as f32;
661        }
662
663        Ok(output)
664    }
665
666    // Update noise profile using exponential moving average
667    fn update_noise_profile(&mut self, spectrum: &[f32]) {
668        let alpha = self.config.noise_smoothing;
669
670        if self.noise_initialized {
671            // EMA update: N_new[k] = α * N_old[k] + (1-α) * |X[k]|
672            for (noise, &current) in self.noise_profile.iter_mut().zip(spectrum.iter()) {
673                *noise = alpha.mul_add(*noise, (1.0 - alpha) * current);
674            }
675        } else {
676            // First silence frame: initialize noise profile
677            self.noise_profile.copy_from_slice(spectrum);
678            self.noise_initialized = true;
679        }
680    }
681
682    // Apply spectral subtraction: |Y[k]| = max(|X[k]| - α*|N[k]|, β*|N[k]|)
683    fn spectral_subtract(&self, spectrum: &[f32]) -> Vec<f32> {
684        let alpha = self.config.oversubtraction_factor;
685        let beta = self.config.spectral_floor;
686
687        spectrum
688            .iter()
689            .zip(&self.noise_profile)
690            .map(|(&signal, &noise)| {
691                let subtracted = alpha.mul_add(-noise, signal);
692                let floor = beta * noise;
693                subtracted.max(floor)
694            })
695            .collect()
696    }
697
698    // Calculate window overlap sum for COLA normalization
699    fn calculate_window_overlap_sum(&self, hop_length: usize) -> f32 {
700        let frame_length = self.window.len();
701        let mut sum: f32 = 0.0;
702
703        // Sum overlapping windows at each sample position
704        for i in 0..frame_length {
705            let mut overlap: f32 = 0.0;
706            let mut offset = 0;
707
708            while offset <= i {
709                if let Some(&w) = self.window.get(i - offset) {
710                    overlap = w.mul_add(w, overlap); // Window applied twice
711                                                     // (analysis +
712                                                     // synthesis)
713                }
714                offset += hop_length;
715            }
716
717            sum = sum.max(overlap);
718        }
719
720        sum
721    }
722}
723
724/// Generate Hann window.
725///
726/// Formula: `w[n] = 0.5 - 0.5 * cos(2π * n / (N-1))`
727fn generate_hann_window(length: usize) -> Vec<f32> {
728    if length == 0 {
729        return Vec::new();
730    }
731
732    if length == 1 {
733        return vec![1.0];
734    }
735
736    let denom = (length - 1) as f32;
737    (0..length)
738        .map(|n| {
739            let angle = 2.0 * PI * n as f32 / denom;
740            0.5f32.mul_add(-angle.cos(), 0.5)
741        })
742        .collect()
743}
744
745fn elapsed_duration(start: AudioInstant) -> AudioDuration {
746    AudioInstant::now().duration_since(start)
747}
748
749#[cfg(test)]
750mod tests {
751    use super::*;
752
753    type TestResult<T> = std::result::Result<T, String>;
754
755    #[test]
756    #[allow(clippy::unnecessary_wraps)]
757    fn test_configuration_validation() -> TestResult<()> {
758        // Valid configuration
759        let valid = NoiseReductionConfig::default();
760        assert!(valid.validate().is_ok());
761
762        // Invalid sample rate
763        let invalid_sr = NoiseReductionConfig {
764            sample_rate_hz: 5000,
765            ..Default::default()
766        };
767        assert!(invalid_sr.validate().is_err());
768
769        // Invalid window size
770        let invalid_window = NoiseReductionConfig {
771            window_ms: 100.0,
772            ..Default::default()
773        };
774        assert!(invalid_window.validate().is_err());
775
776        // Hop >= window
777        let invalid_hop = NoiseReductionConfig {
778            hop_ms: 30.0,
779            window_ms: 25.0,
780            ..Default::default()
781        };
782        assert!(invalid_hop.validate().is_err());
783
784        // Invalid oversubtraction
785        let invalid_alpha = NoiseReductionConfig {
786            oversubtraction_factor: 5.0,
787            ..Default::default()
788        };
789        assert!(invalid_alpha.validate().is_err());
790
791        // Invalid spectral floor
792        let invalid_beta = NoiseReductionConfig {
793            spectral_floor: 0.5,
794            ..Default::default()
795        };
796        assert!(invalid_beta.validate().is_err());
797
798        // Invalid noise smoothing
799        let invalid_smoothing = NoiseReductionConfig {
800            noise_smoothing: 1.0,
801            ..Default::default()
802        };
803        assert!(invalid_smoothing.validate().is_err());
804
805        Ok(())
806    }
807
808    #[test]
809    fn test_hann_window_properties() {
810        // Empty window
811        let window_0 = generate_hann_window(0);
812        assert!(window_0.is_empty());
813
814        // Single element
815        let window_1 = generate_hann_window(1);
816        assert_eq!(window_1.len(), 1);
817        assert!((window_1[0] - 1.0).abs() < 1e-6);
818
819        // Check Hann window properties
820        let window = generate_hann_window(100);
821        assert_eq!(window.len(), 100);
822
823        // First and last samples should be near zero
824        assert!(window[0].abs() < 1e-6);
825        assert!(window[99].abs() < 1e-6);
826
827        // Middle sample should be near 1.0
828        assert!((window[50] - 1.0).abs() < 0.1);
829    }
830
831    #[test]
832    fn test_noise_reducer_creation() -> TestResult<()> {
833        let config = NoiseReductionConfig::default();
834        let reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
835
836        // Check initialization
837        assert_eq!(reducer.config().sample_rate_hz, 16000);
838        assert!(reducer.noise_floor() > 0.0); // Initial estimate
839
840        Ok(())
841    }
842
843    #[test]
844    fn test_empty_input() -> TestResult<()> {
845        let config = NoiseReductionConfig::default();
846        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
847
848        let output = reducer.reduce(&[], None).map_err(|e| e.to_string())?;
849        assert!(output.is_empty());
850
851        Ok(())
852    }
853
854    #[test]
855    fn test_bypass_mode() -> TestResult<()> {
856        let config = NoiseReductionConfig {
857            enable: false,
858            ..Default::default()
859        };
860        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
861
862        let input = vec![0.1, 0.2, 0.3, 0.4];
863        let output = reducer.reduce(&input, None).map_err(|e| e.to_string())?;
864
865        // Bypass mode should return input unchanged
866        assert_eq!(output, input);
867
868        Ok(())
869    }
870
871    #[test]
872    fn test_noise_profile_update() -> TestResult<()> {
873        let config = NoiseReductionConfig::default();
874        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
875
876        // Process silence to build noise profile
877        let silence = vec![0.01; 8000]; // Low-level noise
878        let vad_silence = VadContext { is_silence: true };
879
880        let initial_noise = reducer.noise_floor();
881
882        // Process multiple chunks to converge
883        for _ in 0..5 {
884            let _ = reducer
885                .reduce(&silence, Some(vad_silence))
886                .map_err(|e| e.to_string())?;
887        }
888
889        let converged_noise = reducer.noise_floor();
890
891        // Noise floor should increase from initial estimate
892        assert!(
893            converged_noise > initial_noise,
894            "Noise floor should adapt: initial={:.6}, converged={:.6}",
895            initial_noise,
896            converged_noise
897        );
898
899        Ok(())
900    }
901
902    #[test]
903    fn test_vad_informed_noise_update() -> TestResult<()> {
904        let config = NoiseReductionConfig::default();
905        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
906
907        // Initialize with silence
908        let silence = vec![0.01; 8000];
909        let vad_silence = VadContext { is_silence: true };
910        for _ in 0..5 {
911            let _ = reducer
912                .reduce(&silence, Some(vad_silence))
913                .map_err(|e| e.to_string())?;
914        }
915
916        let noise_after_silence = reducer.noise_floor();
917
918        // Process "speech" (should NOT update noise profile)
919        let speech = vec![0.5; 8000];
920        let vad_speech = VadContext { is_silence: false };
921        let _ = reducer
922            .reduce(&speech, Some(vad_speech))
923            .map_err(|e| e.to_string())?;
924
925        let noise_after_speech = reducer.noise_floor();
926
927        // Noise profile should remain stable during speech
928        let diff = (noise_after_speech - noise_after_silence).abs();
929        assert!(
930            diff < noise_after_silence * 0.01,
931            "Noise profile changed during speech: {:.6} -> {:.6}",
932            noise_after_silence,
933            noise_after_speech
934        );
935
936        Ok(())
937    }
938
939    #[test]
940    fn test_reset_clears_state() -> TestResult<()> {
941        let config = NoiseReductionConfig::default();
942        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
943
944        // Process some audio
945        let samples = vec![0.1; 8000];
946        let vad = VadContext { is_silence: true };
947        let _ = reducer
948            .reduce(&samples, Some(vad))
949            .map_err(|e| e.to_string())?;
950
951        let noise_before = reducer.noise_floor();
952        assert!(noise_before > 1e-5, "Noise profile should be updated");
953
954        // Reset
955        reducer.reset();
956
957        let noise_after = reducer.noise_floor();
958        assert!(
959            noise_after < 1e-5,
960            "Noise profile should be reset to initial value"
961        );
962
963        Ok(())
964    }
965
966    // Helper: Generate sine wave
967    fn generate_sine_wave(
968        frequency: f32,
969        sample_rate: u32,
970        duration_secs: f32,
971        amplitude: f32,
972    ) -> Vec<f32> {
973        use std::f32::consts::PI;
974        let samples = (sample_rate as f32 * duration_secs).round() as usize;
975        (0..samples)
976            .map(|i| {
977                let t = i as f32 / sample_rate as f32;
978                (2.0 * PI * frequency * t).sin() * amplitude
979            })
980            .collect()
981    }
982
983    // Helper: Add white noise to signal
984    fn add_white_noise(signal: &[f32], noise_amplitude: f32) -> Vec<f32> {
985        use rand::Rng;
986        let mut rng = rand::rng();
987        signal
988            .iter()
989            .map(|&s| {
990                let noise: f32 = rng.random_range(-noise_amplitude..noise_amplitude);
991                s + noise
992            })
993            .collect()
994    }
995    fn add_low_freq_hum(
996        signal: &[f32],
997        sample_rate: u32,
998        frequency: f32,
999        amplitude: f32,
1000    ) -> Vec<f32> {
1001        signal
1002            .iter()
1003            .enumerate()
1004            .map(|(i, &sample)| {
1005                let t = i as f32 / sample_rate as f32;
1006                let hum = (2.0 * PI * frequency * t).sin() * amplitude;
1007                sample + hum
1008            })
1009            .collect()
1010    }
1011
1012    // Helper: Add café-like ambient noise (stationary broadband noise)
1013    // Simulates background café noise: HVAC, dishes, distant ambient chatter.
1014    // Uses white noise as a proxy for band-limited stationary noise (100-3000 Hz
1015    // typical). NOTE: Spectral subtraction works for STATIONARY noise, not
1016    // speech-like babble.
1017    fn add_cafe_noise(signal: &[f32], _sample_rate: u32, amplitude: f32) -> Vec<f32> {
1018        use rand::Rng;
1019        let mut rng = rand::rng();
1020        signal
1021            .iter()
1022            .map(|&sample| {
1023                let noise: f32 = rng.random_range(-1.0..1.0);
1024                amplitude.mul_add(noise, sample)
1025            })
1026            .collect()
1027    }
1028
1029    // Helper: Calculate SNR
1030    fn calculate_snr(clean: &[f32], noisy: &[f32]) -> f32 {
1031        if clean.len() != noisy.len() {
1032            return 0.0;
1033        }
1034
1035        let signal_power: f32 = clean.iter().map(|&x| x * x).sum();
1036        let noise: Vec<f32> = clean
1037            .iter()
1038            .zip(noisy.iter())
1039            .map(|(&c, &n)| n - c)
1040            .collect();
1041        let noise_power: f32 = noise.iter().map(|&x| x * x).sum();
1042
1043        if noise_power < 1e-10 {
1044            return 100.0; // Very high SNR
1045        }
1046
1047        10.0 * (signal_power / noise_power).log10()
1048    }
1049
1050    #[test]
1051    fn test_snr_improvement_white_noise() -> TestResult<()> {
1052        let config = NoiseReductionConfig::default();
1053        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1054
1055        // Generate clean speech signal (440 Hz sine wave)
1056        let clean_speech = generate_sine_wave(440.0, 16000, 1.0, 0.5);
1057
1058        // Add white noise (creating ~5 dB input SNR)
1059        let noisy_speech = add_white_noise(&clean_speech, 0.3);
1060
1061        let snr_before = calculate_snr(&clean_speech, &noisy_speech);
1062
1063        // Initialize noise profile with pure noise
1064        // Training: 10 iterations ensures EMA convergence (α=0.98 requires ~50 samples
1065        // for 95% convergence)
1066        let pure_noise = add_white_noise(&vec![0.0; 8000], 0.3);
1067        let vad_silence = VadContext { is_silence: true };
1068        for _ in 0..10 {
1069            let _ = reducer
1070                .reduce(&pure_noise, Some(vad_silence))
1071                .map_err(|e| e.to_string())?;
1072        }
1073
1074        // Apply noise reduction to noisy speech
1075        let vad_speech = VadContext { is_silence: false };
1076        let denoised = reducer
1077            .reduce(&noisy_speech, Some(vad_speech))
1078            .map_err(|e| e.to_string())?;
1079
1080        let snr_after = calculate_snr(&clean_speech, &denoised);
1081        let improvement = snr_after - snr_before;
1082
1083        // Success criterion: ≥6 dB improvement
1084        assert!(
1085            improvement >= 6.0,
1086            "SNR improvement {:.1} dB < 6 dB target",
1087            improvement
1088        );
1089
1090        Ok(())
1091    }
1092
1093    #[test]
1094    fn test_snr_improvement_low_freq_hum() -> TestResult<()> {
1095        let config = NoiseReductionConfig::default();
1096        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1097
1098        // Generate 440 Hz speech with 60 Hz HVAC hum (common electrical interference)
1099        let clean = generate_sine_wave(440.0, 16000, 1.0, 0.4);
1100        let noisy = add_low_freq_hum(&clean, 16000, 60.0, 0.3);
1101        let snr_before = calculate_snr(&clean, &noisy);
1102
1103        // Train on pure 60 Hz hum
1104        // Training: 6 iterations sufficient for tonal noise (faster convergence than
1105        // broadband)
1106        let hum_only = add_low_freq_hum(&vec![0.0; 8000], 16000, 60.0, 0.3);
1107        let vad = VadContext { is_silence: true };
1108        for _ in 0..6 {
1109            let _ = reducer
1110                .reduce(&hum_only, Some(vad))
1111                .map_err(|e| e.to_string())?;
1112        }
1113
1114        let vad_speech = VadContext { is_silence: false };
1115        let denoised = reducer
1116            .reduce(&noisy, Some(vad_speech))
1117            .map_err(|e| e.to_string())?;
1118        let snr_after = calculate_snr(&clean, &denoised);
1119        let improvement = snr_after - snr_before;
1120        assert!(
1121            improvement >= 6.0,
1122            "Hum SNR improvement {:.1} dB < 6 dB target",
1123            improvement
1124        );
1125
1126        Ok(())
1127    }
1128
1129    #[test]
1130    fn test_snr_improvement_cafe_ambient() -> TestResult<()> {
1131        // Use default config - stationary noise doesn't need aggressive oversubtraction
1132        let config = NoiseReductionConfig::default();
1133        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1134
1135        // Generate clean speech signal
1136        let clean = generate_sine_wave(220.0, 16000, 1.0, 0.4);
1137
1138        // Add stationary café ambient noise (HVAC, dishes, background chatter)
1139        let noisy = add_cafe_noise(&clean, 16000, 0.25);
1140        let snr_before = calculate_snr(&clean, &noisy);
1141
1142        // Train noise profile on café ambient noise during "silence"
1143        // Training: 10 iterations for broadband stationary noise (white noise requires
1144        // more samples than tonal) Noise amplitude 0.25 creates ~5-6 dB input
1145        // SNR, realistic for café environment
1146        let cafe_only = add_cafe_noise(&vec![0.0; 8000], 16000, 0.25);
1147        let vad = VadContext { is_silence: true };
1148        for _ in 0..10 {
1149            let _ = reducer
1150                .reduce(&cafe_only, Some(vad))
1151                .map_err(|e| e.to_string())?;
1152        }
1153
1154        // Apply noise reduction to noisy speech
1155        let vad_speech = VadContext { is_silence: false };
1156        let denoised = reducer
1157            .reduce(&noisy, Some(vad_speech))
1158            .map_err(|e| e.to_string())?;
1159
1160        let snr_after = calculate_snr(&clean, &denoised);
1161        let improvement = snr_after - snr_before;
1162
1163        assert!(
1164            improvement >= 6.0,
1165            "Café ambient SNR improvement {:.1} dB < 6 dB target",
1166            improvement
1167        );
1168
1169        Ok(())
1170    }
1171
1172    #[test]
1173    fn test_trailing_partial_frame_preserved() -> TestResult<()> {
1174        let config = NoiseReductionConfig::default();
1175        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1176
1177        // Prime noise profile with a silence chunk so spectral subtraction behaves
1178        // normally.
1179        let silence = vec![0.0; 8000];
1180        let vad_silence = VadContext { is_silence: true };
1181        let _ = reducer
1182            .reduce(&silence, Some(vad_silence))
1183            .map_err(|e| e.to_string())?;
1184
1185        // Speech chunk length intentionally not divisible by hop size (adds 80-sample
1186        // tail).
1187        let speech_len = 8080;
1188        let speech: Vec<f32> = (0..speech_len)
1189            .map(|i| {
1190                let phase = (i as f32 / speech_len as f32) * 20.0;
1191                phase.sin()
1192            })
1193            .collect();
1194
1195        let vad_speech = VadContext { is_silence: false };
1196        let output = reducer
1197            .reduce(&speech, Some(vad_speech))
1198            .map_err(|e| e.to_string())?;
1199
1200        assert_eq!(
1201            output.len(),
1202            speech_len,
1203            "Output length should match input length"
1204        );
1205
1206        let tail = &output[speech_len - 80..];
1207        let tail_energy: f32 = tail.iter().map(|sample| sample.abs()).sum();
1208        assert!(
1209            tail_energy > 1e-3,
1210            "Trailing samples should retain energy, got tail_energy={tail_energy}"
1211        );
1212
1213        Ok(())
1214    }
1215
1216    #[test]
1217    fn test_missing_vad_context_does_not_update_noise_profile() -> TestResult<()> {
1218        let config = NoiseReductionConfig::default();
1219        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1220
1221        // Prime noise profile with explicit silence (non-zero noise so baseline is
1222        // measurable).
1223        let ambient_noise = vec![0.05f32; 8000];
1224        let vad_silence = VadContext { is_silence: true };
1225        reducer
1226            .reduce(&ambient_noise, Some(vad_silence))
1227            .map_err(|e| e.to_string())?;
1228        let baseline_floor = reducer.noise_floor();
1229
1230        // Process speech without VAD context; noise profile should remain unchanged.
1231        let speech = vec![0.2f32; 8000];
1232        let output = reducer.reduce(&speech, None).map_err(|e| e.to_string())?;
1233        let updated_floor = reducer.noise_floor();
1234
1235        let floor_delta = (updated_floor - baseline_floor).abs();
1236        assert!(
1237            floor_delta < baseline_floor.max(1e-6) * 0.01,
1238            "Noise floor changed when VAD context missing: baseline={baseline_floor}, \
1239             updated={updated_floor}"
1240        );
1241
1242        let output_rms =
1243            (output.iter().map(|sample| sample * sample).sum::<f32>() / output.len() as f32).sqrt();
1244        assert!(
1245            output_rms > 0.08,
1246            "Speech energy collapsed without VAD context (rms={output_rms})"
1247        );
1248
1249        Ok(())
1250    }
1251}