1use std::f32::consts::PI;
47use std::sync::Arc;
48
49use super::VadContext;
50use crate::error::{Error, Result};
51use crate::time::{AudioDuration, AudioInstant};
52use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
53use tracing::{info, warn};
54
55#[derive(Debug, Clone)]
74#[allow(missing_copy_implementations)]
75pub struct NoiseReductionConfig {
76 pub sample_rate_hz: u32,
81
82 pub window_ms: f32,
90
91 pub hop_ms: f32,
99
100 pub oversubtraction_factor: f32,
109
110 pub spectral_floor: f32,
118
119 pub noise_smoothing: f32,
127
128 pub enable: bool,
134}
135
136impl Default for NoiseReductionConfig {
137 fn default() -> Self {
138 Self {
139 sample_rate_hz: 16_000,
140 window_ms: 25.0,
141 hop_ms: 10.0,
142 oversubtraction_factor: 2.0,
143 spectral_floor: 0.02,
144 noise_smoothing: 0.98,
145 enable: true,
146 }
147 }
148}
149
150impl NoiseReductionConfig {
151 #[allow(clippy::trivially_copy_pass_by_ref)]
163 pub fn validate(&self) -> Result<()> {
164 if !(8000..=48_000).contains(&self.sample_rate_hz) {
165 return Err(Error::Configuration(format!(
166 "Invalid sample rate: {} Hz (range: 8000-48000)",
167 self.sample_rate_hz
168 )));
169 }
170
171 if !(10.0..=50.0).contains(&self.window_ms) {
172 return Err(Error::Configuration(format!(
173 "Invalid window size: {:.1} ms (range: 10-50)",
174 self.window_ms
175 )));
176 }
177
178 if self.hop_ms >= self.window_ms {
179 return Err(Error::Configuration(format!(
180 "Hop {:.1} ms must be < window {:.1} ms",
181 self.hop_ms, self.window_ms
182 )));
183 }
184
185 if !(1.0..=3.0).contains(&self.oversubtraction_factor) {
186 return Err(Error::Configuration(format!(
187 "Invalid oversubtraction factor: {:.2} (range: 1.0-3.0)",
188 self.oversubtraction_factor
189 )));
190 }
191
192 if !(0.001..=0.1).contains(&self.spectral_floor) {
193 return Err(Error::Configuration(format!(
194 "Invalid spectral floor: {:.3} (range: 0.001-0.1)",
195 self.spectral_floor
196 )));
197 }
198
199 if !(0.9..1.0).contains(&self.noise_smoothing) {
200 return Err(Error::Configuration(format!(
201 "Invalid noise smoothing: {:.3} (range: 0.9-0.999)",
202 self.noise_smoothing
203 )));
204 }
205
206 Ok(())
207 }
208
209 pub fn frame_length(&self) -> usize {
211 ((self.window_ms / 1000.0) * self.sample_rate_hz as f32).round() as usize
212 }
213
214 pub fn hop_length(&self) -> usize {
216 ((self.hop_ms / 1000.0) * self.sample_rate_hz as f32).round() as usize
217 }
218
219 pub fn fft_size(&self) -> usize {
221 self.frame_length().next_power_of_two()
222 }
223}
224
225#[allow(missing_copy_implementations)]
262pub struct NoiseReducer {
263 config: NoiseReductionConfig,
264 fft_forward: Arc<dyn RealToComplex<f32>>,
265 fft_inverse: Arc<dyn ComplexToReal<f32>>,
266 window: Vec<f32>,
267 noise_profile: Vec<f32>,
268 noise_initialized: bool,
269 overlap_buffer: Vec<f32>,
270}
271
272impl std::fmt::Debug for NoiseReducer {
273 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274 f.debug_struct("NoiseReducer")
275 .field("config", &self.config)
276 .field("window_length", &self.window.len())
277 .field("noise_profile_bins", &self.noise_profile.len())
278 .field("noise_initialized", &self.noise_initialized)
279 .finish_non_exhaustive()
280 }
281}
282
283impl NoiseReducer {
284 pub fn new(config: NoiseReductionConfig) -> Result<Self> {
307 config.validate()?;
308
309 let fft_size = config.fft_size();
310 let frame_length = config.frame_length();
311
312 let mut planner = RealFftPlanner::<f32>::new();
313 let fft_forward = planner.plan_fft_forward(fft_size);
314 let fft_inverse = planner.plan_fft_inverse(fft_size);
315
316 let window = generate_hann_window(frame_length);
317
318 let num_bins = fft_size / 2 + 1;
319 let noise_profile = vec![1e-6; num_bins];
320
321 let overlap_buffer = vec![0.0; frame_length];
322
323 Ok(Self {
324 config,
325 fft_forward,
326 fft_inverse,
327 window,
328 noise_profile,
329 noise_initialized: false,
330 overlap_buffer,
331 })
332 }
333
334 #[allow(clippy::unnecessary_wraps)]
370 pub fn reduce(&mut self, samples: &[f32], vad_context: Option<VadContext>) -> Result<Vec<f32>> {
371 let processing_start = AudioInstant::now();
372
373 if samples.is_empty() {
374 return Ok(Vec::new());
375 }
376
377 if !self.config.enable {
378 return Ok(samples.to_vec());
379 }
380
381 let (mut output, frame_count) = self.process_stft_frames(samples, vad_context)?;
382 self.normalize_overlap_add(&mut output);
383
384 let elapsed = elapsed_duration(processing_start);
385 let latency_ms = elapsed.as_secs_f64() * 1000.0;
386 self.record_performance_metrics(samples, &output, latency_ms, frame_count);
387
388 Ok(output)
389 }
390
391 fn process_stft_frames(
393 &mut self,
394 samples: &[f32],
395 vad_context: Option<VadContext>,
396 ) -> Result<(Vec<f32>, usize)> {
397 let frame_length = self.config.frame_length();
398 let hop_length = self.config.hop_length();
399
400 let mut output = vec![0.0; samples.len()];
401 let mut frame_idx = 0;
402 let mut pos = 0;
403
404 while pos < samples.len() {
405 let remaining = samples.len() - pos;
406
407 let frame = Self::extract_frame(samples, pos, frame_length, remaining)?;
408
409 let processed =
410 self.process_single_frame(&frame, vad_context, remaining >= frame_length)?;
411
412 Self::accumulate_frame_output(&processed, &mut output, pos);
413
414 frame_idx += 1;
415
416 if remaining < hop_length {
417 break;
418 }
419 pos += hop_length;
420 }
421
422 Ok((output, frame_idx))
423 }
424
425 fn extract_frame(
427 samples: &[f32],
428 pos: usize,
429 frame_length: usize,
430 remaining: usize,
431 ) -> Result<Vec<f32>> {
432 let mut frame_buf = vec![0.0; frame_length];
433
434 if remaining >= frame_length {
435 let src = samples
436 .get(pos..pos + frame_length)
437 .ok_or_else(|| Error::Processing("frame window out of bounds".into()))?;
438 frame_buf.copy_from_slice(src);
439 } else {
440 let src = samples
441 .get(pos..)
442 .ok_or_else(|| Error::Processing("frame tail out of bounds".into()))?;
443 if let Some(dst) = frame_buf.get_mut(..remaining) {
444 dst.copy_from_slice(src);
445 }
446 }
447
448 Ok(frame_buf)
449 }
450
451 fn process_single_frame(
453 &mut self,
454 frame: &[f32],
455 vad_context: Option<VadContext>,
456 is_full_frame: bool,
457 ) -> Result<Vec<f32>> {
458 let fft_size = self.config.fft_size();
459
460 let windowed: Vec<f32> = frame
461 .iter()
462 .zip(&self.window)
463 .map(|(&s, &w)| s * w)
464 .collect();
465
466 let complex_spectrum = self.forward_fft_complex(&windowed)?;
467 let magnitudes: Vec<f32> = complex_spectrum.iter().map(|c| c.norm()).collect();
468
469 let is_silence = vad_context.is_some_and(|ctx| ctx.is_silence);
470 if is_silence && is_full_frame {
471 self.update_noise_profile(&magnitudes);
472 }
473
474 let cleaned_magnitudes = self.spectral_subtract(&magnitudes);
475
476 let cleaned_complex =
477 Self::reconstruct_complex_spectrum(&complex_spectrum, &cleaned_magnitudes);
478
479 let time_signal = self.inverse_fft_complex(&cleaned_complex, fft_size)?;
480
481 let windowed_output: Vec<f32> = time_signal
482 .iter()
483 .take(frame.len())
484 .zip(&self.window)
485 .map(|(&s, &w)| s * w)
486 .collect();
487
488 Ok(windowed_output)
489 }
490
491 fn reconstruct_complex_spectrum(
493 original_spectrum: &[realfft::num_complex::Complex<f32>],
494 cleaned_magnitudes: &[f32],
495 ) -> Vec<realfft::num_complex::Complex<f32>> {
496 original_spectrum
497 .iter()
498 .zip(cleaned_magnitudes)
499 .enumerate()
500 .map(|(i, (original, &new_mag))| {
501 if i == 0 || i == original_spectrum.len() - 1 {
502 realfft::num_complex::Complex::new(new_mag, 0.0)
504 } else {
505 let phase = original.arg();
506 realfft::num_complex::Complex::from_polar(new_mag, phase)
507 }
508 })
509 .collect()
510 }
511
512 fn accumulate_frame_output(frame: &[f32], output: &mut [f32], pos: usize) {
514 for (i, &sample) in frame.iter().enumerate() {
515 let out_idx = pos + i;
516 if let Some(dst) = output.get_mut(out_idx) {
517 *dst += sample;
518 }
519 }
520 }
521
522 fn normalize_overlap_add(&self, output: &mut [f32]) {
524 let hop_length = self.config.hop_length();
525 let window_sum = self.calculate_window_overlap_sum(hop_length);
526
527 if window_sum > 1e-6 {
528 for sample in output {
529 *sample /= window_sum;
530 }
531 }
532 }
533
534 fn record_performance_metrics(
535 &self,
536 input: &[f32],
537 output: &[f32],
538 latency_ms: f64,
539 frame_count: usize,
540 ) {
541 if input.len() < 8000 {
542 return;
543 }
544
545 if latency_ms > 15.0 {
546 warn!(
547 target: "audio.preprocess.noise_reduction",
548 latency_ms,
549 samples = input.len(),
550 frames = frame_count,
551 oversubtraction = self.config.oversubtraction_factor,
552 spectral_floor = self.config.spectral_floor,
553 "noise reduction latency exceeded target"
554 );
555 }
556
557 let avg_noise_floor = self.noise_floor().max(1e-12);
558 let noise_floor_db = 20.0 * avg_noise_floor.log10();
559
560 let signal_power_out =
561 output.iter().map(|sample| sample * sample).sum::<f32>() / output.len() as f32;
562 let residual_power: f32 = input
563 .iter()
564 .zip(output)
565 .map(|(&noisy, &clean)| {
566 let residual = noisy - clean;
567 residual * residual
568 })
569 .sum::<f32>()
570 / output.len() as f32;
571
572 let snr_improvement_db = if residual_power > 1e-12 && signal_power_out > 0.0 {
573 10.0 * (signal_power_out / residual_power).log10()
574 } else {
575 0.0
576 };
577
578 info!(
579 target: "audio.preprocess.noise_reduction",
580 noise_floor_db,
581 snr_improvement_db,
582 latency_ms,
583 frames = frame_count,
584 samples = input.len(),
585 oversubtraction = self.config.oversubtraction_factor,
586 spectral_floor = self.config.spectral_floor,
587 "noise reduction metrics"
588 );
589 }
590
591 pub fn reset(&mut self) {
596 self.noise_profile.fill(1e-6);
597 self.noise_initialized = false;
598 self.overlap_buffer.fill(0.0);
599 }
600
601 #[must_use]
603 pub fn noise_floor(&self) -> f32 {
604 if self.noise_profile.is_empty() {
605 return 0.0;
606 }
607 self.noise_profile.iter().sum::<f32>() / self.noise_profile.len() as f32
608 }
609
610 #[must_use]
612 pub fn config(&self) -> &NoiseReductionConfig {
613 &self.config
614 }
615
616 fn forward_fft_complex(
618 &self,
619 windowed: &[f32],
620 ) -> Result<Vec<realfft::num_complex::Complex<f32>>> {
621 let mut input = self.fft_forward.make_input_vec();
623 for (i, &sample) in windowed.iter().enumerate() {
624 if let Some(dst) = input.get_mut(i) {
625 *dst = sample;
626 }
627 }
628
629 let mut spectrum = self.fft_forward.make_output_vec();
631 self.fft_forward
632 .process(&mut input, &mut spectrum)
633 .map_err(|e| Error::Processing(format!("FFT failed: {e}")))?;
634
635 Ok(spectrum)
636 }
637
638 fn inverse_fft_complex(
640 &self,
641 complex_spectrum: &[realfft::num_complex::Complex<f32>],
642 fft_size: usize,
643 ) -> Result<Vec<f32>> {
644 let mut spectrum = self.fft_inverse.make_input_vec();
646 for (i, &c) in complex_spectrum.iter().enumerate() {
647 if let Some(bin) = spectrum.get_mut(i) {
648 *bin = c;
649 }
650 }
651
652 let mut output = self.fft_inverse.make_output_vec();
654 self.fft_inverse
655 .process(&mut spectrum, &mut output)
656 .map_err(|e| Error::Processing(format!("IFFT failed: {e}")))?;
657
658 for sample in &mut output {
660 *sample /= fft_size as f32;
661 }
662
663 Ok(output)
664 }
665
666 fn update_noise_profile(&mut self, spectrum: &[f32]) {
668 let alpha = self.config.noise_smoothing;
669
670 if self.noise_initialized {
671 for (noise, ¤t) in self.noise_profile.iter_mut().zip(spectrum.iter()) {
673 *noise = alpha.mul_add(*noise, (1.0 - alpha) * current);
674 }
675 } else {
676 self.noise_profile.copy_from_slice(spectrum);
678 self.noise_initialized = true;
679 }
680 }
681
682 fn spectral_subtract(&self, spectrum: &[f32]) -> Vec<f32> {
684 let alpha = self.config.oversubtraction_factor;
685 let beta = self.config.spectral_floor;
686
687 spectrum
688 .iter()
689 .zip(&self.noise_profile)
690 .map(|(&signal, &noise)| {
691 let subtracted = alpha.mul_add(-noise, signal);
692 let floor = beta * noise;
693 subtracted.max(floor)
694 })
695 .collect()
696 }
697
698 fn calculate_window_overlap_sum(&self, hop_length: usize) -> f32 {
700 let frame_length = self.window.len();
701 let mut sum: f32 = 0.0;
702
703 for i in 0..frame_length {
705 let mut overlap: f32 = 0.0;
706 let mut offset = 0;
707
708 while offset <= i {
709 if let Some(&w) = self.window.get(i - offset) {
710 overlap = w.mul_add(w, overlap); }
714 offset += hop_length;
715 }
716
717 sum = sum.max(overlap);
718 }
719
720 sum
721 }
722}
723
724fn generate_hann_window(length: usize) -> Vec<f32> {
728 if length == 0 {
729 return Vec::new();
730 }
731
732 if length == 1 {
733 return vec![1.0];
734 }
735
736 let denom = (length - 1) as f32;
737 (0..length)
738 .map(|n| {
739 let angle = 2.0 * PI * n as f32 / denom;
740 0.5f32.mul_add(-angle.cos(), 0.5)
741 })
742 .collect()
743}
744
745fn elapsed_duration(start: AudioInstant) -> AudioDuration {
746 AudioInstant::now().duration_since(start)
747}
748
749#[cfg(test)]
750mod tests {
751 use super::*;
752
753 type TestResult<T> = std::result::Result<T, String>;
754
755 #[test]
756 #[allow(clippy::unnecessary_wraps)]
757 fn test_configuration_validation() -> TestResult<()> {
758 let valid = NoiseReductionConfig::default();
760 assert!(valid.validate().is_ok());
761
762 let invalid_sr = NoiseReductionConfig {
764 sample_rate_hz: 5000,
765 ..Default::default()
766 };
767 assert!(invalid_sr.validate().is_err());
768
769 let invalid_window = NoiseReductionConfig {
771 window_ms: 100.0,
772 ..Default::default()
773 };
774 assert!(invalid_window.validate().is_err());
775
776 let invalid_hop = NoiseReductionConfig {
778 hop_ms: 30.0,
779 window_ms: 25.0,
780 ..Default::default()
781 };
782 assert!(invalid_hop.validate().is_err());
783
784 let invalid_alpha = NoiseReductionConfig {
786 oversubtraction_factor: 5.0,
787 ..Default::default()
788 };
789 assert!(invalid_alpha.validate().is_err());
790
791 let invalid_beta = NoiseReductionConfig {
793 spectral_floor: 0.5,
794 ..Default::default()
795 };
796 assert!(invalid_beta.validate().is_err());
797
798 let invalid_smoothing = NoiseReductionConfig {
800 noise_smoothing: 1.0,
801 ..Default::default()
802 };
803 assert!(invalid_smoothing.validate().is_err());
804
805 Ok(())
806 }
807
808 #[test]
809 fn test_hann_window_properties() {
810 let window_0 = generate_hann_window(0);
812 assert!(window_0.is_empty());
813
814 let window_1 = generate_hann_window(1);
816 assert_eq!(window_1.len(), 1);
817 assert!((window_1[0] - 1.0).abs() < 1e-6);
818
819 let window = generate_hann_window(100);
821 assert_eq!(window.len(), 100);
822
823 assert!(window[0].abs() < 1e-6);
825 assert!(window[99].abs() < 1e-6);
826
827 assert!((window[50] - 1.0).abs() < 0.1);
829 }
830
831 #[test]
832 fn test_noise_reducer_creation() -> TestResult<()> {
833 let config = NoiseReductionConfig::default();
834 let reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
835
836 assert_eq!(reducer.config().sample_rate_hz, 16000);
838 assert!(reducer.noise_floor() > 0.0); Ok(())
841 }
842
843 #[test]
844 fn test_empty_input() -> TestResult<()> {
845 let config = NoiseReductionConfig::default();
846 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
847
848 let output = reducer.reduce(&[], None).map_err(|e| e.to_string())?;
849 assert!(output.is_empty());
850
851 Ok(())
852 }
853
854 #[test]
855 fn test_bypass_mode() -> TestResult<()> {
856 let config = NoiseReductionConfig {
857 enable: false,
858 ..Default::default()
859 };
860 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
861
862 let input = vec![0.1, 0.2, 0.3, 0.4];
863 let output = reducer.reduce(&input, None).map_err(|e| e.to_string())?;
864
865 assert_eq!(output, input);
867
868 Ok(())
869 }
870
871 #[test]
872 fn test_noise_profile_update() -> TestResult<()> {
873 let config = NoiseReductionConfig::default();
874 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
875
876 let silence = vec![0.01; 8000]; let vad_silence = VadContext { is_silence: true };
879
880 let initial_noise = reducer.noise_floor();
881
882 for _ in 0..5 {
884 let _ = reducer
885 .reduce(&silence, Some(vad_silence))
886 .map_err(|e| e.to_string())?;
887 }
888
889 let converged_noise = reducer.noise_floor();
890
891 assert!(
893 converged_noise > initial_noise,
894 "Noise floor should adapt: initial={:.6}, converged={:.6}",
895 initial_noise,
896 converged_noise
897 );
898
899 Ok(())
900 }
901
902 #[test]
903 fn test_vad_informed_noise_update() -> TestResult<()> {
904 let config = NoiseReductionConfig::default();
905 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
906
907 let silence = vec![0.01; 8000];
909 let vad_silence = VadContext { is_silence: true };
910 for _ in 0..5 {
911 let _ = reducer
912 .reduce(&silence, Some(vad_silence))
913 .map_err(|e| e.to_string())?;
914 }
915
916 let noise_after_silence = reducer.noise_floor();
917
918 let speech = vec![0.5; 8000];
920 let vad_speech = VadContext { is_silence: false };
921 let _ = reducer
922 .reduce(&speech, Some(vad_speech))
923 .map_err(|e| e.to_string())?;
924
925 let noise_after_speech = reducer.noise_floor();
926
927 let diff = (noise_after_speech - noise_after_silence).abs();
929 assert!(
930 diff < noise_after_silence * 0.01,
931 "Noise profile changed during speech: {:.6} -> {:.6}",
932 noise_after_silence,
933 noise_after_speech
934 );
935
936 Ok(())
937 }
938
939 #[test]
940 fn test_reset_clears_state() -> TestResult<()> {
941 let config = NoiseReductionConfig::default();
942 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
943
944 let samples = vec![0.1; 8000];
946 let vad = VadContext { is_silence: true };
947 let _ = reducer
948 .reduce(&samples, Some(vad))
949 .map_err(|e| e.to_string())?;
950
951 let noise_before = reducer.noise_floor();
952 assert!(noise_before > 1e-5, "Noise profile should be updated");
953
954 reducer.reset();
956
957 let noise_after = reducer.noise_floor();
958 assert!(
959 noise_after < 1e-5,
960 "Noise profile should be reset to initial value"
961 );
962
963 Ok(())
964 }
965
966 fn generate_sine_wave(
968 frequency: f32,
969 sample_rate: u32,
970 duration_secs: f32,
971 amplitude: f32,
972 ) -> Vec<f32> {
973 use std::f32::consts::PI;
974 let samples = (sample_rate as f32 * duration_secs).round() as usize;
975 (0..samples)
976 .map(|i| {
977 let t = i as f32 / sample_rate as f32;
978 (2.0 * PI * frequency * t).sin() * amplitude
979 })
980 .collect()
981 }
982
983 fn add_white_noise(signal: &[f32], noise_amplitude: f32) -> Vec<f32> {
985 use rand::Rng;
986 let mut rng = rand::rng();
987 signal
988 .iter()
989 .map(|&s| {
990 let noise: f32 = rng.random_range(-noise_amplitude..noise_amplitude);
991 s + noise
992 })
993 .collect()
994 }
995 fn add_low_freq_hum(
996 signal: &[f32],
997 sample_rate: u32,
998 frequency: f32,
999 amplitude: f32,
1000 ) -> Vec<f32> {
1001 signal
1002 .iter()
1003 .enumerate()
1004 .map(|(i, &sample)| {
1005 let t = i as f32 / sample_rate as f32;
1006 let hum = (2.0 * PI * frequency * t).sin() * amplitude;
1007 sample + hum
1008 })
1009 .collect()
1010 }
1011
1012 fn add_cafe_noise(signal: &[f32], _sample_rate: u32, amplitude: f32) -> Vec<f32> {
1018 use rand::Rng;
1019 let mut rng = rand::rng();
1020 signal
1021 .iter()
1022 .map(|&sample| {
1023 let noise: f32 = rng.random_range(-1.0..1.0);
1024 amplitude.mul_add(noise, sample)
1025 })
1026 .collect()
1027 }
1028
1029 fn calculate_snr(clean: &[f32], noisy: &[f32]) -> f32 {
1031 if clean.len() != noisy.len() {
1032 return 0.0;
1033 }
1034
1035 let signal_power: f32 = clean.iter().map(|&x| x * x).sum();
1036 let noise: Vec<f32> = clean
1037 .iter()
1038 .zip(noisy.iter())
1039 .map(|(&c, &n)| n - c)
1040 .collect();
1041 let noise_power: f32 = noise.iter().map(|&x| x * x).sum();
1042
1043 if noise_power < 1e-10 {
1044 return 100.0; }
1046
1047 10.0 * (signal_power / noise_power).log10()
1048 }
1049
1050 #[test]
1051 fn test_snr_improvement_white_noise() -> TestResult<()> {
1052 let config = NoiseReductionConfig::default();
1053 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1054
1055 let clean_speech = generate_sine_wave(440.0, 16000, 1.0, 0.5);
1057
1058 let noisy_speech = add_white_noise(&clean_speech, 0.3);
1060
1061 let snr_before = calculate_snr(&clean_speech, &noisy_speech);
1062
1063 let pure_noise = add_white_noise(&vec![0.0; 8000], 0.3);
1067 let vad_silence = VadContext { is_silence: true };
1068 for _ in 0..10 {
1069 let _ = reducer
1070 .reduce(&pure_noise, Some(vad_silence))
1071 .map_err(|e| e.to_string())?;
1072 }
1073
1074 let vad_speech = VadContext { is_silence: false };
1076 let denoised = reducer
1077 .reduce(&noisy_speech, Some(vad_speech))
1078 .map_err(|e| e.to_string())?;
1079
1080 let snr_after = calculate_snr(&clean_speech, &denoised);
1081 let improvement = snr_after - snr_before;
1082
1083 assert!(
1085 improvement >= 6.0,
1086 "SNR improvement {:.1} dB < 6 dB target",
1087 improvement
1088 );
1089
1090 Ok(())
1091 }
1092
1093 #[test]
1094 fn test_snr_improvement_low_freq_hum() -> TestResult<()> {
1095 let config = NoiseReductionConfig::default();
1096 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1097
1098 let clean = generate_sine_wave(440.0, 16000, 1.0, 0.4);
1100 let noisy = add_low_freq_hum(&clean, 16000, 60.0, 0.3);
1101 let snr_before = calculate_snr(&clean, &noisy);
1102
1103 let hum_only = add_low_freq_hum(&vec![0.0; 8000], 16000, 60.0, 0.3);
1107 let vad = VadContext { is_silence: true };
1108 for _ in 0..6 {
1109 let _ = reducer
1110 .reduce(&hum_only, Some(vad))
1111 .map_err(|e| e.to_string())?;
1112 }
1113
1114 let vad_speech = VadContext { is_silence: false };
1115 let denoised = reducer
1116 .reduce(&noisy, Some(vad_speech))
1117 .map_err(|e| e.to_string())?;
1118 let snr_after = calculate_snr(&clean, &denoised);
1119 let improvement = snr_after - snr_before;
1120 assert!(
1121 improvement >= 6.0,
1122 "Hum SNR improvement {:.1} dB < 6 dB target",
1123 improvement
1124 );
1125
1126 Ok(())
1127 }
1128
1129 #[test]
1130 fn test_snr_improvement_cafe_ambient() -> TestResult<()> {
1131 let config = NoiseReductionConfig::default();
1133 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1134
1135 let clean = generate_sine_wave(220.0, 16000, 1.0, 0.4);
1137
1138 let noisy = add_cafe_noise(&clean, 16000, 0.25);
1140 let snr_before = calculate_snr(&clean, &noisy);
1141
1142 let cafe_only = add_cafe_noise(&vec![0.0; 8000], 16000, 0.25);
1147 let vad = VadContext { is_silence: true };
1148 for _ in 0..10 {
1149 let _ = reducer
1150 .reduce(&cafe_only, Some(vad))
1151 .map_err(|e| e.to_string())?;
1152 }
1153
1154 let vad_speech = VadContext { is_silence: false };
1156 let denoised = reducer
1157 .reduce(&noisy, Some(vad_speech))
1158 .map_err(|e| e.to_string())?;
1159
1160 let snr_after = calculate_snr(&clean, &denoised);
1161 let improvement = snr_after - snr_before;
1162
1163 assert!(
1164 improvement >= 6.0,
1165 "Café ambient SNR improvement {:.1} dB < 6 dB target",
1166 improvement
1167 );
1168
1169 Ok(())
1170 }
1171
1172 #[test]
1173 fn test_trailing_partial_frame_preserved() -> TestResult<()> {
1174 let config = NoiseReductionConfig::default();
1175 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1176
1177 let silence = vec![0.0; 8000];
1180 let vad_silence = VadContext { is_silence: true };
1181 let _ = reducer
1182 .reduce(&silence, Some(vad_silence))
1183 .map_err(|e| e.to_string())?;
1184
1185 let speech_len = 8080;
1188 let speech: Vec<f32> = (0..speech_len)
1189 .map(|i| {
1190 let phase = (i as f32 / speech_len as f32) * 20.0;
1191 phase.sin()
1192 })
1193 .collect();
1194
1195 let vad_speech = VadContext { is_silence: false };
1196 let output = reducer
1197 .reduce(&speech, Some(vad_speech))
1198 .map_err(|e| e.to_string())?;
1199
1200 assert_eq!(
1201 output.len(),
1202 speech_len,
1203 "Output length should match input length"
1204 );
1205
1206 let tail = &output[speech_len - 80..];
1207 let tail_energy: f32 = tail.iter().map(|sample| sample.abs()).sum();
1208 assert!(
1209 tail_energy > 1e-3,
1210 "Trailing samples should retain energy, got tail_energy={tail_energy}"
1211 );
1212
1213 Ok(())
1214 }
1215
1216 #[test]
1217 fn test_missing_vad_context_does_not_update_noise_profile() -> TestResult<()> {
1218 let config = NoiseReductionConfig::default();
1219 let mut reducer = NoiseReducer::new(config).map_err(|e| e.to_string())?;
1220
1221 let ambient_noise = vec![0.05f32; 8000];
1224 let vad_silence = VadContext { is_silence: true };
1225 reducer
1226 .reduce(&ambient_noise, Some(vad_silence))
1227 .map_err(|e| e.to_string())?;
1228 let baseline_floor = reducer.noise_floor();
1229
1230 let speech = vec![0.2f32; 8000];
1232 let output = reducer.reduce(&speech, None).map_err(|e| e.to_string())?;
1233 let updated_floor = reducer.noise_floor();
1234
1235 let floor_delta = (updated_floor - baseline_floor).abs();
1236 assert!(
1237 floor_delta < baseline_floor.max(1e-6) * 0.01,
1238 "Noise floor changed when VAD context missing: baseline={baseline_floor}, \
1239 updated={updated_floor}"
1240 );
1241
1242 let output_rms =
1243 (output.iter().map(|sample| sample * sample).sum::<f32>() / output.len() as f32).sqrt();
1244 assert!(
1245 output_rms > 0.08,
1246 "Speech energy collapsed without VAD context (rms={output_rms})"
1247 );
1248
1249 Ok(())
1250 }
1251}