Skip to main content

wavekat_vad/preprocessing/
denoise.rs

1//! RNNoise-based noise suppression wrapper.
2//!
3//! This module wraps the `nnnoiseless` crate (a pure Rust port of RNNoise)
4//! to provide stationary noise suppression for audio streams.
5//!
6//! # Sample Rate Handling
7//!
8//! RNNoise internally requires 48kHz audio. This module handles resampling
9//! automatically:
10//! - At 48kHz: audio is processed directly (most efficient)
11//! - At other rates (e.g., 16kHz): audio is upsampled to 48kHz, processed,
12//!   then downsampled back to the original rate
13//!
14//! # Frame Size
15//!
16//! RNNoise processes 480 samples (10ms at 48kHz) at a time.
17//! The denoiser handles frame buffering internally, so you can pass
18//! any chunk size and it will accumulate/process accordingly.
19
20use super::resample::AudioResampler;
21use nnnoiseless::DenoiseState;
22
23/// Internal sample rate required by RNNoise (48 kHz).
24pub const DENOISE_SAMPLE_RATE: u32 = 48000;
25
26/// Frame size expected by RNNoise (480 samples = 10ms at 48kHz).
27const FRAME_SIZE: usize = 480;
28
29/// RNNoise-based noise suppressor.
30///
31/// Wraps `nnnoiseless::DenoiseState` with frame buffering to handle
32/// arbitrary input chunk sizes. Automatically resamples to/from 48kHz
33/// when the input sample rate differs.
34pub struct Denoiser {
35    state: Box<DenoiseState<'static>>,
36    /// Input sample rate.
37    sample_rate: u32,
38    /// Upsampler: input rate → 48kHz (None if already 48kHz).
39    upsampler: Option<AudioResampler>,
40    /// Downsampler: 48kHz → input rate (None if already 48kHz).
41    downsampler: Option<AudioResampler>,
42    /// Input buffer accumulating samples until we have FRAME_SIZE (at 48kHz).
43    input_buffer: Vec<f32>,
44    /// Output buffer holding processed samples (at 48kHz).
45    output_buffer: Vec<f32>,
46    /// Whether this is the first frame (discard due to fade-in artifacts).
47    first_frame: bool,
48}
49
50impl std::fmt::Debug for Denoiser {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("Denoiser")
53            .field("sample_rate", &self.sample_rate)
54            .field("resampling", &self.upsampler.is_some())
55            .field("input_buffer_len", &self.input_buffer.len())
56            .field("output_buffer_len", &self.output_buffer.len())
57            .field("first_frame", &self.first_frame)
58            .finish_non_exhaustive()
59    }
60}
61
62impl Denoiser {
63    /// Create a new denoiser.
64    ///
65    /// # Arguments
66    /// * `sample_rate` - Input sample rate in Hz. Common values: 16000, 48000.
67    ///
68    /// If the sample rate is not 48kHz, the denoiser will automatically
69    /// resample to 48kHz for processing and back to the original rate.
70    ///
71    /// # Panics
72    ///
73    /// Panics if resamplers cannot be created (should not happen with valid rates).
74    pub fn new(sample_rate: u32) -> Self {
75        let (upsampler, downsampler) = if sample_rate == DENOISE_SAMPLE_RATE {
76            (None, None)
77        } else {
78            let up = AudioResampler::new(sample_rate, DENOISE_SAMPLE_RATE)
79                .expect("failed to create upsampler");
80            let down = AudioResampler::new(DENOISE_SAMPLE_RATE, sample_rate)
81                .expect("failed to create downsampler");
82            (Some(up), Some(down))
83        };
84
85        Self {
86            state: DenoiseState::new(),
87            sample_rate,
88            upsampler,
89            downsampler,
90            input_buffer: Vec::with_capacity(FRAME_SIZE),
91            output_buffer: Vec::new(),
92            first_frame: true,
93        }
94    }
95
96    /// Returns the sample rate this denoiser was configured for.
97    pub fn sample_rate(&self) -> u32 {
98        self.sample_rate
99    }
100
101    /// Returns true if resampling is being used.
102    pub fn is_resampling(&self) -> bool {
103        self.upsampler.is_some()
104    }
105
106    /// Process audio samples through the noise suppressor.
107    ///
108    /// Input samples are i16 values at the configured sample rate.
109    /// Returns denoised samples at the same sample rate.
110    /// Due to frame buffering (and resampling if applicable), output length
111    /// may differ from input length.
112    pub fn process(&mut self, samples: &[i16]) -> Vec<i16> {
113        // Step 1: Upsample if needed (e.g., 16kHz → 48kHz)
114        let samples_48k: Vec<i16> = if let Some(ref mut upsampler) = self.upsampler {
115            upsampler.process(samples)
116        } else {
117            samples.to_vec()
118        };
119
120        // Step 2: Process through RNNoise at 48kHz
121        // Convert i16 to f32 and add to input buffer
122        for &sample in &samples_48k {
123            self.input_buffer.push(sample as f32);
124        }
125
126        // Process complete frames
127        while self.input_buffer.len() >= FRAME_SIZE {
128            let mut input_frame = [0.0f32; FRAME_SIZE];
129            let mut output_frame = [0.0f32; FRAME_SIZE];
130
131            // Copy frame from input buffer
132            input_frame.copy_from_slice(&self.input_buffer[..FRAME_SIZE]);
133            self.input_buffer.drain(..FRAME_SIZE);
134
135            // Process through RNNoise
136            let _vad_prob = self.state.process_frame(&mut output_frame, &input_frame);
137
138            // Skip first frame due to fade-in artifacts
139            if self.first_frame {
140                self.first_frame = false;
141                // Output zeros for the first frame to maintain timing
142                self.output_buffer
143                    .extend(std::iter::repeat_n(0.0, FRAME_SIZE));
144            } else {
145                self.output_buffer.extend_from_slice(&output_frame);
146            }
147        }
148
149        // Convert output buffer to i16
150        let denoised_48k: Vec<i16> = self
151            .output_buffer
152            .drain(..)
153            .map(|s| s.round().clamp(-32768.0, 32767.0) as i16)
154            .collect();
155
156        // Step 3: Downsample if needed (e.g., 48kHz → 16kHz)
157        if let Some(ref mut downsampler) = self.downsampler {
158            downsampler.process(&denoised_48k)
159        } else {
160            denoised_48k
161        }
162    }
163
164    /// Process a complete buffer of samples (must be multiple of FRAME_SIZE).
165    ///
166    /// This is more efficient when you know your input is frame-aligned.
167    pub fn process_aligned(&mut self, samples: &[i16]) -> Vec<i16> {
168        assert!(
169            samples.len().is_multiple_of(FRAME_SIZE),
170            "Input length {} is not a multiple of frame size {}",
171            samples.len(),
172            FRAME_SIZE
173        );
174
175        let mut output = Vec::with_capacity(samples.len());
176        let mut input_frame = [0.0f32; FRAME_SIZE];
177        let mut output_frame = [0.0f32; FRAME_SIZE];
178
179        for chunk in samples.chunks_exact(FRAME_SIZE) {
180            // Convert to f32
181            for (i, &sample) in chunk.iter().enumerate() {
182                input_frame[i] = sample as f32;
183            }
184
185            // Process
186            let _vad_prob = self.state.process_frame(&mut output_frame, &input_frame);
187
188            // Handle first frame
189            if self.first_frame {
190                self.first_frame = false;
191                output.extend(std::iter::repeat_n(0i16, FRAME_SIZE));
192            } else {
193                // Convert back to i16
194                for &s in &output_frame {
195                    output.push(s.round().clamp(-32768.0, 32767.0) as i16);
196                }
197            }
198        }
199
200        output
201    }
202
203    /// Reset the denoiser state.
204    pub fn reset(&mut self) {
205        self.state = DenoiseState::new();
206        self.input_buffer.clear();
207        self.output_buffer.clear();
208        self.first_frame = true;
209        if let Some(ref mut upsampler) = self.upsampler {
210            upsampler.reset();
211        }
212        if let Some(ref mut downsampler) = self.downsampler {
213            downsampler.reset();
214        }
215    }
216
217    /// Returns the number of samples currently buffered.
218    pub fn buffered_samples(&self) -> usize {
219        self.input_buffer.len()
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_denoiser_creation_48k() {
229        let denoiser = Denoiser::new(48000);
230        assert_eq!(denoiser.buffered_samples(), 0);
231        assert_eq!(denoiser.sample_rate(), 48000);
232        assert!(!denoiser.is_resampling());
233    }
234
235    #[test]
236    fn test_denoiser_creation_16k() {
237        let denoiser = Denoiser::new(16000);
238        assert_eq!(denoiser.buffered_samples(), 0);
239        assert_eq!(denoiser.sample_rate(), 16000);
240        assert!(denoiser.is_resampling());
241    }
242
243    #[test]
244    fn test_denoiser_process_single_frame_48k() {
245        let mut denoiser = Denoiser::new(48000);
246
247        // Process exactly one frame
248        let input: Vec<i16> = vec![0; FRAME_SIZE];
249        let output = denoiser.process(&input);
250
251        // First frame outputs zeros (fade-in handling)
252        assert_eq!(output.len(), FRAME_SIZE);
253    }
254
255    #[test]
256    fn test_denoiser_process_multiple_frames_48k() {
257        let mut denoiser = Denoiser::new(48000);
258
259        // Process two frames
260        let input: Vec<i16> = vec![0; FRAME_SIZE * 2];
261        let output = denoiser.process(&input);
262
263        // Should get both frames back
264        assert_eq!(output.len(), FRAME_SIZE * 2);
265    }
266
267    #[test]
268    fn test_denoiser_process_partial_frame() {
269        let mut denoiser = Denoiser::new(48000);
270
271        // Process less than one frame
272        let input: Vec<i16> = vec![0; 100];
273        let output = denoiser.process(&input);
274
275        // No output yet (buffering)
276        assert_eq!(output.len(), 0);
277        assert_eq!(denoiser.buffered_samples(), 100);
278
279        // Complete the frame
280        let input2: Vec<i16> = vec![0; FRAME_SIZE - 100];
281        let output2 = denoiser.process(&input2);
282
283        // Now we get output
284        assert_eq!(output2.len(), FRAME_SIZE);
285        assert_eq!(denoiser.buffered_samples(), 0);
286    }
287
288    #[test]
289    fn test_denoiser_reset() {
290        let mut denoiser = Denoiser::new(48000);
291
292        // Buffer some samples
293        let input: Vec<i16> = vec![0; 100];
294        denoiser.process(&input);
295        assert_eq!(denoiser.buffered_samples(), 100);
296
297        // Reset
298        denoiser.reset();
299        assert_eq!(denoiser.buffered_samples(), 0);
300    }
301
302    #[test]
303    fn test_denoiser_aligned() {
304        let mut denoiser = Denoiser::new(48000);
305
306        let input: Vec<i16> = vec![0; FRAME_SIZE * 3];
307        let output = denoiser.process_aligned(&input);
308
309        assert_eq!(output.len(), FRAME_SIZE * 3);
310    }
311
312    #[test]
313    fn test_denoiser_16k_produces_output() {
314        let mut denoiser = Denoiser::new(16000);
315
316        // Process enough samples to get output (need extra for resampling buffers)
317        // At 16kHz, we need ~160 samples for 10ms, but resampling buffers need more
318        let input: Vec<i16> = vec![0; 2048];
319        let output = denoiser.process(&input);
320
321        // Should produce some output (exact amount depends on resampler buffering)
322        // Due to multiple buffering stages, first call may not produce full output
323        assert!(
324            output.len() > 0 || denoiser.buffered_samples() > 0,
325            "Should either produce output or buffer samples"
326        );
327    }
328
329    #[test]
330    fn test_denoiser_16k_continuous_processing() {
331        let mut denoiser = Denoiser::new(16000);
332
333        // Process several chunks to fill all buffers
334        let chunk: Vec<i16> = vec![0; 320]; // 20ms at 16kHz
335        let mut total_output = 0;
336
337        for _ in 0..20 {
338            let output = denoiser.process(&chunk);
339            total_output += output.len();
340        }
341
342        // After 400ms of input, we should have substantial output
343        assert!(
344            total_output > 5000,
345            "Expected significant output, got {total_output}"
346        );
347    }
348}