wavekat_vad/preprocessing/
normalize.rs1const FULL_SCALE: f64 = 32768.0;
8
9const MIN_RMS_THRESHOLD: f64 = 0.001; #[derive(Debug, Clone)]
18pub struct Normalizer {
19 target_rms: f64,
21 current_gain: f64,
23 smoothing: f64,
25 peak_limit: bool,
27}
28
29impl Normalizer {
30 pub fn new(target_dbfs: f32) -> Self {
38 assert!(
39 target_dbfs <= 0.0,
40 "Target dBFS must be <= 0, got {target_dbfs}"
41 );
42
43 let target_rms = FULL_SCALE * 10_f64.powf(target_dbfs as f64 / 20.0);
47
48 Self {
49 target_rms,
50 current_gain: 1.0,
51 smoothing: 0.1, peak_limit: true,
53 }
54 }
55
56 pub fn with_settings(target_dbfs: f32, smoothing: f64, peak_limit: bool) -> Self {
58 let mut normalizer = Self::new(target_dbfs);
59 normalizer.smoothing = smoothing.clamp(0.01, 1.0);
60 normalizer.peak_limit = peak_limit;
61 normalizer
62 }
63
64 fn calculate_rms(samples: &[i16]) -> f64 {
66 if samples.is_empty() {
67 return 0.0;
68 }
69
70 let sum_squares: f64 = samples.iter().map(|&s| (s as f64).powi(2)).sum();
71 (sum_squares / samples.len() as f64).sqrt()
72 }
73
74 #[allow(dead_code)]
76 pub fn rms_to_dbfs(rms: f64) -> f64 {
77 if rms <= 0.0 {
78 return -96.0; }
80 20.0 * (rms / FULL_SCALE).log10()
81 }
82
83 pub fn process(&mut self, samples: &[i16]) -> Vec<i16> {
87 if samples.is_empty() {
88 return Vec::new();
89 }
90
91 let input_rms = Self::calculate_rms(samples);
92
93 if input_rms < MIN_RMS_THRESHOLD * FULL_SCALE {
95 return samples.to_vec();
96 }
97
98 let target_gain = self.target_rms / input_rms;
100
101 self.current_gain += self.smoothing * (target_gain - self.current_gain);
103
104 samples
106 .iter()
107 .map(|&s| {
108 let amplified = s as f64 * self.current_gain;
109
110 if self.peak_limit {
111 let normalized = amplified / FULL_SCALE;
113 let limited = if normalized.abs() > 0.9 {
114 let sign = normalized.signum();
116 let magnitude = normalized.abs();
117 let compressed = 0.9 + 0.1 * ((magnitude - 0.9) / 0.1).tanh();
118 sign * compressed * FULL_SCALE
119 } else {
120 amplified
121 };
122 limited.round().clamp(-32768.0, 32767.0) as i16
123 } else {
124 amplified.round().clamp(-32768.0, 32767.0) as i16
126 }
127 })
128 .collect()
129 }
130
131 pub fn reset(&mut self) {
133 self.current_gain = 1.0;
134 }
135
136 pub fn current_gain(&self) -> f64 {
138 self.current_gain
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn test_normalizer_creation() {
148 let norm = Normalizer::new(-20.0);
149 assert!(norm.target_rms > 0.0);
150 assert!(norm.target_rms < FULL_SCALE);
151 }
152
153 #[test]
154 #[should_panic(expected = "Target dBFS must be <= 0")]
155 fn test_normalizer_invalid_target() {
156 Normalizer::new(6.0);
157 }
158
159 #[test]
160 fn test_rms_calculation() {
161 let dc: Vec<i16> = vec![1000; 100];
163 let rms = Normalizer::calculate_rms(&dc);
164 assert!((rms - 1000.0).abs() < 1.0);
165
166 let silence: Vec<i16> = vec![0; 100];
168 let rms = Normalizer::calculate_rms(&silence);
169 assert_eq!(rms, 0.0);
170 }
171
172 #[test]
173 fn test_normalizer_amplifies_quiet() {
174 let mut norm = Normalizer::new(-20.0);
175
176 let quiet: Vec<i16> = vec![100; 480];
178 let output = norm.process(&quiet);
179
180 let input_rms = Normalizer::calculate_rms(&quiet);
182 let output_rms = Normalizer::calculate_rms(&output);
183 assert!(
184 output_rms > input_rms,
185 "Output RMS {output_rms} should be > input RMS {input_rms}"
186 );
187 }
188
189 #[test]
190 fn test_normalizer_attenuates_loud() {
191 let mut norm = Normalizer::new(-20.0);
192
193 let loud: Vec<i16> = vec![16000; 480];
195 let output = norm.process(&loud);
196
197 let input_rms = Normalizer::calculate_rms(&loud);
199 let output_rms = Normalizer::calculate_rms(&output);
200 assert!(
201 output_rms < input_rms,
202 "Output RMS {output_rms} should be < input RMS {input_rms}"
203 );
204 }
205
206 #[test]
207 fn test_normalizer_skips_silence() {
208 let mut norm = Normalizer::new(-20.0);
209
210 let silence: Vec<i16> = vec![1; 480];
212 let output = norm.process(&silence);
213
214 assert_eq!(output, silence);
216 }
217
218 #[test]
219 fn test_normalizer_peak_limiting() {
220 let mut norm = Normalizer::with_settings(-6.0, 1.0, true);
221
222 let input: Vec<i16> = vec![10000; 480];
224 let output = norm.process(&input);
225
226 assert!(!output.is_empty());
230 }
231
232 #[test]
233 fn test_normalizer_reset() {
234 let mut norm = Normalizer::new(-20.0);
235
236 let samples: Vec<i16> = vec![1000; 480];
238 norm.process(&samples);
239 assert!(norm.current_gain() != 1.0);
240
241 norm.reset();
243 assert_eq!(norm.current_gain(), 1.0);
244 }
245
246 #[test]
247 fn test_dbfs_conversion() {
248 let dbfs = Normalizer::rms_to_dbfs(FULL_SCALE);
250 assert!((dbfs - 0.0).abs() < 0.01);
251
252 let dbfs = Normalizer::rms_to_dbfs(FULL_SCALE / 2.0);
254 assert!((dbfs - (-6.02)).abs() < 0.1);
255 }
256}