Skip to main content

speech_prep/preprocessing/
noise_reduction.rs

1//! Noise reduction via spectral subtraction.
2//!
3//! Reduces background noise while preserving speech intelligibility using
4//! adaptive spectral subtraction with VAD-informed noise profiling.
5//!
6//! # Capabilities
7//!
8//! - **Stationary Noise Reduction**: Effectively removes constant background
9//!   noise (HVAC hum, white noise, fan noise, café ambience)
10//! - **≥6 dB SNR Improvement**: Validated on white noise, low-frequency hum,
11//!   and ambient café noise
12//! - **Phase Preservation**: Maintains speech intelligibility by preserving
13//!   original signal phase
14//! - **VAD Integration**: Adapts noise profile only during detected silence
15//! - **Real-Time**: <15ms latency per 500ms chunk (typically 0.2-0.3ms)
16//!
17//! # Limitations
18//!
19//! **Spectral subtraction is designed for STATIONARY noise only.**
20//!
21//! - **Non-stationary noise**: Struggles with time-varying noise (individual
22//!   voices, music, babble with distinct speakers). Use Wiener filtering or
23//!   deep learning approaches for non-stationary scenarios.
24//! - **Speech-like interference**: Cannot separate overlapping speakers or
25//!   remove foreground speech interference (requires source separation
26//!   techniques).
27//! - **Musical noise artifacts**: Tonal artifacts may occur with aggressive
28//!   settings. Mitigated via spectral floor parameter (β=0.02 default).
29//! - **Transient noise**: Impulsive sounds (door slams, clicks) are not handled
30//!   well. Consider median filtering for transient suppression.
31//!
32//! # When to Use This
33//!
34//! ✅ **Good fit**:
35//! - Background HVAC/fan noise
36//! - Café/restaurant ambient noise (general chatter blur, dishes)
37//! - Low-frequency hum (electrical interference)
38//! - Stationary white/pink noise
39//!
40//! ❌ **Poor fit**:
41//! - Multi-speaker separation (babble with distinct voices)
42//! - Music removal
43//! - Non-stationary interference
44//! - Echo/reverb reduction (use AEC instead)
45
46use std::f32::consts::PI;
47use std::sync::Arc;
48
49use crate::error::{Error, Result};
50use crate::time::{AudioDuration, AudioInstant};
51use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
52use tracing::{info, warn};
53
54use super::artifacts::WavArtifactWriter;
55use super::VadContext;
56
57/// Configuration for noise reduction via spectral subtraction.
58///
59/// # Examples
60///
61/// ```rust,no_run
62/// use speech_prep::preprocessing::NoiseReductionConfig;
63///
64/// // Default: 25ms window, 10ms hop, α=2.0, β=0.02
65/// let config = NoiseReductionConfig::default();
66///
67/// // Aggressive noise removal for noisy environment
68/// let config = NoiseReductionConfig {
69///     oversubtraction_factor: 2.5,
70///     spectral_floor: 0.01,
71///     ..Default::default()
72/// };
73/// # Ok::<(), speech_prep::error::Error>(())
74/// ```
75#[derive(Debug, Clone)]
76#[allow(missing_copy_implementations)]
77pub struct NoiseReductionConfig {
78    /// Sample rate in Hz.
79    ///
80    /// **Default**: 16000
81    /// **Range**: 8000-48000
82    pub sample_rate_hz: u32,
83
84    /// STFT window duration in milliseconds.
85    ///
86    /// **Default**: 25.0
87    /// **Range**: 10.0-50.0
88    ///
89    /// **Effect**: Longer windows improve frequency resolution but reduce
90    /// time resolution. 25ms captures 1-3 pitch periods for typical speech.
91    pub window_ms: f32,
92
93    /// STFT hop duration in milliseconds.
94    ///
95    /// **Default**: 10.0
96    /// **Range**: 5.0-25.0
97    ///
98    /// **Effect**: Smaller hops increase overlap (smoother reconstruction)
99    /// but require more computation. 10ms hop = 60% overlap with 25ms window.
100    pub hop_ms: f32,
101
102    /// Oversubtraction factor (α).
103    ///
104    /// **Default**: 2.0
105    /// **Range**: 1.0-3.0
106    ///
107    /// **Effect**: Multiplier for noise estimate in spectral subtraction.
108    /// - Higher: More aggressive noise removal, more artifacts
109    /// - Lower: Conservative removal, less SNR gain
110    pub oversubtraction_factor: f32,
111
112    /// Spectral floor (β) as fraction of noise estimate.
113    ///
114    /// **Default**: 0.02 (2% of noise estimate)
115    /// **Range**: 0.001-0.1
116    ///
117    /// **Effect**: Minimum magnitude after subtraction to prevent musical
118    /// noise. Acts as a noise gate.
119    pub spectral_floor: f32,
120
121    /// Noise profile smoothing factor (`α_noise`).
122    ///
123    /// **Default**: 0.98
124    /// **Range**: 0.9-0.999
125    ///
126    /// **Effect**: Exponential moving average smoothing for noise profile.
127    /// Higher values = slower adaptation, more stable estimate.
128    pub noise_smoothing: f32,
129
130    /// Enable noise reduction.
131    ///
132    /// **Default**: true
133    ///
134    /// **Effect**: When false, audio passes through unmodified (bypass mode).
135    pub enable: bool,
136}
137
138impl Default for NoiseReductionConfig {
139    fn default() -> Self {
140        Self {
141            sample_rate_hz: 16_000,
142            window_ms: 25.0,
143            hop_ms: 10.0,
144            oversubtraction_factor: 2.0,
145            spectral_floor: 0.02,
146            noise_smoothing: 0.98,
147            enable: true,
148        }
149    }
150}
151
152impl NoiseReductionConfig {
153    /// Validate configuration parameters.
154    ///
155    /// # Errors
156    ///
157    /// Returns `Error::Configuration` if:
158    /// - `sample_rate_hz` not in 8000-48000 Hz
159    /// - `window_ms` not in 10.0-50.0 ms
160    /// - `hop_ms` >= `window_ms` (overlap required)
161    /// - `oversubtraction_factor` not in 1.0-3.0
162    /// - `spectral_floor` not in 0.001-0.1
163    /// - `noise_smoothing` not in 0.9-0.999
164    #[allow(clippy::trivially_copy_pass_by_ref)]
165    pub fn validate(&self) -> Result<()> {
166        if !(8000..=48_000).contains(&self.sample_rate_hz) {
167            return Err(Error::Configuration(format!(
168                "Invalid sample rate: {} Hz (range: 8000-48000)",
169                self.sample_rate_hz
170            )));
171        }
172
173        if !(10.0..=50.0).contains(&self.window_ms) {
174            return Err(Error::Configuration(format!(
175                "Invalid window size: {:.1} ms (range: 10-50)",
176                self.window_ms
177            )));
178        }
179
180        if self.hop_ms >= self.window_ms {
181            return Err(Error::Configuration(format!(
182                "Hop {:.1} ms must be < window {:.1} ms",
183                self.hop_ms, self.window_ms
184            )));
185        }
186
187        if !(1.0..=3.0).contains(&self.oversubtraction_factor) {
188            return Err(Error::Configuration(format!(
189                "Invalid oversubtraction factor: {:.2} (range: 1.0-3.0)",
190                self.oversubtraction_factor
191            )));
192        }
193
194        if !(0.001..=0.1).contains(&self.spectral_floor) {
195            return Err(Error::Configuration(format!(
196                "Invalid spectral floor: {:.3} (range: 0.001-0.1)",
197                self.spectral_floor
198            )));
199        }
200
201        if !(0.9..1.0).contains(&self.noise_smoothing) {
202            return Err(Error::Configuration(format!(
203                "Invalid noise smoothing: {:.3} (range: 0.9-0.999)",
204                self.noise_smoothing
205            )));
206        }
207
208        Ok(())
209    }
210
211    /// Calculate frame length in samples.
212    pub fn frame_length(&self) -> usize {
213        ((self.window_ms / 1000.0) * self.sample_rate_hz as f32).round() as usize
214    }
215
216    /// Calculate hop length in samples.
217    pub fn hop_length(&self) -> usize {
218        ((self.hop_ms / 1000.0) * self.sample_rate_hz as f32).round() as usize
219    }
220
221    /// Calculate FFT size (next power of 2 >= frame length).
222    pub fn fft_size(&self) -> usize {
223        self.frame_length().next_power_of_two()
224    }
225}
226
227/// Noise reduction via spectral subtraction with adaptive noise profiling.
228///
229/// Implements the noise reduction specification:
230/// - STFT-based processing (25ms window, 10ms hop)
231/// - Adaptive noise profile estimation during VAD-detected silence
232/// - Magnitude-only spectral subtraction (preserves phase)
233/// - Achieves ≥6 dB SNR improvement target
234///
235/// # Performance
236///
237/// - **Target**: <15ms per 500ms chunk (8000 samples @ 16kHz)
238/// - **Expected**: ~7ms (2x headroom)
239/// - **Optimization**: Precomputed FFT plans, reused buffers
240///
241/// # Example
242///
243/// ```rust,no_run
244/// use speech_prep::preprocessing::{NoiseReducer, NoiseReductionConfig, VadContext};
245///
246/// # fn main() -> speech_prep::error::Result<()> {
247/// let config = NoiseReductionConfig::default();
248/// let mut reducer = NoiseReducer::new(config)?;
249/// let audio_stream = vec![vec![0.0; 8000], vec![0.05; 8080]];
250///
251/// // Process streaming chunks with VAD context
252/// for chunk in audio_stream {
253///     let vad_ctx = VadContext { is_silence: detect_silence(&chunk) };
254///     let _denoised = reducer.reduce(&chunk, Some(vad_ctx))?;
255/// }
256/// # Ok(())
257/// # }
258/// #
259/// # fn detect_silence(chunk: &[f32]) -> bool {
260/// #     chunk.iter().all(|sample| sample.abs() < 1e-3)
261/// # }
262/// ```
263#[allow(missing_copy_implementations)]
264pub struct NoiseReducer {
265    config: NoiseReductionConfig,
266    fft_forward: Arc<dyn RealToComplex<f32>>,
267    fft_inverse: Arc<dyn ComplexToReal<f32>>,
268    window: Vec<f32>,
269    noise_profile: Vec<f32>,
270    noise_initialized: bool,
271    overlap_buffer: Vec<f32>,
272}
273
274impl std::fmt::Debug for NoiseReducer {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        f.debug_struct("NoiseReducer")
277            .field("config", &self.config)
278            .field("window_length", &self.window.len())
279            .field("noise_profile_bins", &self.noise_profile.len())
280            .field("noise_initialized", &self.noise_initialized)
281            .finish_non_exhaustive()
282    }
283}
284
285impl NoiseReducer {
286    /// Create a new noise reducer.
287    ///
288    /// # Arguments
289    ///
290    /// * `config` - Configuration parameters (window size, hop, α, β)
291    ///
292    /// # Errors
293    ///
294    /// Returns `Error::Configuration` if configuration is invalid.
295    ///
296    /// # Example
297    ///
298    /// ```rust,no_run
299    /// use speech_prep::preprocessing::{NoiseReducer, NoiseReductionConfig};
300    ///
301    /// let config = NoiseReductionConfig {
302    ///     oversubtraction_factor: 2.5, // Aggressive
303    ///     ..Default::default()
304    /// };
305    /// let reducer = NoiseReducer::new(config)?;
306    /// # Ok::<(), speech_prep::error::Error>(())
307    /// ```
308    pub fn new(config: NoiseReductionConfig) -> Result<Self> {
309        config.validate()?;
310
311        let fft_size = config.fft_size();
312        let frame_length = config.frame_length();
313
314        let mut planner = RealFftPlanner::<f32>::new();
315        let fft_forward = planner.plan_fft_forward(fft_size);
316        let fft_inverse = planner.plan_fft_inverse(fft_size);
317
318        let window = generate_hann_window(frame_length);
319
320        let num_bins = fft_size / 2 + 1;
321        let noise_profile = vec![1e-6; num_bins];
322
323        let overlap_buffer = vec![0.0; frame_length];
324
325        Ok(Self {
326            config,
327            fft_forward,
328            fft_inverse,
329            window,
330            noise_profile,
331            noise_initialized: false,
332            overlap_buffer,
333        })
334    }
335
336    /// Apply noise reduction to audio samples.
337    ///
338    /// # Arguments
339    ///
340    /// * `samples` - Input audio samples (typically 500ms chunk = 8000 samples
341    ///   @ 16kHz)
342    /// * `vad_context` - Optional VAD state for noise profile updates
343    ///
344    /// # Returns
345    ///
346    /// Denoised audio with ≥6 dB SNR improvement on noisy input.
347    ///
348    /// # Performance
349    ///
350    /// - Expected: ~7ms for 8000 samples (2x better than <15ms target)
351    /// - Complexity: O(n log n) for FFT operations
352    ///
353    /// # Example
354    ///
355    /// ```rust,no_run
356    /// use speech_prep::preprocessing::{NoiseReducer, NoiseReductionConfig, VadContext};
357    ///
358    /// let mut reducer = NoiseReducer::new(NoiseReductionConfig::default())?;
359    ///
360    /// // Chunk 1 (silence - initialize noise profile)
361    /// let chunk1 = vec![0.001; 8000];
362    /// let vad1 = VadContext { is_silence: true };
363    /// let output1 = reducer.reduce(&chunk1, Some(vad1))?;
364    ///
365    /// // Chunk 2 (speech - apply noise reduction)
366    /// let chunk2 = vec![0.1; 8000];
367    /// let vad2 = VadContext { is_silence: false };
368    /// let output2 = reducer.reduce(&chunk2, Some(vad2))?;
369    /// # Ok::<(), speech_prep::error::Error>(())
370    /// ```
371    #[allow(clippy::unnecessary_wraps)]
372    pub fn reduce(&mut self, samples: &[f32], vad_context: Option<VadContext>) -> Result<Vec<f32>> {
373        self.reduce_with_artifacts(samples, vad_context, None)
374    }
375
376    /// Apply noise reduction and optionally capture QA artifacts.
377    pub fn reduce_with_artifacts(
378        &mut self,
379        samples: &[f32],
380        vad_context: Option<VadContext>,
381        mut artifacts: Option<&mut NoiseReductionArtifacts<'_>>,
382    ) -> Result<Vec<f32>> {
383        let processing_start = AudioInstant::now();
384
385        if samples.is_empty() {
386            return Ok(Vec::new());
387        }
388
389        if let Some(ref mut recorder) = artifacts {
390            recorder.record_input(samples);
391        }
392
393        if !self.config.enable {
394            return Ok(samples.to_vec());
395        }
396
397        let (mut output, frame_count) = self.process_stft_frames(samples, vad_context)?;
398        self.normalize_overlap_add(&mut output);
399
400        if let Some(ref mut recorder) = artifacts {
401            recorder.record_output(&output);
402        }
403
404        let elapsed = elapsed_duration(processing_start);
405        let latency_ms = elapsed.as_secs_f64() * 1000.0;
406        self.record_performance_metrics(samples, &output, latency_ms, frame_count);
407
408        Ok(output)
409    }
410
411    /// Process audio through STFT frames with spectral subtraction.
412    fn process_stft_frames(
413        &mut self,
414        samples: &[f32],
415        vad_context: Option<VadContext>,
416    ) -> Result<(Vec<f32>, usize)> {
417        let frame_length = self.config.frame_length();
418        let hop_length = self.config.hop_length();
419
420        let mut output = vec![0.0; samples.len()];
421        let mut frame_idx = 0;
422        let mut pos = 0;
423
424        while pos < samples.len() {
425            let remaining = samples.len() - pos;
426
427            let frame = Self::extract_frame(samples, pos, frame_length, remaining)?;
428
429            let processed =
430                self.process_single_frame(&frame, vad_context, remaining >= frame_length)?;
431
432            Self::accumulate_frame_output(&processed, &mut output, pos);
433
434            frame_idx += 1;
435
436            if remaining < hop_length {
437                break;
438            }
439            pos += hop_length;
440        }
441
442        Ok((output, frame_idx))
443    }
444
445    /// Extract a frame from the input samples, zero-padding if partial.
446    fn extract_frame(
447        samples: &[f32],
448        pos: usize,
449        frame_length: usize,
450        remaining: usize,
451    ) -> Result<Vec<f32>> {
452        let mut frame_buf = vec![0.0; frame_length];
453
454        if remaining >= frame_length {
455            let src = samples
456                .get(pos..pos + frame_length)
457                .ok_or_else(|| Error::Processing("frame window out of bounds".into()))?;
458            frame_buf.copy_from_slice(src);
459        } else {
460            let src = samples
461                .get(pos..)
462                .ok_or_else(|| Error::Processing("frame tail out of bounds".into()))?;
463            if let Some(dst) = frame_buf.get_mut(..remaining) {
464                dst.copy_from_slice(src);
465            }
466        }
467
468        Ok(frame_buf)
469    }
470
471    /// Process a single frame through FFT, spectral subtraction, and IFFT.
472    fn process_single_frame(
473        &mut self,
474        frame: &[f32],
475        vad_context: Option<VadContext>,
476        is_full_frame: bool,
477    ) -> Result<Vec<f32>> {
478        let fft_size = self.config.fft_size();
479
480        let windowed: Vec<f32> = frame
481            .iter()
482            .zip(&self.window)
483            .map(|(&s, &w)| s * w)
484            .collect();
485
486        let complex_spectrum = self.forward_fft_complex(&windowed)?;
487        let magnitudes: Vec<f32> = complex_spectrum.iter().map(|c| c.norm()).collect();
488
489        let is_silence = vad_context.is_some_and(|ctx| ctx.is_silence);
490        if is_silence && is_full_frame {
491            self.update_noise_profile(&magnitudes);
492        }
493
494        let cleaned_magnitudes = self.spectral_subtract(&magnitudes);
495
496        let cleaned_complex =
497            Self::reconstruct_complex_spectrum(&complex_spectrum, &cleaned_magnitudes);
498
499        let time_signal = self.inverse_fft_complex(&cleaned_complex, fft_size)?;
500
501        let windowed_output: Vec<f32> = time_signal
502            .iter()
503            .take(frame.len())
504            .zip(&self.window)
505            .map(|(&s, &w)| s * w)
506            .collect();
507
508        Ok(windowed_output)
509    }
510
511    /// Reconstruct complex spectrum preserving phase from original signal.
512    fn reconstruct_complex_spectrum(
513        original_spectrum: &[realfft::num_complex::Complex<f32>],
514        cleaned_magnitudes: &[f32],
515    ) -> Vec<realfft::num_complex::Complex<f32>> {
516        original_spectrum
517            .iter()
518            .zip(cleaned_magnitudes)
519            .enumerate()
520            .map(|(i, (original, &new_mag))| {
521                if i == 0 || i == original_spectrum.len() - 1 {
522                    // DC and Nyquist bins must be real-valued
523                    realfft::num_complex::Complex::new(new_mag, 0.0)
524                } else {
525                    let phase = original.arg();
526                    realfft::num_complex::Complex::from_polar(new_mag, phase)
527                }
528            })
529            .collect()
530    }
531
532    /// Accumulate processed frame into output buffer (overlap-add).
533    fn accumulate_frame_output(frame: &[f32], output: &mut [f32], pos: usize) {
534        for (i, &sample) in frame.iter().enumerate() {
535            let out_idx = pos + i;
536            if let Some(dst) = output.get_mut(out_idx) {
537                *dst += sample;
538            }
539        }
540    }
541
542    /// Normalize output by overlap-add window sum.
543    fn normalize_overlap_add(&self, output: &mut [f32]) {
544        let hop_length = self.config.hop_length();
545        let window_sum = self.calculate_window_overlap_sum(hop_length);
546
547        if window_sum > 1e-6 {
548            for sample in output {
549                *sample /= window_sum;
550            }
551        }
552    }
553
554    fn record_performance_metrics(
555        &self,
556        input: &[f32],
557        output: &[f32],
558        latency_ms: f64,
559        frame_count: usize,
560    ) {
561        if input.len() < 8000 {
562            return;
563        }
564
565        if latency_ms > 15.0 {
566            warn!(
567                target: "audio.preprocess.noise_reduction",
568                latency_ms,
569                samples = input.len(),
570                frames = frame_count,
571                oversubtraction = self.config.oversubtraction_factor,
572                spectral_floor = self.config.spectral_floor,
573                "noise reduction latency exceeded target"
574            );
575        }
576
577        let avg_noise_floor = self.noise_floor().max(1e-12);
578        let noise_floor_db = 20.0 * avg_noise_floor.log10();
579
580        let signal_power_out =
581            output.iter().map(|sample| sample * sample).sum::<f32>() / output.len() as f32;
582        let residual_power: f32 = input
583            .iter()
584            .zip(output)
585            .map(|(&noisy, &clean)| {
586                let residual = noisy - clean;
587                residual * residual
588            })
589            .sum::<f32>()
590            / output.len() as f32;
591
592        let snr_improvement_db = if residual_power > 1e-12 && signal_power_out > 0.0 {
593            10.0 * (signal_power_out / residual_power).log10()
594        } else {
595            0.0
596        };
597
598        info!(
599            target: "audio.preprocess.noise_reduction",
600            noise_floor_db,
601            snr_improvement_db,
602            latency_ms,
603            frames = frame_count,
604            samples = input.len(),
605            oversubtraction = self.config.oversubtraction_factor,
606            spectral_floor = self.config.spectral_floor,
607            "noise reduction metrics"
608        );
609    }
610
611    /// Reset noise profile for new audio stream.
612    ///
613    /// Clears noise estimate and overlap-add state.
614    /// Use this when starting a new, independent audio stream.
615    pub fn reset(&mut self) {
616        self.noise_profile.fill(1e-6);
617        self.noise_initialized = false;
618        self.overlap_buffer.fill(0.0);
619    }
620
621    /// Get current average noise floor (for debugging/observability).
622    #[must_use]
623    pub fn noise_floor(&self) -> f32 {
624        if self.noise_profile.is_empty() {
625            return 0.0;
626        }
627        self.noise_profile.iter().sum::<f32>() / self.noise_profile.len() as f32
628    }
629
630    /// Get current configuration.
631    #[must_use]
632    pub fn config(&self) -> &NoiseReductionConfig {
633        &self.config
634    }
635
636    // Forward FFT with zero-padding (returns complex spectrum)
637    fn forward_fft_complex(
638        &self,
639        windowed: &[f32],
640    ) -> Result<Vec<realfft::num_complex::Complex<f32>>> {
641        // Prepare input buffer (zero-padded to FFT size)
642        let mut input = self.fft_forward.make_input_vec();
643        for (i, &sample) in windowed.iter().enumerate() {
644            if let Some(dst) = input.get_mut(i) {
645                *dst = sample;
646            }
647        }
648
649        // Perform FFT
650        let mut spectrum = self.fft_forward.make_output_vec();
651        self.fft_forward
652            .process(&mut input, &mut spectrum)
653            .map_err(|e| Error::Processing(format!("FFT failed: {e}")))?;
654
655        Ok(spectrum)
656    }
657
658    // Inverse FFT from complex spectrum (preserves phase)
659    fn inverse_fft_complex(
660        &self,
661        complex_spectrum: &[realfft::num_complex::Complex<f32>],
662        fft_size: usize,
663    ) -> Result<Vec<f32>> {
664        // Prepare input buffer
665        let mut spectrum = self.fft_inverse.make_input_vec();
666        for (i, &c) in complex_spectrum.iter().enumerate() {
667            if let Some(bin) = spectrum.get_mut(i) {
668                *bin = c;
669            }
670        }
671
672        // Perform inverse FFT
673        let mut output = self.fft_inverse.make_output_vec();
674        self.fft_inverse
675            .process(&mut spectrum, &mut output)
676            .map_err(|e| Error::Processing(format!("IFFT failed: {e}")))?;
677
678        // Normalize by FFT size
679        for sample in &mut output {
680            *sample /= fft_size as f32;
681        }
682
683        Ok(output)
684    }
685
686    // Update noise profile using exponential moving average
687    fn update_noise_profile(&mut self, spectrum: &[f32]) {
688        let alpha = self.config.noise_smoothing;
689
690        if self.noise_initialized {
691            // EMA update: N_new[k] = α * N_old[k] + (1-α) * |X[k]|
692            for (noise, &current) in self.noise_profile.iter_mut().zip(spectrum.iter()) {
693                *noise = alpha.mul_add(*noise, (1.0 - alpha) * current);
694            }
695        } else {
696            // First silence frame: initialize noise profile
697            self.noise_profile.copy_from_slice(spectrum);
698            self.noise_initialized = true;
699        }
700    }
701
702    // Apply spectral subtraction: |Y[k]| = max(|X[k]| - α*|N[k]|, β*|N[k]|)
703    fn spectral_subtract(&self, spectrum: &[f32]) -> Vec<f32> {
704        let alpha = self.config.oversubtraction_factor;
705        let beta = self.config.spectral_floor;
706
707        spectrum
708            .iter()
709            .zip(&self.noise_profile)
710            .map(|(&signal, &noise)| {
711                let subtracted = alpha.mul_add(-noise, signal);
712                let floor = beta * noise;
713                subtracted.max(floor)
714            })
715            .collect()
716    }
717
718    // Calculate window overlap sum for COLA normalization
719    fn calculate_window_overlap_sum(&self, hop_length: usize) -> f32 {
720        let frame_length = self.window.len();
721        let mut sum: f32 = 0.0;
722
723        // Sum overlapping windows at each sample position
724        for i in 0..frame_length {
725            let mut overlap: f32 = 0.0;
726            let mut offset = 0;
727
728            while offset <= i {
729                if let Some(&w) = self.window.get(i - offset) {
730                    overlap = w.mul_add(w, overlap); // Window applied twice
731                                                     // (analysis +
732                                                     // synthesis)
733                }
734                offset += hop_length;
735            }
736
737            sum = sum.max(overlap);
738        }
739
740        sum
741    }
742}
743
744/// Optional artifact capture for QA workflows.
745pub struct NoiseReductionArtifacts<'a> {
746    before: Option<&'a mut WavArtifactWriter>,
747    after: Option<&'a mut WavArtifactWriter>,
748}
749
750impl<'a> NoiseReductionArtifacts<'a> {
751    /// Create a recorder for the provided before/after writers.
752    pub fn new(
753        before: Option<&'a mut WavArtifactWriter>,
754        after: Option<&'a mut WavArtifactWriter>,
755    ) -> Self {
756        Self { before, after }
757    }
758
759    /// Persist the noisy input samples for QA review.
760    pub fn record_input(&mut self, samples: &[f32]) {
761        if let Some(writer) = self.before.as_deref_mut() {
762            if let Err(err) = writer.write_samples(samples) {
763                warn!(target: "audio.preprocess.noise_reduction", error = %err, "failed to write input artifact");
764            }
765        }
766    }
767
768    /// Persist the denoised output samples for QA review.
769    pub fn record_output(&mut self, samples: &[f32]) {
770        if let Some(writer) = self.after.as_deref_mut() {
771            if let Err(err) = writer.write_samples(samples) {
772                warn!(target: "audio.preprocess.noise_reduction", error = %err, "failed to write output artifact");
773            }
774        }
775    }
776}
777
778impl std::fmt::Debug for NoiseReductionArtifacts<'_> {
779    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
780        f.debug_struct("NoiseReductionArtifacts").finish()
781    }
782}
783
784/// Generate Hann window.
785///
786/// Formula: `w[n] = 0.5 - 0.5 * cos(2π * n / (N-1))`
787fn generate_hann_window(length: usize) -> Vec<f32> {
788    if length == 0 {
789        return Vec::new();
790    }
791
792    if length == 1 {
793        return vec![1.0];
794    }
795
796    let denom = (length - 1) as f32;
797    (0..length)
798        .map(|n| {
799            let angle = 2.0 * PI * n as f32 / denom;
800            0.5f32.mul_add(-angle.cos(), 0.5)
801        })
802        .collect()
803}
804
805fn elapsed_duration(start: AudioInstant) -> AudioDuration {
806    AudioInstant::now().duration_since(start)
807}
808
809#[cfg(test)]
810mod tests {
811    use super::*;
812
813    type TestResult<T> = std::result::Result<T, String>;
814
815    #[test]
816    #[allow(clippy::unnecessary_wraps)]
817    fn test_configuration_validation() -> TestResult<()> {
818        // Valid configuration
819        let valid = NoiseReductionConfig::default();
820        assert!(valid.validate().is_ok());
821
822        // Invalid sample rate
823        let invalid_sr = NoiseReductionConfig {
824            sample_rate_hz: 5000,
825            ..Default::default()
826        };
827        assert!(invalid_sr.validate().is_err());
828
829        // Invalid window size
830        let invalid_window = NoiseReductionConfig {
831            window_ms: 100.0,
832            ..Default::default()
833        };
834        assert!(invalid_window.validate().is_err());
835
836        // Hop >= window
837        let invalid_hop = NoiseReductionConfig {
838            hop_ms: 30.0,
839            window_ms: 25.0,
840            ..Default::default()
841        };
842        assert!(invalid_hop.validate().is_err());
843
844        // Invalid oversubtraction
845        let invalid_alpha = NoiseReductionConfig {
846            oversubtraction_factor: 5.0,
847            ..Default::default()
848        };
849        assert!(invalid_alpha.validate().is_err());
850
851        // Invalid spectral floor
852        let invalid_beta = NoiseReductionConfig {
853            spectral_floor: 0.5,
854            ..Default::default()
855        };
856        assert!(invalid_beta.validate().is_err());
857
858        // Invalid noise smoothing
859        let invalid_smoothing = NoiseReductionConfig {
860            noise_smoothing: 1.0,
861            ..Default::default()
862        };
863        assert!(invalid_smoothing.validate().is_err());
864
865        Ok(())
866    }
867
868    #[test]
869    fn test_hann_window_properties() {
870        // Empty window
871        let window_0 = generate_hann_window(0);
872        assert!(window_0.is_empty());
873
874        // Single element
875        let window_1 = generate_hann_window(1);
876        assert_eq!(window_1.len(), 1);
877        assert!((window_1[0] - 1.0).abs() < 1e-6);
878
879        // Check Hann window properties
880        let window = generate_hann_window(100);
881        assert_eq!(window.len(), 100);
882
883        // First and last samples should be near zero
884        assert!(window[0].abs() < 1e-6);
885        assert!(window[99].abs() < 1e-6);
886
887        // Middle sample should be near 1.0
888        assert!((window[50] - 1.0).abs() < 0.1);
889    }
890
891    #[test]
892    fn test_noise_reducer_creation() -> TestResult<()> {
893        let config = NoiseReductionConfig::default();
894        let reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
895
896        // Check initialization
897        assert_eq!(reducer.config().sample_rate_hz, 16000);
898        assert!(reducer.noise_floor() > 0.0); // Initial estimate
899
900        Ok(())
901    }
902
903    #[test]
904    fn test_empty_input() -> TestResult<()> {
905        let config = NoiseReductionConfig::default();
906        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
907
908        let output = reducer.reduce(&[], None).map_err(|e| e.to_string())?;
909        assert!(output.is_empty());
910
911        Ok(())
912    }
913
914    #[test]
915    fn test_bypass_mode() -> TestResult<()> {
916        let config = NoiseReductionConfig {
917            enable: false,
918            ..Default::default()
919        };
920        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
921
922        let input = vec![0.1, 0.2, 0.3, 0.4];
923        let output = reducer.reduce(&input, None).map_err(|e| e.to_string())?;
924
925        // Bypass mode should return input unchanged
926        assert_eq!(output, input);
927
928        Ok(())
929    }
930
931    #[test]
932    fn test_noise_profile_update() -> TestResult<()> {
933        let config = NoiseReductionConfig::default();
934        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
935
936        // Process silence to build noise profile
937        let silence = vec![0.01; 8000]; // Low-level noise
938        let vad_silence = VadContext { is_silence: true };
939
940        let initial_noise = reducer.noise_floor();
941
942        // Process multiple chunks to converge
943        for _ in 0..5 {
944            let _ = reducer
945                .reduce(&silence, Some(vad_silence))
946                .map_err(|e| e.to_string())?;
947        }
948
949        let converged_noise = reducer.noise_floor();
950
951        // Noise floor should increase from initial estimate
952        assert!(
953            converged_noise > initial_noise,
954            "Noise floor should adapt: initial={:.6}, converged={:.6}",
955            initial_noise,
956            converged_noise
957        );
958
959        Ok(())
960    }
961
962    #[test]
963    fn test_vad_informed_noise_update() -> TestResult<()> {
964        let config = NoiseReductionConfig::default();
965        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
966
967        // Initialize with silence
968        let silence = vec![0.01; 8000];
969        let vad_silence = VadContext { is_silence: true };
970        for _ in 0..5 {
971            let _ = reducer
972                .reduce(&silence, Some(vad_silence))
973                .map_err(|e| e.to_string())?;
974        }
975
976        let noise_after_silence = reducer.noise_floor();
977
978        // Process "speech" (should NOT update noise profile)
979        let speech = vec![0.5; 8000];
980        let vad_speech = VadContext { is_silence: false };
981        let _ = reducer
982            .reduce(&speech, Some(vad_speech))
983            .map_err(|e| e.to_string())?;
984
985        let noise_after_speech = reducer.noise_floor();
986
987        // Noise profile should remain stable during speech
988        let diff = (noise_after_speech - noise_after_silence).abs();
989        assert!(
990            diff < noise_after_silence * 0.01,
991            "Noise profile changed during speech: {:.6} -> {:.6}",
992            noise_after_silence,
993            noise_after_speech
994        );
995
996        Ok(())
997    }
998
999    #[test]
1000    fn test_reset_clears_state() -> TestResult<()> {
1001        let config = NoiseReductionConfig::default();
1002        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1003
1004        // Process some audio
1005        let samples = vec![0.1; 8000];
1006        let vad = VadContext { is_silence: true };
1007        let _ = reducer
1008            .reduce(&samples, Some(vad))
1009            .map_err(|e| e.to_string())?;
1010
1011        let noise_before = reducer.noise_floor();
1012        assert!(noise_before > 1e-5, "Noise profile should be updated");
1013
1014        // Reset
1015        reducer.reset();
1016
1017        let noise_after = reducer.noise_floor();
1018        assert!(
1019            noise_after < 1e-5,
1020            "Noise profile should be reset to initial value"
1021        );
1022
1023        Ok(())
1024    }
1025
1026    // Helper: Generate sine wave
1027    fn generate_sine_wave(
1028        frequency: f32,
1029        sample_rate: u32,
1030        duration_secs: f32,
1031        amplitude: f32,
1032    ) -> Vec<f32> {
1033        use std::f32::consts::PI;
1034        let samples = (sample_rate as f32 * duration_secs).round() as usize;
1035        (0..samples)
1036            .map(|i| {
1037                let t = i as f32 / sample_rate as f32;
1038                (2.0 * PI * frequency * t).sin() * amplitude
1039            })
1040            .collect()
1041    }
1042
1043    // Helper: Add white noise to signal
1044    fn add_white_noise(signal: &[f32], noise_amplitude: f32) -> Vec<f32> {
1045        use rand::Rng;
1046        let mut rng = rand::rng();
1047        signal
1048            .iter()
1049            .map(|&s| {
1050                let noise: f32 = rng.random_range(-noise_amplitude..noise_amplitude);
1051                s + noise
1052            })
1053            .collect()
1054    }
1055    fn add_low_freq_hum(
1056        signal: &[f32],
1057        sample_rate: u32,
1058        frequency: f32,
1059        amplitude: f32,
1060    ) -> Vec<f32> {
1061        signal
1062            .iter()
1063            .enumerate()
1064            .map(|(i, &sample)| {
1065                let t = i as f32 / sample_rate as f32;
1066                let hum = (2.0 * PI * frequency * t).sin() * amplitude;
1067                sample + hum
1068            })
1069            .collect()
1070    }
1071
1072    // Helper: Add café-like ambient noise (stationary broadband noise)
1073    // Simulates background café noise: HVAC, dishes, distant ambient chatter.
1074    // Uses white noise as a proxy for band-limited stationary noise (100-3000 Hz
1075    // typical). NOTE: Spectral subtraction works for STATIONARY noise, not
1076    // speech-like babble.
1077    fn add_cafe_noise(signal: &[f32], _sample_rate: u32, amplitude: f32) -> Vec<f32> {
1078        use rand::Rng;
1079        let mut rng = rand::rng();
1080        signal
1081            .iter()
1082            .map(|&sample| {
1083                let noise: f32 = rng.random_range(-1.0..1.0);
1084                amplitude.mul_add(noise, sample)
1085            })
1086            .collect()
1087    }
1088
1089    // Helper: Calculate SNR
1090    fn calculate_snr(clean: &[f32], noisy: &[f32]) -> f32 {
1091        if clean.len() != noisy.len() {
1092            return 0.0;
1093        }
1094
1095        let signal_power: f32 = clean.iter().map(|&x| x * x).sum();
1096        let noise: Vec<f32> = clean
1097            .iter()
1098            .zip(noisy.iter())
1099            .map(|(&c, &n)| n - c)
1100            .collect();
1101        let noise_power: f32 = noise.iter().map(|&x| x * x).sum();
1102
1103        if noise_power < 1e-10 {
1104            return 100.0; // Very high SNR
1105        }
1106
1107        10.0 * (signal_power / noise_power).log10()
1108    }
1109
1110    #[test]
1111    fn test_snr_improvement_white_noise() -> TestResult<()> {
1112        let config = NoiseReductionConfig::default();
1113        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1114
1115        // Generate clean speech signal (440 Hz sine wave)
1116        let clean_speech = generate_sine_wave(440.0, 16000, 1.0, 0.5);
1117
1118        // Add white noise (creating ~5 dB input SNR)
1119        let noisy_speech = add_white_noise(&clean_speech, 0.3);
1120
1121        let snr_before = calculate_snr(&clean_speech, &noisy_speech);
1122
1123        // Initialize noise profile with pure noise
1124        // Training: 10 iterations ensures EMA convergence (α=0.98 requires ~50 samples
1125        // for 95% convergence)
1126        let pure_noise = add_white_noise(&vec![0.0; 8000], 0.3);
1127        let vad_silence = VadContext { is_silence: true };
1128        for _ in 0..10 {
1129            let _ = reducer
1130                .reduce(&pure_noise, Some(vad_silence))
1131                .map_err(|e| e.to_string())?;
1132        }
1133
1134        // Apply noise reduction to noisy speech
1135        let vad_speech = VadContext { is_silence: false };
1136        let denoised = reducer
1137            .reduce(&noisy_speech, Some(vad_speech))
1138            .map_err(|e| e.to_string())?;
1139
1140        let snr_after = calculate_snr(&clean_speech, &denoised);
1141        let improvement = snr_after - snr_before;
1142
1143        eprintln!(
1144            "SNR improvement: Before={:.1} dB, After={:.1} dB, Improvement={:.1} dB",
1145            snr_before, snr_after, improvement
1146        );
1147
1148        // Success criterion: ≥6 dB improvement
1149        assert!(
1150            improvement >= 6.0,
1151            "SNR improvement {:.1} dB < 6 dB target",
1152            improvement
1153        );
1154
1155        Ok(())
1156    }
1157
1158    #[test]
1159    fn test_snr_improvement_low_freq_hum() -> TestResult<()> {
1160        let config = NoiseReductionConfig::default();
1161        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1162
1163        // Generate 440 Hz speech with 60 Hz HVAC hum (common electrical interference)
1164        let clean = generate_sine_wave(440.0, 16000, 1.0, 0.4);
1165        let noisy = add_low_freq_hum(&clean, 16000, 60.0, 0.3);
1166        let snr_before = calculate_snr(&clean, &noisy);
1167
1168        // Train on pure 60 Hz hum
1169        // Training: 6 iterations sufficient for tonal noise (faster convergence than
1170        // broadband)
1171        let hum_only = add_low_freq_hum(&vec![0.0; 8000], 16000, 60.0, 0.3);
1172        let vad = VadContext { is_silence: true };
1173        for _ in 0..6 {
1174            let _ = reducer
1175                .reduce(&hum_only, Some(vad))
1176                .map_err(|e| e.to_string())?;
1177        }
1178
1179        let vad_speech = VadContext { is_silence: false };
1180        let denoised = reducer
1181            .reduce(&noisy, Some(vad_speech))
1182            .map_err(|e| e.to_string())?;
1183        let snr_after = calculate_snr(&clean, &denoised);
1184        let improvement = snr_after - snr_before;
1185        assert!(
1186            improvement >= 6.0,
1187            "Hum SNR improvement {:.1} dB < 6 dB target",
1188            improvement
1189        );
1190
1191        Ok(())
1192    }
1193
1194    #[test]
1195    fn test_snr_improvement_cafe_ambient() -> TestResult<()> {
1196        // Use default config - stationary noise doesn't need aggressive oversubtraction
1197        let config = NoiseReductionConfig::default();
1198        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1199
1200        // Generate clean speech signal
1201        let clean = generate_sine_wave(220.0, 16000, 1.0, 0.4);
1202
1203        // Add stationary café ambient noise (HVAC, dishes, background chatter)
1204        let noisy = add_cafe_noise(&clean, 16000, 0.25);
1205        let snr_before = calculate_snr(&clean, &noisy);
1206
1207        // Train noise profile on café ambient noise during "silence"
1208        // Training: 10 iterations for broadband stationary noise (white noise requires
1209        // more samples than tonal) Noise amplitude 0.25 creates ~5-6 dB input
1210        // SNR, realistic for café environment
1211        let cafe_only = add_cafe_noise(&vec![0.0; 8000], 16000, 0.25);
1212        let vad = VadContext { is_silence: true };
1213        for _ in 0..10 {
1214            let _ = reducer
1215                .reduce(&cafe_only, Some(vad))
1216                .map_err(|e| e.to_string())?;
1217        }
1218
1219        // Apply noise reduction to noisy speech
1220        let vad_speech = VadContext { is_silence: false };
1221        let denoised = reducer
1222            .reduce(&noisy, Some(vad_speech))
1223            .map_err(|e| e.to_string())?;
1224
1225        let snr_after = calculate_snr(&clean, &denoised);
1226        let improvement = snr_after - snr_before;
1227
1228        assert!(
1229            improvement >= 6.0,
1230            "Café ambient SNR improvement {:.1} dB < 6 dB target",
1231            improvement
1232        );
1233
1234        Ok(())
1235    }
1236
1237    #[test]
1238    fn test_trailing_partial_frame_preserved() -> TestResult<()> {
1239        let config = NoiseReductionConfig::default();
1240        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1241
1242        // Prime noise profile with a silence chunk so spectral subtraction behaves
1243        // normally.
1244        let silence = vec![0.0; 8000];
1245        let vad_silence = VadContext { is_silence: true };
1246        let _ = reducer
1247            .reduce(&silence, Some(vad_silence))
1248            .map_err(|e| e.to_string())?;
1249
1250        // Speech chunk length intentionally not divisible by hop size (adds 80-sample
1251        // tail).
1252        let speech_len = 8080;
1253        let speech: Vec<f32> = (0..speech_len)
1254            .map(|i| {
1255                let phase = (i as f32 / speech_len as f32) * 20.0;
1256                phase.sin()
1257            })
1258            .collect();
1259
1260        let vad_speech = VadContext { is_silence: false };
1261        let output = reducer
1262            .reduce(&speech, Some(vad_speech))
1263            .map_err(|e| e.to_string())?;
1264
1265        assert_eq!(
1266            output.len(),
1267            speech_len,
1268            "Output length should match input length"
1269        );
1270
1271        let tail = &output[speech_len - 80..];
1272        let tail_energy: f32 = tail.iter().map(|sample| sample.abs()).sum();
1273        assert!(
1274            tail_energy > 1e-3,
1275            "Trailing samples should retain energy, got tail_energy={tail_energy}"
1276        );
1277
1278        Ok(())
1279    }
1280
1281    #[test]
1282    fn test_missing_vad_context_does_not_update_noise_profile() -> TestResult<()> {
1283        let config = NoiseReductionConfig::default();
1284        let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1285
1286        // Prime noise profile with explicit silence (non-zero noise so baseline is
1287        // measurable).
1288        let ambient_noise = vec![0.05f32; 8000];
1289        let vad_silence = VadContext { is_silence: true };
1290        reducer
1291            .reduce(&ambient_noise, Some(vad_silence))
1292            .map_err(|e| e.to_string())?;
1293        let baseline_floor = reducer.noise_floor();
1294
1295        // Process speech without VAD context; noise profile should remain unchanged.
1296        let speech = vec![0.2f32; 8000];
1297        let output = reducer.reduce(&speech, None).map_err(|e| e.to_string())?;
1298        let updated_floor = reducer.noise_floor();
1299
1300        let floor_delta = (updated_floor - baseline_floor).abs();
1301        assert!(
1302            floor_delta < baseline_floor.max(1e-6) * 0.01,
1303            "Noise floor changed when VAD context missing: baseline={baseline_floor}, \
1304             updated={updated_floor}"
1305        );
1306
1307        let output_rms =
1308            (output.iter().map(|sample| sample * sample).sum::<f32>() / output.len() as f32).sqrt();
1309        assert!(
1310            output_rms > 0.08,
1311            "Speech energy collapsed without VAD context (rms={output_rms})"
1312        );
1313
1314        Ok(())
1315    }
1316}