1use std::f32::consts::PI;
47use std::sync::Arc;
48
49use crate::error::{Error, Result};
50use crate::time::{AudioDuration, AudioInstant};
51use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
52use tracing::{info, warn};
53
54use super::artifacts::WavArtifactWriter;
55use super::VadContext;
56
57#[derive(Debug, Clone)]
76#[allow(missing_copy_implementations)]
77pub struct NoiseReductionConfig {
78 pub sample_rate_hz: u32,
83
84 pub window_ms: f32,
92
93 pub hop_ms: f32,
101
102 pub oversubtraction_factor: f32,
111
112 pub spectral_floor: f32,
120
121 pub noise_smoothing: f32,
129
130 pub enable: bool,
136}
137
138impl Default for NoiseReductionConfig {
139 fn default() -> Self {
140 Self {
141 sample_rate_hz: 16_000,
142 window_ms: 25.0,
143 hop_ms: 10.0,
144 oversubtraction_factor: 2.0,
145 spectral_floor: 0.02,
146 noise_smoothing: 0.98,
147 enable: true,
148 }
149 }
150}
151
152impl NoiseReductionConfig {
153 #[allow(clippy::trivially_copy_pass_by_ref)]
165 pub fn validate(&self) -> Result<()> {
166 if !(8000..=48_000).contains(&self.sample_rate_hz) {
167 return Err(Error::Configuration(format!(
168 "Invalid sample rate: {} Hz (range: 8000-48000)",
169 self.sample_rate_hz
170 )));
171 }
172
173 if !(10.0..=50.0).contains(&self.window_ms) {
174 return Err(Error::Configuration(format!(
175 "Invalid window size: {:.1} ms (range: 10-50)",
176 self.window_ms
177 )));
178 }
179
180 if self.hop_ms >= self.window_ms {
181 return Err(Error::Configuration(format!(
182 "Hop {:.1} ms must be < window {:.1} ms",
183 self.hop_ms, self.window_ms
184 )));
185 }
186
187 if !(1.0..=3.0).contains(&self.oversubtraction_factor) {
188 return Err(Error::Configuration(format!(
189 "Invalid oversubtraction factor: {:.2} (range: 1.0-3.0)",
190 self.oversubtraction_factor
191 )));
192 }
193
194 if !(0.001..=0.1).contains(&self.spectral_floor) {
195 return Err(Error::Configuration(format!(
196 "Invalid spectral floor: {:.3} (range: 0.001-0.1)",
197 self.spectral_floor
198 )));
199 }
200
201 if !(0.9..1.0).contains(&self.noise_smoothing) {
202 return Err(Error::Configuration(format!(
203 "Invalid noise smoothing: {:.3} (range: 0.9-0.999)",
204 self.noise_smoothing
205 )));
206 }
207
208 Ok(())
209 }
210
211 pub fn frame_length(&self) -> usize {
213 ((self.window_ms / 1000.0) * self.sample_rate_hz as f32).round() as usize
214 }
215
216 pub fn hop_length(&self) -> usize {
218 ((self.hop_ms / 1000.0) * self.sample_rate_hz as f32).round() as usize
219 }
220
221 pub fn fft_size(&self) -> usize {
223 self.frame_length().next_power_of_two()
224 }
225}
226
227#[allow(missing_copy_implementations)]
264pub struct NoiseReducer {
265 config: NoiseReductionConfig,
266 fft_forward: Arc<dyn RealToComplex<f32>>,
267 fft_inverse: Arc<dyn ComplexToReal<f32>>,
268 window: Vec<f32>,
269 noise_profile: Vec<f32>,
270 noise_initialized: bool,
271 overlap_buffer: Vec<f32>,
272}
273
274impl std::fmt::Debug for NoiseReducer {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 f.debug_struct("NoiseReducer")
277 .field("config", &self.config)
278 .field("window_length", &self.window.len())
279 .field("noise_profile_bins", &self.noise_profile.len())
280 .field("noise_initialized", &self.noise_initialized)
281 .finish_non_exhaustive()
282 }
283}
284
285impl NoiseReducer {
286 pub fn new(config: NoiseReductionConfig) -> Result<Self> {
309 config.validate()?;
310
311 let fft_size = config.fft_size();
312 let frame_length = config.frame_length();
313
314 let mut planner = RealFftPlanner::<f32>::new();
315 let fft_forward = planner.plan_fft_forward(fft_size);
316 let fft_inverse = planner.plan_fft_inverse(fft_size);
317
318 let window = generate_hann_window(frame_length);
319
320 let num_bins = fft_size / 2 + 1;
321 let noise_profile = vec![1e-6; num_bins];
322
323 let overlap_buffer = vec![0.0; frame_length];
324
325 Ok(Self {
326 config,
327 fft_forward,
328 fft_inverse,
329 window,
330 noise_profile,
331 noise_initialized: false,
332 overlap_buffer,
333 })
334 }
335
336 #[allow(clippy::unnecessary_wraps)]
372 pub fn reduce(&mut self, samples: &[f32], vad_context: Option<VadContext>) -> Result<Vec<f32>> {
373 self.reduce_with_artifacts(samples, vad_context, None)
374 }
375
376 pub fn reduce_with_artifacts(
378 &mut self,
379 samples: &[f32],
380 vad_context: Option<VadContext>,
381 mut artifacts: Option<&mut NoiseReductionArtifacts<'_>>,
382 ) -> Result<Vec<f32>> {
383 let processing_start = AudioInstant::now();
384
385 if samples.is_empty() {
386 return Ok(Vec::new());
387 }
388
389 if let Some(ref mut recorder) = artifacts {
390 recorder.record_input(samples);
391 }
392
393 if !self.config.enable {
394 return Ok(samples.to_vec());
395 }
396
397 let (mut output, frame_count) = self.process_stft_frames(samples, vad_context)?;
398 self.normalize_overlap_add(&mut output);
399
400 if let Some(ref mut recorder) = artifacts {
401 recorder.record_output(&output);
402 }
403
404 let elapsed = elapsed_duration(processing_start);
405 let latency_ms = elapsed.as_secs_f64() * 1000.0;
406 self.record_performance_metrics(samples, &output, latency_ms, frame_count);
407
408 Ok(output)
409 }
410
411 fn process_stft_frames(
413 &mut self,
414 samples: &[f32],
415 vad_context: Option<VadContext>,
416 ) -> Result<(Vec<f32>, usize)> {
417 let frame_length = self.config.frame_length();
418 let hop_length = self.config.hop_length();
419
420 let mut output = vec![0.0; samples.len()];
421 let mut frame_idx = 0;
422 let mut pos = 0;
423
424 while pos < samples.len() {
425 let remaining = samples.len() - pos;
426
427 let frame = Self::extract_frame(samples, pos, frame_length, remaining)?;
428
429 let processed =
430 self.process_single_frame(&frame, vad_context, remaining >= frame_length)?;
431
432 Self::accumulate_frame_output(&processed, &mut output, pos);
433
434 frame_idx += 1;
435
436 if remaining < hop_length {
437 break;
438 }
439 pos += hop_length;
440 }
441
442 Ok((output, frame_idx))
443 }
444
445 fn extract_frame(
447 samples: &[f32],
448 pos: usize,
449 frame_length: usize,
450 remaining: usize,
451 ) -> Result<Vec<f32>> {
452 let mut frame_buf = vec![0.0; frame_length];
453
454 if remaining >= frame_length {
455 let src = samples
456 .get(pos..pos + frame_length)
457 .ok_or_else(|| Error::Processing("frame window out of bounds".into()))?;
458 frame_buf.copy_from_slice(src);
459 } else {
460 let src = samples
461 .get(pos..)
462 .ok_or_else(|| Error::Processing("frame tail out of bounds".into()))?;
463 if let Some(dst) = frame_buf.get_mut(..remaining) {
464 dst.copy_from_slice(src);
465 }
466 }
467
468 Ok(frame_buf)
469 }
470
471 fn process_single_frame(
473 &mut self,
474 frame: &[f32],
475 vad_context: Option<VadContext>,
476 is_full_frame: bool,
477 ) -> Result<Vec<f32>> {
478 let fft_size = self.config.fft_size();
479
480 let windowed: Vec<f32> = frame
481 .iter()
482 .zip(&self.window)
483 .map(|(&s, &w)| s * w)
484 .collect();
485
486 let complex_spectrum = self.forward_fft_complex(&windowed)?;
487 let magnitudes: Vec<f32> = complex_spectrum.iter().map(|c| c.norm()).collect();
488
489 let is_silence = vad_context.is_some_and(|ctx| ctx.is_silence);
490 if is_silence && is_full_frame {
491 self.update_noise_profile(&magnitudes);
492 }
493
494 let cleaned_magnitudes = self.spectral_subtract(&magnitudes);
495
496 let cleaned_complex =
497 Self::reconstruct_complex_spectrum(&complex_spectrum, &cleaned_magnitudes);
498
499 let time_signal = self.inverse_fft_complex(&cleaned_complex, fft_size)?;
500
501 let windowed_output: Vec<f32> = time_signal
502 .iter()
503 .take(frame.len())
504 .zip(&self.window)
505 .map(|(&s, &w)| s * w)
506 .collect();
507
508 Ok(windowed_output)
509 }
510
511 fn reconstruct_complex_spectrum(
513 original_spectrum: &[realfft::num_complex::Complex<f32>],
514 cleaned_magnitudes: &[f32],
515 ) -> Vec<realfft::num_complex::Complex<f32>> {
516 original_spectrum
517 .iter()
518 .zip(cleaned_magnitudes)
519 .enumerate()
520 .map(|(i, (original, &new_mag))| {
521 if i == 0 || i == original_spectrum.len() - 1 {
522 realfft::num_complex::Complex::new(new_mag, 0.0)
524 } else {
525 let phase = original.arg();
526 realfft::num_complex::Complex::from_polar(new_mag, phase)
527 }
528 })
529 .collect()
530 }
531
532 fn accumulate_frame_output(frame: &[f32], output: &mut [f32], pos: usize) {
534 for (i, &sample) in frame.iter().enumerate() {
535 let out_idx = pos + i;
536 if let Some(dst) = output.get_mut(out_idx) {
537 *dst += sample;
538 }
539 }
540 }
541
542 fn normalize_overlap_add(&self, output: &mut [f32]) {
544 let hop_length = self.config.hop_length();
545 let window_sum = self.calculate_window_overlap_sum(hop_length);
546
547 if window_sum > 1e-6 {
548 for sample in output {
549 *sample /= window_sum;
550 }
551 }
552 }
553
554 fn record_performance_metrics(
555 &self,
556 input: &[f32],
557 output: &[f32],
558 latency_ms: f64,
559 frame_count: usize,
560 ) {
561 if input.len() < 8000 {
562 return;
563 }
564
565 if latency_ms > 15.0 {
566 warn!(
567 target: "audio.preprocess.noise_reduction",
568 latency_ms,
569 samples = input.len(),
570 frames = frame_count,
571 oversubtraction = self.config.oversubtraction_factor,
572 spectral_floor = self.config.spectral_floor,
573 "noise reduction latency exceeded target"
574 );
575 }
576
577 let avg_noise_floor = self.noise_floor().max(1e-12);
578 let noise_floor_db = 20.0 * avg_noise_floor.log10();
579
580 let signal_power_out =
581 output.iter().map(|sample| sample * sample).sum::<f32>() / output.len() as f32;
582 let residual_power: f32 = input
583 .iter()
584 .zip(output)
585 .map(|(&noisy, &clean)| {
586 let residual = noisy - clean;
587 residual * residual
588 })
589 .sum::<f32>()
590 / output.len() as f32;
591
592 let snr_improvement_db = if residual_power > 1e-12 && signal_power_out > 0.0 {
593 10.0 * (signal_power_out / residual_power).log10()
594 } else {
595 0.0
596 };
597
598 info!(
599 target: "audio.preprocess.noise_reduction",
600 noise_floor_db,
601 snr_improvement_db,
602 latency_ms,
603 frames = frame_count,
604 samples = input.len(),
605 oversubtraction = self.config.oversubtraction_factor,
606 spectral_floor = self.config.spectral_floor,
607 "noise reduction metrics"
608 );
609 }
610
611 pub fn reset(&mut self) {
616 self.noise_profile.fill(1e-6);
617 self.noise_initialized = false;
618 self.overlap_buffer.fill(0.0);
619 }
620
621 #[must_use]
623 pub fn noise_floor(&self) -> f32 {
624 if self.noise_profile.is_empty() {
625 return 0.0;
626 }
627 self.noise_profile.iter().sum::<f32>() / self.noise_profile.len() as f32
628 }
629
630 #[must_use]
632 pub fn config(&self) -> &NoiseReductionConfig {
633 &self.config
634 }
635
636 fn forward_fft_complex(
638 &self,
639 windowed: &[f32],
640 ) -> Result<Vec<realfft::num_complex::Complex<f32>>> {
641 let mut input = self.fft_forward.make_input_vec();
643 for (i, &sample) in windowed.iter().enumerate() {
644 if let Some(dst) = input.get_mut(i) {
645 *dst = sample;
646 }
647 }
648
649 let mut spectrum = self.fft_forward.make_output_vec();
651 self.fft_forward
652 .process(&mut input, &mut spectrum)
653 .map_err(|e| Error::Processing(format!("FFT failed: {e}")))?;
654
655 Ok(spectrum)
656 }
657
658 fn inverse_fft_complex(
660 &self,
661 complex_spectrum: &[realfft::num_complex::Complex<f32>],
662 fft_size: usize,
663 ) -> Result<Vec<f32>> {
664 let mut spectrum = self.fft_inverse.make_input_vec();
666 for (i, &c) in complex_spectrum.iter().enumerate() {
667 if let Some(bin) = spectrum.get_mut(i) {
668 *bin = c;
669 }
670 }
671
672 let mut output = self.fft_inverse.make_output_vec();
674 self.fft_inverse
675 .process(&mut spectrum, &mut output)
676 .map_err(|e| Error::Processing(format!("IFFT failed: {e}")))?;
677
678 for sample in &mut output {
680 *sample /= fft_size as f32;
681 }
682
683 Ok(output)
684 }
685
686 fn update_noise_profile(&mut self, spectrum: &[f32]) {
688 let alpha = self.config.noise_smoothing;
689
690 if self.noise_initialized {
691 for (noise, ¤t) in self.noise_profile.iter_mut().zip(spectrum.iter()) {
693 *noise = alpha.mul_add(*noise, (1.0 - alpha) * current);
694 }
695 } else {
696 self.noise_profile.copy_from_slice(spectrum);
698 self.noise_initialized = true;
699 }
700 }
701
702 fn spectral_subtract(&self, spectrum: &[f32]) -> Vec<f32> {
704 let alpha = self.config.oversubtraction_factor;
705 let beta = self.config.spectral_floor;
706
707 spectrum
708 .iter()
709 .zip(&self.noise_profile)
710 .map(|(&signal, &noise)| {
711 let subtracted = alpha.mul_add(-noise, signal);
712 let floor = beta * noise;
713 subtracted.max(floor)
714 })
715 .collect()
716 }
717
718 fn calculate_window_overlap_sum(&self, hop_length: usize) -> f32 {
720 let frame_length = self.window.len();
721 let mut sum: f32 = 0.0;
722
723 for i in 0..frame_length {
725 let mut overlap: f32 = 0.0;
726 let mut offset = 0;
727
728 while offset <= i {
729 if let Some(&w) = self.window.get(i - offset) {
730 overlap = w.mul_add(w, overlap); }
734 offset += hop_length;
735 }
736
737 sum = sum.max(overlap);
738 }
739
740 sum
741 }
742}
743
744pub struct NoiseReductionArtifacts<'a> {
746 before: Option<&'a mut WavArtifactWriter>,
747 after: Option<&'a mut WavArtifactWriter>,
748}
749
750impl<'a> NoiseReductionArtifacts<'a> {
751 pub fn new(
753 before: Option<&'a mut WavArtifactWriter>,
754 after: Option<&'a mut WavArtifactWriter>,
755 ) -> Self {
756 Self { before, after }
757 }
758
759 pub fn record_input(&mut self, samples: &[f32]) {
761 if let Some(writer) = self.before.as_deref_mut() {
762 if let Err(err) = writer.write_samples(samples) {
763 warn!(target: "audio.preprocess.noise_reduction", error = %err, "failed to write input artifact");
764 }
765 }
766 }
767
768 pub fn record_output(&mut self, samples: &[f32]) {
770 if let Some(writer) = self.after.as_deref_mut() {
771 if let Err(err) = writer.write_samples(samples) {
772 warn!(target: "audio.preprocess.noise_reduction", error = %err, "failed to write output artifact");
773 }
774 }
775 }
776}
777
778impl std::fmt::Debug for NoiseReductionArtifacts<'_> {
779 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
780 f.debug_struct("NoiseReductionArtifacts").finish()
781 }
782}
783
784fn generate_hann_window(length: usize) -> Vec<f32> {
788 if length == 0 {
789 return Vec::new();
790 }
791
792 if length == 1 {
793 return vec![1.0];
794 }
795
796 let denom = (length - 1) as f32;
797 (0..length)
798 .map(|n| {
799 let angle = 2.0 * PI * n as f32 / denom;
800 0.5f32.mul_add(-angle.cos(), 0.5)
801 })
802 .collect()
803}
804
805fn elapsed_duration(start: AudioInstant) -> AudioDuration {
806 AudioInstant::now().duration_since(start)
807}
808
809#[cfg(test)]
810mod tests {
811 use super::*;
812
813 type TestResult<T> = std::result::Result<T, String>;
814
815 #[test]
816 #[allow(clippy::unnecessary_wraps)]
817 fn test_configuration_validation() -> TestResult<()> {
818 let valid = NoiseReductionConfig::default();
820 assert!(valid.validate().is_ok());
821
822 let invalid_sr = NoiseReductionConfig {
824 sample_rate_hz: 5000,
825 ..Default::default()
826 };
827 assert!(invalid_sr.validate().is_err());
828
829 let invalid_window = NoiseReductionConfig {
831 window_ms: 100.0,
832 ..Default::default()
833 };
834 assert!(invalid_window.validate().is_err());
835
836 let invalid_hop = NoiseReductionConfig {
838 hop_ms: 30.0,
839 window_ms: 25.0,
840 ..Default::default()
841 };
842 assert!(invalid_hop.validate().is_err());
843
844 let invalid_alpha = NoiseReductionConfig {
846 oversubtraction_factor: 5.0,
847 ..Default::default()
848 };
849 assert!(invalid_alpha.validate().is_err());
850
851 let invalid_beta = NoiseReductionConfig {
853 spectral_floor: 0.5,
854 ..Default::default()
855 };
856 assert!(invalid_beta.validate().is_err());
857
858 let invalid_smoothing = NoiseReductionConfig {
860 noise_smoothing: 1.0,
861 ..Default::default()
862 };
863 assert!(invalid_smoothing.validate().is_err());
864
865 Ok(())
866 }
867
868 #[test]
869 fn test_hann_window_properties() {
870 let window_0 = generate_hann_window(0);
872 assert!(window_0.is_empty());
873
874 let window_1 = generate_hann_window(1);
876 assert_eq!(window_1.len(), 1);
877 assert!((window_1[0] - 1.0).abs() < 1e-6);
878
879 let window = generate_hann_window(100);
881 assert_eq!(window.len(), 100);
882
883 assert!(window[0].abs() < 1e-6);
885 assert!(window[99].abs() < 1e-6);
886
887 assert!((window[50] - 1.0).abs() < 0.1);
889 }
890
891 #[test]
892 fn test_noise_reducer_creation() -> TestResult<()> {
893 let config = NoiseReductionConfig::default();
894 let reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
895
896 assert_eq!(reducer.config().sample_rate_hz, 16000);
898 assert!(reducer.noise_floor() > 0.0); Ok(())
901 }
902
903 #[test]
904 fn test_empty_input() -> TestResult<()> {
905 let config = NoiseReductionConfig::default();
906 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
907
908 let output = reducer.reduce(&[], None).map_err(|e| e.to_string())?;
909 assert!(output.is_empty());
910
911 Ok(())
912 }
913
914 #[test]
915 fn test_bypass_mode() -> TestResult<()> {
916 let config = NoiseReductionConfig {
917 enable: false,
918 ..Default::default()
919 };
920 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
921
922 let input = vec![0.1, 0.2, 0.3, 0.4];
923 let output = reducer.reduce(&input, None).map_err(|e| e.to_string())?;
924
925 assert_eq!(output, input);
927
928 Ok(())
929 }
930
931 #[test]
932 fn test_noise_profile_update() -> TestResult<()> {
933 let config = NoiseReductionConfig::default();
934 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
935
936 let silence = vec![0.01; 8000]; let vad_silence = VadContext { is_silence: true };
939
940 let initial_noise = reducer.noise_floor();
941
942 for _ in 0..5 {
944 let _ = reducer
945 .reduce(&silence, Some(vad_silence))
946 .map_err(|e| e.to_string())?;
947 }
948
949 let converged_noise = reducer.noise_floor();
950
951 assert!(
953 converged_noise > initial_noise,
954 "Noise floor should adapt: initial={:.6}, converged={:.6}",
955 initial_noise,
956 converged_noise
957 );
958
959 Ok(())
960 }
961
962 #[test]
963 fn test_vad_informed_noise_update() -> TestResult<()> {
964 let config = NoiseReductionConfig::default();
965 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
966
967 let silence = vec![0.01; 8000];
969 let vad_silence = VadContext { is_silence: true };
970 for _ in 0..5 {
971 let _ = reducer
972 .reduce(&silence, Some(vad_silence))
973 .map_err(|e| e.to_string())?;
974 }
975
976 let noise_after_silence = reducer.noise_floor();
977
978 let speech = vec![0.5; 8000];
980 let vad_speech = VadContext { is_silence: false };
981 let _ = reducer
982 .reduce(&speech, Some(vad_speech))
983 .map_err(|e| e.to_string())?;
984
985 let noise_after_speech = reducer.noise_floor();
986
987 let diff = (noise_after_speech - noise_after_silence).abs();
989 assert!(
990 diff < noise_after_silence * 0.01,
991 "Noise profile changed during speech: {:.6} -> {:.6}",
992 noise_after_silence,
993 noise_after_speech
994 );
995
996 Ok(())
997 }
998
999 #[test]
1000 fn test_reset_clears_state() -> TestResult<()> {
1001 let config = NoiseReductionConfig::default();
1002 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1003
1004 let samples = vec![0.1; 8000];
1006 let vad = VadContext { is_silence: true };
1007 let _ = reducer
1008 .reduce(&samples, Some(vad))
1009 .map_err(|e| e.to_string())?;
1010
1011 let noise_before = reducer.noise_floor();
1012 assert!(noise_before > 1e-5, "Noise profile should be updated");
1013
1014 reducer.reset();
1016
1017 let noise_after = reducer.noise_floor();
1018 assert!(
1019 noise_after < 1e-5,
1020 "Noise profile should be reset to initial value"
1021 );
1022
1023 Ok(())
1024 }
1025
1026 fn generate_sine_wave(
1028 frequency: f32,
1029 sample_rate: u32,
1030 duration_secs: f32,
1031 amplitude: f32,
1032 ) -> Vec<f32> {
1033 use std::f32::consts::PI;
1034 let samples = (sample_rate as f32 * duration_secs).round() as usize;
1035 (0..samples)
1036 .map(|i| {
1037 let t = i as f32 / sample_rate as f32;
1038 (2.0 * PI * frequency * t).sin() * amplitude
1039 })
1040 .collect()
1041 }
1042
1043 fn add_white_noise(signal: &[f32], noise_amplitude: f32) -> Vec<f32> {
1045 use rand::Rng;
1046 let mut rng = rand::rng();
1047 signal
1048 .iter()
1049 .map(|&s| {
1050 let noise: f32 = rng.random_range(-noise_amplitude..noise_amplitude);
1051 s + noise
1052 })
1053 .collect()
1054 }
1055 fn add_low_freq_hum(
1056 signal: &[f32],
1057 sample_rate: u32,
1058 frequency: f32,
1059 amplitude: f32,
1060 ) -> Vec<f32> {
1061 signal
1062 .iter()
1063 .enumerate()
1064 .map(|(i, &sample)| {
1065 let t = i as f32 / sample_rate as f32;
1066 let hum = (2.0 * PI * frequency * t).sin() * amplitude;
1067 sample + hum
1068 })
1069 .collect()
1070 }
1071
1072 fn add_cafe_noise(signal: &[f32], _sample_rate: u32, amplitude: f32) -> Vec<f32> {
1078 use rand::Rng;
1079 let mut rng = rand::rng();
1080 signal
1081 .iter()
1082 .map(|&sample| {
1083 let noise: f32 = rng.random_range(-1.0..1.0);
1084 amplitude.mul_add(noise, sample)
1085 })
1086 .collect()
1087 }
1088
1089 fn calculate_snr(clean: &[f32], noisy: &[f32]) -> f32 {
1091 if clean.len() != noisy.len() {
1092 return 0.0;
1093 }
1094
1095 let signal_power: f32 = clean.iter().map(|&x| x * x).sum();
1096 let noise: Vec<f32> = clean
1097 .iter()
1098 .zip(noisy.iter())
1099 .map(|(&c, &n)| n - c)
1100 .collect();
1101 let noise_power: f32 = noise.iter().map(|&x| x * x).sum();
1102
1103 if noise_power < 1e-10 {
1104 return 100.0; }
1106
1107 10.0 * (signal_power / noise_power).log10()
1108 }
1109
1110 #[test]
1111 fn test_snr_improvement_white_noise() -> TestResult<()> {
1112 let config = NoiseReductionConfig::default();
1113 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1114
1115 let clean_speech = generate_sine_wave(440.0, 16000, 1.0, 0.5);
1117
1118 let noisy_speech = add_white_noise(&clean_speech, 0.3);
1120
1121 let snr_before = calculate_snr(&clean_speech, &noisy_speech);
1122
1123 let pure_noise = add_white_noise(&vec![0.0; 8000], 0.3);
1127 let vad_silence = VadContext { is_silence: true };
1128 for _ in 0..10 {
1129 let _ = reducer
1130 .reduce(&pure_noise, Some(vad_silence))
1131 .map_err(|e| e.to_string())?;
1132 }
1133
1134 let vad_speech = VadContext { is_silence: false };
1136 let denoised = reducer
1137 .reduce(&noisy_speech, Some(vad_speech))
1138 .map_err(|e| e.to_string())?;
1139
1140 let snr_after = calculate_snr(&clean_speech, &denoised);
1141 let improvement = snr_after - snr_before;
1142
1143 eprintln!(
1144 "SNR improvement: Before={:.1} dB, After={:.1} dB, Improvement={:.1} dB",
1145 snr_before, snr_after, improvement
1146 );
1147
1148 assert!(
1150 improvement >= 6.0,
1151 "SNR improvement {:.1} dB < 6 dB target",
1152 improvement
1153 );
1154
1155 Ok(())
1156 }
1157
1158 #[test]
1159 fn test_snr_improvement_low_freq_hum() -> TestResult<()> {
1160 let config = NoiseReductionConfig::default();
1161 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1162
1163 let clean = generate_sine_wave(440.0, 16000, 1.0, 0.4);
1165 let noisy = add_low_freq_hum(&clean, 16000, 60.0, 0.3);
1166 let snr_before = calculate_snr(&clean, &noisy);
1167
1168 let hum_only = add_low_freq_hum(&vec![0.0; 8000], 16000, 60.0, 0.3);
1172 let vad = VadContext { is_silence: true };
1173 for _ in 0..6 {
1174 let _ = reducer
1175 .reduce(&hum_only, Some(vad))
1176 .map_err(|e| e.to_string())?;
1177 }
1178
1179 let vad_speech = VadContext { is_silence: false };
1180 let denoised = reducer
1181 .reduce(&noisy, Some(vad_speech))
1182 .map_err(|e| e.to_string())?;
1183 let snr_after = calculate_snr(&clean, &denoised);
1184 let improvement = snr_after - snr_before;
1185 assert!(
1186 improvement >= 6.0,
1187 "Hum SNR improvement {:.1} dB < 6 dB target",
1188 improvement
1189 );
1190
1191 Ok(())
1192 }
1193
1194 #[test]
1195 fn test_snr_improvement_cafe_ambient() -> TestResult<()> {
1196 let config = NoiseReductionConfig::default();
1198 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1199
1200 let clean = generate_sine_wave(220.0, 16000, 1.0, 0.4);
1202
1203 let noisy = add_cafe_noise(&clean, 16000, 0.25);
1205 let snr_before = calculate_snr(&clean, &noisy);
1206
1207 let cafe_only = add_cafe_noise(&vec![0.0; 8000], 16000, 0.25);
1212 let vad = VadContext { is_silence: true };
1213 for _ in 0..10 {
1214 let _ = reducer
1215 .reduce(&cafe_only, Some(vad))
1216 .map_err(|e| e.to_string())?;
1217 }
1218
1219 let vad_speech = VadContext { is_silence: false };
1221 let denoised = reducer
1222 .reduce(&noisy, Some(vad_speech))
1223 .map_err(|e| e.to_string())?;
1224
1225 let snr_after = calculate_snr(&clean, &denoised);
1226 let improvement = snr_after - snr_before;
1227
1228 assert!(
1229 improvement >= 6.0,
1230 "Café ambient SNR improvement {:.1} dB < 6 dB target",
1231 improvement
1232 );
1233
1234 Ok(())
1235 }
1236
1237 #[test]
1238 fn test_trailing_partial_frame_preserved() -> TestResult<()> {
1239 let config = NoiseReductionConfig::default();
1240 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1241
1242 let silence = vec![0.0; 8000];
1245 let vad_silence = VadContext { is_silence: true };
1246 let _ = reducer
1247 .reduce(&silence, Some(vad_silence))
1248 .map_err(|e| e.to_string())?;
1249
1250 let speech_len = 8080;
1253 let speech: Vec<f32> = (0..speech_len)
1254 .map(|i| {
1255 let phase = (i as f32 / speech_len as f32) * 20.0;
1256 phase.sin()
1257 })
1258 .collect();
1259
1260 let vad_speech = VadContext { is_silence: false };
1261 let output = reducer
1262 .reduce(&speech, Some(vad_speech))
1263 .map_err(|e| e.to_string())?;
1264
1265 assert_eq!(
1266 output.len(),
1267 speech_len,
1268 "Output length should match input length"
1269 );
1270
1271 let tail = &output[speech_len - 80..];
1272 let tail_energy: f32 = tail.iter().map(|sample| sample.abs()).sum();
1273 assert!(
1274 tail_energy > 1e-3,
1275 "Trailing samples should retain energy, got tail_energy={tail_energy}"
1276 );
1277
1278 Ok(())
1279 }
1280
1281 #[test]
1282 fn test_missing_vad_context_does_not_update_noise_profile() -> TestResult<()> {
1283 let config = NoiseReductionConfig::default();
1284 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1285
1286 let ambient_noise = vec![0.05f32; 8000];
1289 let vad_silence = VadContext { is_silence: true };
1290 reducer
1291 .reduce(&ambient_noise, Some(vad_silence))
1292 .map_err(|e| e.to_string())?;
1293 let baseline_floor = reducer.noise_floor();
1294
1295 let speech = vec![0.2f32; 8000];
1297 let output = reducer.reduce(&speech, None).map_err(|e| e.to_string())?;
1298 let updated_floor = reducer.noise_floor();
1299
1300 let floor_delta = (updated_floor - baseline_floor).abs();
1301 assert!(
1302 floor_delta < baseline_floor.max(1e-6) * 0.01,
1303 "Noise floor changed when VAD context missing: baseline={baseline_floor}, \
1304 updated={updated_floor}"
1305 );
1306
1307 let output_rms =
1308 (output.iter().map(|sample| sample * sample).sum::<f32>() / output.len() as f32).sqrt();
1309 assert!(
1310 output_rms > 0.08,
1311 "Speech energy collapsed without VAD context (rms={output_rms})"
1312 );
1313
1314 Ok(())
1315 }
1316}