wavekat_vad/preprocessing/
biquad.rs1use std::f64::consts::PI;
4
5#[derive(Debug, Clone)]
14pub struct BiquadFilter {
15 b0: f64,
17 b1: f64,
18 b2: f64,
19 a1: f64,
20 a2: f64,
21 x1: f64, x2: f64, y1: f64, y2: f64, }
27
28impl BiquadFilter {
29 pub fn highpass_butterworth(cutoff_hz: f32, sample_rate: u32) -> Self {
38 let fs = sample_rate as f64;
39 let fc = cutoff_hz as f64;
40
41 assert!(
42 fc < fs / 2.0,
43 "cutoff frequency must be below Nyquist frequency"
44 );
45
46 let q = std::f64::consts::FRAC_1_SQRT_2; let omega = 2.0 * PI * fc / fs;
51 let cos_omega = omega.cos();
52 let sin_omega = omega.sin();
53 let alpha = sin_omega / (2.0 * q);
54
55 let b0 = (1.0 + cos_omega) / 2.0;
57 let b1 = -(1.0 + cos_omega);
58 let b2 = (1.0 + cos_omega) / 2.0;
59 let a0 = 1.0 + alpha;
60 let a1 = -2.0 * cos_omega;
61 let a2 = 1.0 - alpha;
62
63 Self {
65 b0: b0 / a0,
66 b1: b1 / a0,
67 b2: b2 / a0,
68 a1: a1 / a0,
69 a2: a2 / a0,
70 x1: 0.0,
71 x2: 0.0,
72 y1: 0.0,
73 y2: 0.0,
74 }
75 }
76
77 #[inline]
79 pub fn process_sample(&mut self, x: f64) -> f64 {
80 let y = self.b0 * x + self.b1 * self.x1 + self.b2 * self.x2
81 - self.a1 * self.y1
82 - self.a2 * self.y2;
83
84 self.x2 = self.x1;
86 self.x1 = x;
87 self.y2 = self.y1;
88 self.y1 = y;
89
90 y
91 }
92
93 pub fn process_i16(&mut self, samples: &mut [i16]) {
95 for sample in samples.iter_mut() {
96 let x = *sample as f64;
97 let y = self.process_sample(x);
98 *sample = y.round().clamp(-32768.0, 32767.0) as i16;
100 }
101 }
102
103 pub fn process_i16_to_vec(&mut self, samples: &[i16]) -> Vec<i16> {
105 samples
106 .iter()
107 .map(|&s| {
108 let y = self.process_sample(s as f64);
109 y.round().clamp(-32768.0, 32767.0) as i16
110 })
111 .collect()
112 }
113
114 pub fn reset(&mut self) {
116 self.x1 = 0.0;
117 self.x2 = 0.0;
118 self.y1 = 0.0;
119 self.y2 = 0.0;
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn test_highpass_creation() {
129 let filter = BiquadFilter::highpass_butterworth(80.0, 16000);
130 assert!(filter.b0.is_finite());
132 assert!(filter.a1.is_finite());
133 }
134
135 #[test]
136 #[should_panic(expected = "cutoff frequency must be below Nyquist")]
137 fn test_highpass_invalid_cutoff() {
138 BiquadFilter::highpass_butterworth(8000.0, 16000);
140 }
141
142 #[test]
143 fn test_highpass_attenuates_dc() {
144 let mut filter = BiquadFilter::highpass_butterworth(100.0, 16000);
145
146 let dc_samples: Vec<i16> = vec![10000; 1000];
148 let output = filter.process_i16_to_vec(&dc_samples);
149
150 let last_100: i32 = output[900..].iter().map(|&s| s.abs() as i32).sum();
152 let avg = last_100 / 100;
153 assert!(
154 avg < 100,
155 "DC should be heavily attenuated, got avg abs: {avg}"
156 );
157 }
158
159 #[test]
160 fn test_highpass_passes_high_frequencies() {
161 let mut filter = BiquadFilter::highpass_butterworth(100.0, 16000);
162
163 let sample_rate = 16000.0;
165 let freq = 1000.0;
166 let samples: Vec<i16> = (0..1600)
167 .map(|i| {
168 let t = i as f64 / sample_rate;
169 (10000.0 * (2.0 * PI * freq * t).sin()) as i16
170 })
171 .collect();
172
173 let output = filter.process_i16_to_vec(&samples);
174
175 let input_rms: f64 = (samples[800..]
178 .iter()
179 .map(|&s| (s as f64).powi(2))
180 .sum::<f64>()
181 / 800.0)
182 .sqrt();
183 let output_rms: f64 = (output[800..]
184 .iter()
185 .map(|&s| (s as f64).powi(2))
186 .sum::<f64>()
187 / 800.0)
188 .sqrt();
189
190 let ratio = output_rms / input_rms;
192 assert!(
193 ratio > 0.9,
194 "1kHz should pass with >90% amplitude, got {:.1}%",
195 ratio * 100.0
196 );
197 }
198
199 #[test]
200 fn test_highpass_attenuates_low_frequencies() {
201 let mut filter = BiquadFilter::highpass_butterworth(200.0, 16000);
202
203 let sample_rate = 16000.0;
205 let freq = 50.0;
206 let samples: Vec<i16> = (0..3200) .map(|i| {
208 let t = i as f64 / sample_rate;
209 (10000.0 * (2.0 * PI * freq * t).sin()) as i16
210 })
211 .collect();
212
213 let output = filter.process_i16_to_vec(&samples);
214
215 let input_rms: f64 = (samples[1600..]
217 .iter()
218 .map(|&s| (s as f64).powi(2))
219 .sum::<f64>()
220 / 1600.0)
221 .sqrt();
222 let output_rms: f64 = (output[1600..]
223 .iter()
224 .map(|&s| (s as f64).powi(2))
225 .sum::<f64>()
226 / 1600.0)
227 .sqrt();
228
229 let ratio = output_rms / input_rms;
232 assert!(
233 ratio < 0.35,
234 "50Hz should be attenuated to <35% at 200Hz cutoff, got {:.1}%",
235 ratio * 100.0
236 );
237 }
238
239 #[test]
240 fn test_reset() {
241 let mut filter = BiquadFilter::highpass_butterworth(100.0, 16000);
242
243 let samples: Vec<i16> = vec![1000; 100];
245 filter.process_i16_to_vec(&samples);
246
247 filter.reset();
249 assert_eq!(filter.x1, 0.0);
250 assert_eq!(filter.x2, 0.0);
251 assert_eq!(filter.y1, 0.0);
252 assert_eq!(filter.y2, 0.0);
253 }
254}