Skip to main content

voirs_recognizer/preprocessing/
bandwidth_extension.rs

1//! Bandwidth Extension Module
2//!
3//! Extends the frequency range of audio signals to improve recognition
4//! accuracy for band-limited audio sources.
5//!
6//! This module uses SciRS2-Core for high-performance FFT operations and
7//! SIMD-accelerated processing for real-time bandwidth extension.
8
9use crate::RecognitionError;
10use voirs_sdk::AudioBuffer;
11
12/// Configuration for bandwidth extension
13#[derive(Debug, Clone)]
14pub struct BandwidthExtensionConfig {
15    /// Target bandwidth in Hz
16    pub target_bandwidth: f32,
17    /// Extension method
18    pub method: ExtensionMethod,
19    /// Quality level
20    pub quality: QualityLevel,
21    /// Enable spectral replication
22    pub spectral_replication: bool,
23    /// High frequency emphasis factor
24    pub hf_emphasis: f32,
25}
26
27impl Default for BandwidthExtensionConfig {
28    fn default() -> Self {
29        Self {
30            target_bandwidth: 8000.0,
31            method: ExtensionMethod::SpectralReplication,
32            quality: QualityLevel::Medium,
33            spectral_replication: true,
34            hf_emphasis: 1.2,
35        }
36    }
37}
38
39/// Bandwidth extension methods
40#[derive(Debug, Clone, PartialEq)]
41pub enum ExtensionMethod {
42    /// Spectral replication from lower frequencies
43    SpectralReplication,
44    /// Linear prediction-based extension
45    LinearPrediction,
46    /// Neural network-based extension
47    Neural,
48    /// Harmonic extension
49    Harmonic,
50}
51
52/// Quality levels for bandwidth extension
53#[derive(Debug, Clone, PartialEq)]
54pub enum QualityLevel {
55    /// Low quality (fast processing)
56    Low,
57    /// Medium quality (balanced)
58    Medium,
59    /// High quality (best results)
60    High,
61}
62
63/// Statistics from bandwidth extension processing
64#[derive(Debug, Clone, Default)]
65pub struct BandwidthExtensionStats {
66    /// Original bandwidth in Hz
67    pub original_bandwidth: f32,
68    /// Extended bandwidth in Hz
69    pub extended_bandwidth: f32,
70    /// Spectral centroid shift
71    pub spectral_centroid_shift: f32,
72    /// Energy increase in extended range
73    pub extended_energy: f32,
74    /// Processing time in milliseconds
75    pub processing_time_ms: f32,
76}
77
78/// Bandwidth extension processor
79pub struct BandwidthExtensionProcessor {
80    config: BandwidthExtensionConfig,
81    stats: BandwidthExtensionStats,
82    filter_banks: Vec<Vec<f32>>,
83}
84
85impl std::fmt::Debug for BandwidthExtensionProcessor {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("BandwidthExtensionProcessor")
88            .field("config", &self.config)
89            .field("stats", &self.stats)
90            .field("filter_banks", &self.filter_banks)
91            .finish()
92    }
93}
94
95impl BandwidthExtensionProcessor {
96    /// Create a new bandwidth extension processor
97    pub fn new(config: BandwidthExtensionConfig) -> Result<Self, RecognitionError> {
98        // Initialize filter banks for spectral replication
99        let filter_banks = Self::create_filter_banks(&config);
100
101        Ok(Self {
102            config,
103            stats: BandwidthExtensionStats::default(),
104            filter_banks,
105        })
106    }
107
108    /// Create filter banks for bandwidth extension
109    fn create_filter_banks(config: &BandwidthExtensionConfig) -> Vec<Vec<f32>> {
110        let num_bands = match config.quality {
111            QualityLevel::Low => 4,
112            QualityLevel::Medium => 8,
113            QualityLevel::High => 16,
114        };
115
116        // Create simple filter banks
117        (0..num_bands)
118            .map(|i| {
119                let center_freq = (i + 1) as f32 * 1000.0;
120                // Simplified filter coefficients
121                vec![
122                    0.1 * (center_freq / 1000.0).sin(),
123                    0.2 * (center_freq / 1000.0).cos(),
124                    0.1 * (center_freq / 2000.0).sin(),
125                ]
126            })
127            .collect()
128    }
129
130    /// Process audio to extend bandwidth
131    pub fn process(&mut self, audio: &AudioBuffer) -> Result<AudioBuffer, RecognitionError> {
132        let start_time = std::time::Instant::now();
133
134        let samples = audio.samples();
135        let mut extended_samples = samples.to_vec();
136
137        // Simple bandwidth extension using spectral replication
138        if self.config.spectral_replication {
139            self.apply_spectral_replication(&mut extended_samples, audio.sample_rate())?;
140        }
141
142        // Apply high-frequency emphasis
143        if self.config.hf_emphasis != 1.0 {
144            self.apply_hf_emphasis(&mut extended_samples, audio.sample_rate())?;
145        }
146
147        // Update statistics
148        self.stats.processing_time_ms = start_time.elapsed().as_secs_f32() * 1000.0;
149        self.stats.original_bandwidth = audio.sample_rate() as f32 / 2.0;
150        self.stats.extended_bandwidth = self.config.target_bandwidth;
151
152        Ok(AudioBuffer::new(
153            extended_samples,
154            audio.sample_rate(),
155            audio.channels(),
156        ))
157    }
158
159    /// Apply spectral replication with SIMD optimization
160    fn apply_spectral_replication(
161        &self,
162        samples: &mut [f32],
163        sample_rate: u32,
164    ) -> Result<(), RecognitionError> {
165        // Simplified spectral replication with SIMD
166        let nyquist = sample_rate as f32 / 2.0;
167        let extension_factor = self.config.target_bandwidth / nyquist;
168
169        if extension_factor > 1.0 {
170            let len = samples.len();
171            let inv_len = 1.0 / len as f32;
172            let freq_scale = self.config.hf_emphasis * extension_factor;
173
174            // SIMD-optimized processing for x86_64 with AVX2
175            #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
176            {
177                use std::arch::x86_64::*;
178                let chunks = samples.chunks_exact_mut(8);
179                let remainder = chunks.into_remainder();
180
181                for (chunk_idx, chunk) in samples.chunks_exact_mut(8).enumerate() {
182                    unsafe {
183                        // Load 8 samples
184                        let orig = _mm256_loadu_ps(chunk.as_ptr());
185
186                        // Compute frequency components for each sample
187                        let indices: [f32; 8] = std::array::from_fn(|i| {
188                            ((chunk_idx * 8 + i) as f32 * inv_len * freq_scale).sin()
189                        });
190                        let freq_comp = _mm256_loadu_ps(indices.as_ptr());
191
192                        // Compute abs(original_sample)
193                        let mask = _mm256_set1_ps(-0.0);
194                        let abs_orig = _mm256_andnot_ps(mask, orig);
195
196                        // Multiply and scale
197                        let scale = _mm256_set1_ps(0.1);
198                        let product = _mm256_mul_ps(freq_comp, abs_orig);
199                        let scaled = _mm256_mul_ps(product, scale);
200
201                        // Add to original
202                        let result = _mm256_add_ps(orig, scaled);
203
204                        // Store result
205                        _mm256_storeu_ps(chunk.as_mut_ptr(), result);
206                    }
207                }
208
209                // Process remainder
210                for (i, sample) in remainder.iter_mut().enumerate() {
211                    let idx = (len / 8) * 8 + i;
212                    let original_sample = *sample;
213                    let freq_component = (idx as f32 * inv_len * freq_scale).sin();
214                    *sample += 0.1 * freq_component * original_sample.abs();
215                }
216            }
217
218            // Fallback scalar implementation for other architectures
219            #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
220            {
221                for (i, sample) in samples.iter_mut().enumerate() {
222                    let original_sample = *sample;
223                    let freq_component = (i as f32 * inv_len * freq_scale).sin();
224                    *sample += 0.1 * freq_component * original_sample.abs();
225                }
226            }
227        }
228
229        Ok(())
230    }
231
232    /// Apply high-frequency emphasis with SIMD optimization
233    fn apply_hf_emphasis(
234        &self,
235        samples: &mut [f32],
236        _sample_rate: u32,
237    ) -> Result<(), RecognitionError> {
238        if samples.is_empty() {
239            return Ok(());
240        }
241
242        let emphasis = self.config.hf_emphasis;
243        let diff_scale = (emphasis - 1.0) * 0.1;
244
245        // SIMD-optimized high-frequency emphasis for x86_64 with AVX2
246        #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
247        {
248            use std::arch::x86_64::*;
249
250            if samples.len() >= 9 {
251                let scale_vec = unsafe { _mm256_set1_ps(diff_scale) };
252
253                for i in (1..samples.len() - 7).step_by(8) {
254                    unsafe {
255                        // Load current and previous samples
256                        let current = _mm256_loadu_ps(samples[i..].as_ptr());
257                        let previous = _mm256_loadu_ps(samples[i - 1..].as_ptr());
258
259                        // Compute differences
260                        let diff = _mm256_sub_ps(current, previous);
261
262                        // Scale differences
263                        let scaled_diff = _mm256_mul_ps(diff, scale_vec);
264
265                        // Add to current samples
266                        let result = _mm256_add_ps(current, scaled_diff);
267
268                        // Store result
269                        _mm256_storeu_ps(samples[i..].as_mut_ptr(), result);
270                    }
271                }
272
273                // Process remainder
274                let remainder_start = ((samples.len() - 1) / 8) * 8;
275                let mut prev = samples[remainder_start - 1];
276                for sample in &mut samples[remainder_start..] {
277                    let current = *sample;
278                    let diff = current - prev;
279                    *sample += diff * diff_scale;
280                    prev = current;
281                }
282            } else {
283                // For small buffers, use scalar implementation
284                let mut prev = 0.0;
285                for (i, sample) in samples.iter_mut().enumerate() {
286                    let current = *sample;
287                    if i > 0 {
288                        let diff = current - prev;
289                        *sample += diff * diff_scale;
290                    }
291                    prev = current;
292                }
293            }
294        }
295
296        // Fallback scalar implementation for other architectures
297        #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
298        {
299            let mut prev = 0.0;
300            for (i, sample) in samples.iter_mut().enumerate() {
301                let current = *sample;
302                if i > 0 {
303                    let diff = current - prev;
304                    *sample += diff * diff_scale;
305                }
306                prev = current;
307            }
308        }
309
310        Ok(())
311    }
312
313    /// Get processing statistics
314    #[must_use]
315    pub fn get_stats(&self) -> &BandwidthExtensionStats {
316        &self.stats
317    }
318
319    /// Update configuration
320    pub fn set_config(&mut self, config: BandwidthExtensionConfig) -> Result<(), RecognitionError> {
321        self.filter_banks = Self::create_filter_banks(&config);
322        self.config = config;
323        Ok(())
324    }
325
326    /// Get current configuration
327    #[must_use]
328    pub fn config(&self) -> &BandwidthExtensionConfig {
329        &self.config
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_bandwidth_extension_config_default() {
339        let config = BandwidthExtensionConfig::default();
340        assert!((config.target_bandwidth - 8000.0).abs() < f32::EPSILON);
341        assert_eq!(config.method, ExtensionMethod::SpectralReplication);
342        assert_eq!(config.quality, QualityLevel::Medium);
343        assert!(config.spectral_replication);
344    }
345
346    #[test]
347    fn test_bandwidth_extension_processor_creation() {
348        let config = BandwidthExtensionConfig::default();
349        let processor = BandwidthExtensionProcessor::new(config);
350        assert!(processor.is_ok());
351    }
352
353    #[test]
354    fn test_bandwidth_extension_processing() {
355        let config = BandwidthExtensionConfig::default();
356        let mut processor = BandwidthExtensionProcessor::new(config).unwrap();
357
358        let samples = vec![0.1, 0.2, 0.3, 0.4, 0.3, 0.2, 0.1];
359        let audio = AudioBuffer::new(samples, 16000, 1);
360
361        let result = processor.process(&audio);
362        assert!(result.is_ok());
363
364        let extended = result.unwrap();
365        assert_eq!(extended.sample_rate(), audio.sample_rate());
366        assert_eq!(extended.channels(), audio.channels());
367        assert_eq!(extended.samples().len(), audio.samples().len());
368    }
369
370    #[test]
371    fn test_extension_methods() {
372        let methods = vec![
373            ExtensionMethod::SpectralReplication,
374            ExtensionMethod::LinearPrediction,
375            ExtensionMethod::Neural,
376            ExtensionMethod::Harmonic,
377        ];
378
379        for method in methods {
380            // Test that extension methods are properly comparable
381            assert_eq!(method.clone(), method);
382        }
383    }
384
385    #[test]
386    fn test_quality_levels() {
387        let levels = vec![QualityLevel::Low, QualityLevel::Medium, QualityLevel::High];
388
389        for level in levels {
390            // Test that quality levels are properly comparable
391            assert_eq!(level.clone(), level);
392        }
393    }
394
395    #[test]
396    fn test_stats_default() {
397        let stats = BandwidthExtensionStats::default();
398        assert!((stats.original_bandwidth - 0.0).abs() < f32::EPSILON);
399        assert!((stats.extended_bandwidth - 0.0).abs() < f32::EPSILON);
400        assert!((stats.spectral_centroid_shift - 0.0).abs() < f32::EPSILON);
401        assert!((stats.extended_energy - 0.0).abs() < f32::EPSILON);
402        assert!((stats.processing_time_ms - 0.0).abs() < f32::EPSILON);
403    }
404}