speech_prep/preprocessing/
normalization.rs1use crate::error::{Error, Result};
23use crate::time::{AudioDuration, AudioInstant};
24use tracing::{debug, trace};
25
26#[derive(Debug, Clone, Copy)]
39pub struct Normalizer {
40 target_rms: f32,
41 max_gain: f32,
42}
43
44impl Normalizer {
45 pub fn new(target_rms: f32, max_gain: f32) -> Result<Self> {
58 if !(0.0..=1.0).contains(&target_rms) {
59 return Err(Error::InvalidInput(
60 "target_rms must be in range [0.0, 1.0]".into(),
61 ));
62 }
63 if max_gain <= 0.0 {
64 return Err(Error::InvalidInput("max_gain must be positive".into()));
65 }
66 Ok(Self {
67 target_rms,
68 max_gain,
69 })
70 }
71
72 #[tracing::instrument(skip(samples), fields(sample_count = samples.len()))]
86 pub fn normalize(self, samples: &[f32]) -> Result<Vec<f32>> {
87 if samples.is_empty() {
88 return Err(Error::InvalidInput("cannot normalize empty audio".into()));
89 }
90
91 let processing_start = AudioInstant::now();
92 let current_rms = Self::calculate_rms(samples);
93
94 if current_rms < 1e-10 {
95 trace!(
96 current_rms,
97 "audio is silence or near-silence, no normalization applied"
98 );
99 let _elapsed = elapsed_duration(processing_start);
100 return Ok(samples.to_vec());
101 }
102
103 let raw_gain = self.target_rms / current_rms;
104 let gain = raw_gain.min(self.max_gain);
105 let gain_limited = raw_gain > self.max_gain;
106
107 let (output, clipped_samples) = Self::apply_gain_with_limiting(samples, gain);
108
109 Self::log_normalization_metrics(
110 current_rms,
111 self.target_rms,
112 gain,
113 gain_limited,
114 clipped_samples,
115 );
116
117 let _elapsed = elapsed_duration(processing_start);
118
119 Ok(output)
120 }
121
122 fn apply_gain_with_limiting(samples: &[f32], gain: f32) -> (Vec<f32>, usize) {
126 let mut clipped_samples = 0usize;
127 let output: Vec<f32> = samples
128 .iter()
129 .map(|&s| {
130 let amplified = s * gain;
131 if amplified.abs() > 1.0 {
132 clipped_samples += 1;
133 }
134 amplified.clamp(-1.0, 1.0)
135 })
136 .collect();
137 (output, clipped_samples)
138 }
139
140 fn log_normalization_metrics(
141 current_rms: f32,
142 target_rms: f32,
143 gain: f32,
144 gain_limited: bool,
145 clipped_samples: usize,
146 ) {
147 let gain_db = 20.0 * gain.log10();
148
149 if gain_db > 6.0 {
150 debug!(
151 current_rms,
152 target_rms,
153 gain,
154 gain_db,
155 gain_limited,
156 clipped_samples,
157 "high gain applied during normalization"
158 );
159 } else {
160 trace!(
161 current_rms,
162 target_rms,
163 gain,
164 gain_db,
165 gain_limited,
166 clipped_samples,
167 "normalization complete"
168 );
169 }
170 }
171
172 fn calculate_rms(samples: &[f32]) -> f32 {
174 let sum_squares: f32 = samples.iter().map(|&s| s * s).sum();
175 let mean_square = sum_squares / samples.len() as f32;
176 mean_square.sqrt()
177 }
178}
179
180fn elapsed_duration(start: AudioInstant) -> AudioDuration {
181 AudioInstant::now().duration_since(start)
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_normalize_to_target_rms() {
190 let normalizer = Normalizer::new(0.5, 10.0).unwrap();
191 let quiet_audio = vec![0.1f32; 1000];
192
193 let normalized = normalizer.normalize(&quiet_audio).unwrap();
194 let result_rms = Normalizer::calculate_rms(&normalized);
195
196 assert!(
198 (result_rms - 0.5).abs() < 0.025,
199 "RMS {result_rms} not within tolerance of 0.5"
200 );
201 }
202
203 #[test]
204 fn test_no_clipping() {
205 let normalizer = Normalizer::new(0.9, 100.0).unwrap();
206 let audio = vec![0.8f32; 1000];
207
208 let normalized = normalizer.normalize(&audio).unwrap();
209
210 for &sample in &normalized {
211 assert!(
212 (-1.0..=1.0).contains(&sample),
213 "Sample {sample} outside [-1.0, 1.0]"
214 );
215 }
216 }
217
218 #[test]
219 fn test_max_gain_limit() {
220 let normalizer = Normalizer::new(0.5, 2.0).unwrap();
221 let very_quiet = vec![0.01f32; 1000];
222
223 let normalized = normalizer.normalize(&very_quiet).unwrap();
224
225 let actual_gain = normalized[0] / very_quiet[0];
227 assert!(
228 actual_gain <= 2.0 + 1e-6,
229 "Gain {actual_gain} exceeded max_gain 2.0"
230 );
231 }
232
233 #[test]
234 fn test_silence_handling() {
235 let normalizer = Normalizer::new(0.5, 10.0).unwrap();
236 let silence = vec![0.0f32; 1000];
237
238 let normalized = normalizer.normalize(&silence).unwrap();
239
240 assert_eq!(normalized, silence, "Silence should remain unchanged");
241 }
242
243 #[test]
244 fn test_near_silence_handling() {
245 let normalizer = Normalizer::new(0.5, 10.0).unwrap();
246 let near_silence = vec![1e-11f32; 1000];
247
248 let normalized = normalizer.normalize(&near_silence).unwrap();
249
250 assert_eq!(
252 normalized, near_silence,
253 "Near-silence should remain unchanged"
254 );
255 }
256
257 #[test]
258 fn test_invalid_target_rms_above() {
259 let result = Normalizer::new(1.5, 10.0);
260 assert!(result.is_err(), "Should reject target_rms > 1.0");
261 }
262
263 #[test]
264 fn test_invalid_target_rms_below() {
265 let result = Normalizer::new(-0.1, 10.0);
266 assert!(result.is_err(), "Should reject negative target_rms");
267 }
268
269 #[test]
270 fn test_invalid_max_gain_zero() {
271 let result = Normalizer::new(0.5, 0.0);
272 assert!(result.is_err(), "Should reject zero max_gain");
273 }
274
275 #[test]
276 fn test_invalid_max_gain_negative() {
277 let result = Normalizer::new(0.5, -1.0);
278 assert!(result.is_err(), "Should reject negative max_gain");
279 }
280
281 #[test]
282 fn test_empty_audio() {
283 let normalizer = Normalizer::new(0.5, 10.0).unwrap();
284 let result = normalizer.normalize(&[]);
285 assert!(result.is_err(), "Should reject empty audio");
286 }
287
288 #[test]
289 fn test_loud_audio_reduction() {
290 let normalizer = Normalizer::new(0.3, 10.0).unwrap();
291 let loud_audio = vec![0.9f32; 1000];
292
293 let normalized = normalizer.normalize(&loud_audio).unwrap();
294 let result_rms = Normalizer::calculate_rms(&normalized);
295
296 assert!(
298 (result_rms - 0.3).abs() < 0.02,
299 "RMS {result_rms} not within tolerance of 0.3"
300 );
301 }
302
303 #[test]
304 fn test_varied_amplitude() {
305 let normalizer = Normalizer::new(0.5, 10.0).unwrap();
306 let varied: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 0.1).collect();
307
308 let normalized = normalizer.normalize(&varied).unwrap();
309
310 for &sample in &normalized {
312 assert!(
313 (-1.0..=1.0).contains(&sample),
314 "Sample {sample} outside valid range"
315 );
316 }
317
318 let result_rms = Normalizer::calculate_rms(&normalized);
319 assert!(
320 (result_rms - 0.5).abs() < 0.05,
321 "RMS {result_rms} not within tolerance of 0.5"
322 );
323 }
324
325 #[test]
326 fn test_peak_limiting_preserves_bounds() {
327 let normalizer = Normalizer::new(0.8, 20.0).unwrap();
328
329 let mut audio = vec![0.1f32; 999];
332 audio.insert(0, 1.0);
333
334 let normalized = normalizer.normalize(&audio).unwrap();
335 let result_rms = Normalizer::calculate_rms(&normalized);
336
337 assert!(
339 normalized
340 .iter()
341 .all(|sample| (-1.0..=1.0).contains(sample)),
342 "Samples exceeded normalized bounds: {:?}",
343 normalized
344 );
345 assert!(
346 normalized[0] <= 1.0 && normalized[0] >= 0.999,
347 "Peak sample should be hard-limited to ~1.0, got {}",
348 normalized[0]
349 );
350
351 assert!(
354 result_rms > 0.7 && result_rms <= normalizer.target_rms + 0.05,
355 "RMS {result_rms} outside expected post-limiting range"
356 );
357 }
358}