wavekat_vad/preprocessing/
denoise.rs1use super::resample::AudioResampler;
21use nnnoiseless::DenoiseState;
22
23pub const DENOISE_SAMPLE_RATE: u32 = 48000;
25
26const FRAME_SIZE: usize = 480;
28
29pub struct Denoiser {
35 state: Box<DenoiseState<'static>>,
36 sample_rate: u32,
38 upsampler: Option<AudioResampler>,
40 downsampler: Option<AudioResampler>,
42 input_buffer: Vec<f32>,
44 output_buffer: Vec<f32>,
46 first_frame: bool,
48}
49
50impl std::fmt::Debug for Denoiser {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("Denoiser")
53 .field("sample_rate", &self.sample_rate)
54 .field("resampling", &self.upsampler.is_some())
55 .field("input_buffer_len", &self.input_buffer.len())
56 .field("output_buffer_len", &self.output_buffer.len())
57 .field("first_frame", &self.first_frame)
58 .finish_non_exhaustive()
59 }
60}
61
62impl Denoiser {
63 pub fn new(sample_rate: u32) -> Self {
75 let (upsampler, downsampler) = if sample_rate == DENOISE_SAMPLE_RATE {
76 (None, None)
77 } else {
78 let up = AudioResampler::new(sample_rate, DENOISE_SAMPLE_RATE)
79 .expect("failed to create upsampler");
80 let down = AudioResampler::new(DENOISE_SAMPLE_RATE, sample_rate)
81 .expect("failed to create downsampler");
82 (Some(up), Some(down))
83 };
84
85 Self {
86 state: DenoiseState::new(),
87 sample_rate,
88 upsampler,
89 downsampler,
90 input_buffer: Vec::with_capacity(FRAME_SIZE),
91 output_buffer: Vec::new(),
92 first_frame: true,
93 }
94 }
95
96 pub fn sample_rate(&self) -> u32 {
98 self.sample_rate
99 }
100
101 pub fn is_resampling(&self) -> bool {
103 self.upsampler.is_some()
104 }
105
106 pub fn process(&mut self, samples: &[i16]) -> Vec<i16> {
113 let samples_48k: Vec<i16> = if let Some(ref mut upsampler) = self.upsampler {
115 upsampler.process(samples)
116 } else {
117 samples.to_vec()
118 };
119
120 for &sample in &samples_48k {
123 self.input_buffer.push(sample as f32);
124 }
125
126 while self.input_buffer.len() >= FRAME_SIZE {
128 let mut input_frame = [0.0f32; FRAME_SIZE];
129 let mut output_frame = [0.0f32; FRAME_SIZE];
130
131 input_frame.copy_from_slice(&self.input_buffer[..FRAME_SIZE]);
133 self.input_buffer.drain(..FRAME_SIZE);
134
135 let _vad_prob = self.state.process_frame(&mut output_frame, &input_frame);
137
138 if self.first_frame {
140 self.first_frame = false;
141 self.output_buffer
143 .extend(std::iter::repeat_n(0.0, FRAME_SIZE));
144 } else {
145 self.output_buffer.extend_from_slice(&output_frame);
146 }
147 }
148
149 let denoised_48k: Vec<i16> = self
151 .output_buffer
152 .drain(..)
153 .map(|s| s.round().clamp(-32768.0, 32767.0) as i16)
154 .collect();
155
156 if let Some(ref mut downsampler) = self.downsampler {
158 downsampler.process(&denoised_48k)
159 } else {
160 denoised_48k
161 }
162 }
163
164 pub fn process_aligned(&mut self, samples: &[i16]) -> Vec<i16> {
168 assert!(
169 samples.len().is_multiple_of(FRAME_SIZE),
170 "Input length {} is not a multiple of frame size {}",
171 samples.len(),
172 FRAME_SIZE
173 );
174
175 let mut output = Vec::with_capacity(samples.len());
176 let mut input_frame = [0.0f32; FRAME_SIZE];
177 let mut output_frame = [0.0f32; FRAME_SIZE];
178
179 for chunk in samples.chunks_exact(FRAME_SIZE) {
180 for (i, &sample) in chunk.iter().enumerate() {
182 input_frame[i] = sample as f32;
183 }
184
185 let _vad_prob = self.state.process_frame(&mut output_frame, &input_frame);
187
188 if self.first_frame {
190 self.first_frame = false;
191 output.extend(std::iter::repeat_n(0i16, FRAME_SIZE));
192 } else {
193 for &s in &output_frame {
195 output.push(s.round().clamp(-32768.0, 32767.0) as i16);
196 }
197 }
198 }
199
200 output
201 }
202
203 pub fn reset(&mut self) {
205 self.state = DenoiseState::new();
206 self.input_buffer.clear();
207 self.output_buffer.clear();
208 self.first_frame = true;
209 if let Some(ref mut upsampler) = self.upsampler {
210 upsampler.reset();
211 }
212 if let Some(ref mut downsampler) = self.downsampler {
213 downsampler.reset();
214 }
215 }
216
217 pub fn buffered_samples(&self) -> usize {
219 self.input_buffer.len()
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn test_denoiser_creation_48k() {
229 let denoiser = Denoiser::new(48000);
230 assert_eq!(denoiser.buffered_samples(), 0);
231 assert_eq!(denoiser.sample_rate(), 48000);
232 assert!(!denoiser.is_resampling());
233 }
234
235 #[test]
236 fn test_denoiser_creation_16k() {
237 let denoiser = Denoiser::new(16000);
238 assert_eq!(denoiser.buffered_samples(), 0);
239 assert_eq!(denoiser.sample_rate(), 16000);
240 assert!(denoiser.is_resampling());
241 }
242
243 #[test]
244 fn test_denoiser_process_single_frame_48k() {
245 let mut denoiser = Denoiser::new(48000);
246
247 let input: Vec<i16> = vec![0; FRAME_SIZE];
249 let output = denoiser.process(&input);
250
251 assert_eq!(output.len(), FRAME_SIZE);
253 }
254
255 #[test]
256 fn test_denoiser_process_multiple_frames_48k() {
257 let mut denoiser = Denoiser::new(48000);
258
259 let input: Vec<i16> = vec![0; FRAME_SIZE * 2];
261 let output = denoiser.process(&input);
262
263 assert_eq!(output.len(), FRAME_SIZE * 2);
265 }
266
267 #[test]
268 fn test_denoiser_process_partial_frame() {
269 let mut denoiser = Denoiser::new(48000);
270
271 let input: Vec<i16> = vec![0; 100];
273 let output = denoiser.process(&input);
274
275 assert_eq!(output.len(), 0);
277 assert_eq!(denoiser.buffered_samples(), 100);
278
279 let input2: Vec<i16> = vec![0; FRAME_SIZE - 100];
281 let output2 = denoiser.process(&input2);
282
283 assert_eq!(output2.len(), FRAME_SIZE);
285 assert_eq!(denoiser.buffered_samples(), 0);
286 }
287
288 #[test]
289 fn test_denoiser_reset() {
290 let mut denoiser = Denoiser::new(48000);
291
292 let input: Vec<i16> = vec![0; 100];
294 denoiser.process(&input);
295 assert_eq!(denoiser.buffered_samples(), 100);
296
297 denoiser.reset();
299 assert_eq!(denoiser.buffered_samples(), 0);
300 }
301
302 #[test]
303 fn test_denoiser_aligned() {
304 let mut denoiser = Denoiser::new(48000);
305
306 let input: Vec<i16> = vec![0; FRAME_SIZE * 3];
307 let output = denoiser.process_aligned(&input);
308
309 assert_eq!(output.len(), FRAME_SIZE * 3);
310 }
311
312 #[test]
313 fn test_denoiser_16k_produces_output() {
314 let mut denoiser = Denoiser::new(16000);
315
316 let input: Vec<i16> = vec![0; 2048];
319 let output = denoiser.process(&input);
320
321 assert!(
324 output.len() > 0 || denoiser.buffered_samples() > 0,
325 "Should either produce output or buffer samples"
326 );
327 }
328
329 #[test]
330 fn test_denoiser_16k_continuous_processing() {
331 let mut denoiser = Denoiser::new(16000);
332
333 let chunk: Vec<i16> = vec![0; 320]; let mut total_output = 0;
336
337 for _ in 0..20 {
338 let output = denoiser.process(&chunk);
339 total_output += output.len();
340 }
341
342 assert!(
344 total_output > 5000,
345 "Expected significant output, got {total_output}"
346 );
347 }
348}