voirs_recognizer/preprocessing/
bandwidth_extension.rs1use crate::RecognitionError;
10use voirs_sdk::AudioBuffer;
11
12#[derive(Debug, Clone)]
14pub struct BandwidthExtensionConfig {
15 pub target_bandwidth: f32,
17 pub method: ExtensionMethod,
19 pub quality: QualityLevel,
21 pub spectral_replication: bool,
23 pub hf_emphasis: f32,
25}
26
27impl Default for BandwidthExtensionConfig {
28 fn default() -> Self {
29 Self {
30 target_bandwidth: 8000.0,
31 method: ExtensionMethod::SpectralReplication,
32 quality: QualityLevel::Medium,
33 spectral_replication: true,
34 hf_emphasis: 1.2,
35 }
36 }
37}
38
39#[derive(Debug, Clone, PartialEq)]
41pub enum ExtensionMethod {
42 SpectralReplication,
44 LinearPrediction,
46 Neural,
48 Harmonic,
50}
51
52#[derive(Debug, Clone, PartialEq)]
54pub enum QualityLevel {
55 Low,
57 Medium,
59 High,
61}
62
63#[derive(Debug, Clone, Default)]
65pub struct BandwidthExtensionStats {
66 pub original_bandwidth: f32,
68 pub extended_bandwidth: f32,
70 pub spectral_centroid_shift: f32,
72 pub extended_energy: f32,
74 pub processing_time_ms: f32,
76}
77
78pub struct BandwidthExtensionProcessor {
80 config: BandwidthExtensionConfig,
81 stats: BandwidthExtensionStats,
82 filter_banks: Vec<Vec<f32>>,
83}
84
85impl std::fmt::Debug for BandwidthExtensionProcessor {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("BandwidthExtensionProcessor")
88 .field("config", &self.config)
89 .field("stats", &self.stats)
90 .field("filter_banks", &self.filter_banks)
91 .finish()
92 }
93}
94
95impl BandwidthExtensionProcessor {
96 pub fn new(config: BandwidthExtensionConfig) -> Result<Self, RecognitionError> {
98 let filter_banks = Self::create_filter_banks(&config);
100
101 Ok(Self {
102 config,
103 stats: BandwidthExtensionStats::default(),
104 filter_banks,
105 })
106 }
107
108 fn create_filter_banks(config: &BandwidthExtensionConfig) -> Vec<Vec<f32>> {
110 let num_bands = match config.quality {
111 QualityLevel::Low => 4,
112 QualityLevel::Medium => 8,
113 QualityLevel::High => 16,
114 };
115
116 (0..num_bands)
118 .map(|i| {
119 let center_freq = (i + 1) as f32 * 1000.0;
120 vec![
122 0.1 * (center_freq / 1000.0).sin(),
123 0.2 * (center_freq / 1000.0).cos(),
124 0.1 * (center_freq / 2000.0).sin(),
125 ]
126 })
127 .collect()
128 }
129
130 pub fn process(&mut self, audio: &AudioBuffer) -> Result<AudioBuffer, RecognitionError> {
132 let start_time = std::time::Instant::now();
133
134 let samples = audio.samples();
135 let mut extended_samples = samples.to_vec();
136
137 if self.config.spectral_replication {
139 self.apply_spectral_replication(&mut extended_samples, audio.sample_rate())?;
140 }
141
142 if self.config.hf_emphasis != 1.0 {
144 self.apply_hf_emphasis(&mut extended_samples, audio.sample_rate())?;
145 }
146
147 self.stats.processing_time_ms = start_time.elapsed().as_secs_f32() * 1000.0;
149 self.stats.original_bandwidth = audio.sample_rate() as f32 / 2.0;
150 self.stats.extended_bandwidth = self.config.target_bandwidth;
151
152 Ok(AudioBuffer::new(
153 extended_samples,
154 audio.sample_rate(),
155 audio.channels(),
156 ))
157 }
158
159 fn apply_spectral_replication(
161 &self,
162 samples: &mut [f32],
163 sample_rate: u32,
164 ) -> Result<(), RecognitionError> {
165 let nyquist = sample_rate as f32 / 2.0;
167 let extension_factor = self.config.target_bandwidth / nyquist;
168
169 if extension_factor > 1.0 {
170 let len = samples.len();
171 let inv_len = 1.0 / len as f32;
172 let freq_scale = self.config.hf_emphasis * extension_factor;
173
174 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
176 {
177 use std::arch::x86_64::*;
178 let chunks = samples.chunks_exact_mut(8);
179 let remainder = chunks.into_remainder();
180
181 for (chunk_idx, chunk) in samples.chunks_exact_mut(8).enumerate() {
182 unsafe {
183 let orig = _mm256_loadu_ps(chunk.as_ptr());
185
186 let indices: [f32; 8] = std::array::from_fn(|i| {
188 ((chunk_idx * 8 + i) as f32 * inv_len * freq_scale).sin()
189 });
190 let freq_comp = _mm256_loadu_ps(indices.as_ptr());
191
192 let mask = _mm256_set1_ps(-0.0);
194 let abs_orig = _mm256_andnot_ps(mask, orig);
195
196 let scale = _mm256_set1_ps(0.1);
198 let product = _mm256_mul_ps(freq_comp, abs_orig);
199 let scaled = _mm256_mul_ps(product, scale);
200
201 let result = _mm256_add_ps(orig, scaled);
203
204 _mm256_storeu_ps(chunk.as_mut_ptr(), result);
206 }
207 }
208
209 for (i, sample) in remainder.iter_mut().enumerate() {
211 let idx = (len / 8) * 8 + i;
212 let original_sample = *sample;
213 let freq_component = (idx as f32 * inv_len * freq_scale).sin();
214 *sample += 0.1 * freq_component * original_sample.abs();
215 }
216 }
217
218 #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
220 {
221 for (i, sample) in samples.iter_mut().enumerate() {
222 let original_sample = *sample;
223 let freq_component = (i as f32 * inv_len * freq_scale).sin();
224 *sample += 0.1 * freq_component * original_sample.abs();
225 }
226 }
227 }
228
229 Ok(())
230 }
231
232 fn apply_hf_emphasis(
234 &self,
235 samples: &mut [f32],
236 _sample_rate: u32,
237 ) -> Result<(), RecognitionError> {
238 if samples.is_empty() {
239 return Ok(());
240 }
241
242 let emphasis = self.config.hf_emphasis;
243 let diff_scale = (emphasis - 1.0) * 0.1;
244
245 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
247 {
248 use std::arch::x86_64::*;
249
250 if samples.len() >= 9 {
251 let scale_vec = unsafe { _mm256_set1_ps(diff_scale) };
252
253 for i in (1..samples.len() - 7).step_by(8) {
254 unsafe {
255 let current = _mm256_loadu_ps(samples[i..].as_ptr());
257 let previous = _mm256_loadu_ps(samples[i - 1..].as_ptr());
258
259 let diff = _mm256_sub_ps(current, previous);
261
262 let scaled_diff = _mm256_mul_ps(diff, scale_vec);
264
265 let result = _mm256_add_ps(current, scaled_diff);
267
268 _mm256_storeu_ps(samples[i..].as_mut_ptr(), result);
270 }
271 }
272
273 let remainder_start = ((samples.len() - 1) / 8) * 8;
275 let mut prev = samples[remainder_start - 1];
276 for sample in &mut samples[remainder_start..] {
277 let current = *sample;
278 let diff = current - prev;
279 *sample += diff * diff_scale;
280 prev = current;
281 }
282 } else {
283 let mut prev = 0.0;
285 for (i, sample) in samples.iter_mut().enumerate() {
286 let current = *sample;
287 if i > 0 {
288 let diff = current - prev;
289 *sample += diff * diff_scale;
290 }
291 prev = current;
292 }
293 }
294 }
295
296 #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
298 {
299 let mut prev = 0.0;
300 for (i, sample) in samples.iter_mut().enumerate() {
301 let current = *sample;
302 if i > 0 {
303 let diff = current - prev;
304 *sample += diff * diff_scale;
305 }
306 prev = current;
307 }
308 }
309
310 Ok(())
311 }
312
313 #[must_use]
315 pub fn get_stats(&self) -> &BandwidthExtensionStats {
316 &self.stats
317 }
318
319 pub fn set_config(&mut self, config: BandwidthExtensionConfig) -> Result<(), RecognitionError> {
321 self.filter_banks = Self::create_filter_banks(&config);
322 self.config = config;
323 Ok(())
324 }
325
326 #[must_use]
328 pub fn config(&self) -> &BandwidthExtensionConfig {
329 &self.config
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_bandwidth_extension_config_default() {
339 let config = BandwidthExtensionConfig::default();
340 assert!((config.target_bandwidth - 8000.0).abs() < f32::EPSILON);
341 assert_eq!(config.method, ExtensionMethod::SpectralReplication);
342 assert_eq!(config.quality, QualityLevel::Medium);
343 assert!(config.spectral_replication);
344 }
345
346 #[test]
347 fn test_bandwidth_extension_processor_creation() {
348 let config = BandwidthExtensionConfig::default();
349 let processor = BandwidthExtensionProcessor::new(config);
350 assert!(processor.is_ok());
351 }
352
353 #[test]
354 fn test_bandwidth_extension_processing() {
355 let config = BandwidthExtensionConfig::default();
356 let mut processor = BandwidthExtensionProcessor::new(config).unwrap();
357
358 let samples = vec![0.1, 0.2, 0.3, 0.4, 0.3, 0.2, 0.1];
359 let audio = AudioBuffer::new(samples, 16000, 1);
360
361 let result = processor.process(&audio);
362 assert!(result.is_ok());
363
364 let extended = result.unwrap();
365 assert_eq!(extended.sample_rate(), audio.sample_rate());
366 assert_eq!(extended.channels(), audio.channels());
367 assert_eq!(extended.samples().len(), audio.samples().len());
368 }
369
370 #[test]
371 fn test_extension_methods() {
372 let methods = vec![
373 ExtensionMethod::SpectralReplication,
374 ExtensionMethod::LinearPrediction,
375 ExtensionMethod::Neural,
376 ExtensionMethod::Harmonic,
377 ];
378
379 for method in methods {
380 assert_eq!(method.clone(), method);
382 }
383 }
384
385 #[test]
386 fn test_quality_levels() {
387 let levels = vec![QualityLevel::Low, QualityLevel::Medium, QualityLevel::High];
388
389 for level in levels {
390 assert_eq!(level.clone(), level);
392 }
393 }
394
395 #[test]
396 fn test_stats_default() {
397 let stats = BandwidthExtensionStats::default();
398 assert!((stats.original_bandwidth - 0.0).abs() < f32::EPSILON);
399 assert!((stats.extended_bandwidth - 0.0).abs() < f32::EPSILON);
400 assert!((stats.spectral_centroid_shift - 0.0).abs() < f32::EPSILON);
401 assert!((stats.extended_energy - 0.0).abs() < f32::EPSILON);
402 assert!((stats.processing_time_ms - 0.0).abs() < f32::EPSILON);
403 }
404}