Skip to main content

voirs_recognizer/preprocessing/
advanced_spectral.rs

1//! Advanced spectral processing for audio enhancement
2//!
3//! This module provides sophisticated spectral processing techniques including:
4//! - Spectral noise gating
5//! - Harmonic enhancement
6//! - Spectral subtraction with oversubtraction factor
7//! - Multi-band dynamic range compression
8//! - Perceptual spectral shaping
9
10use crate::RecognitionError;
11use std::f32::consts::PI;
12use voirs_sdk::AudioBuffer;
13
14/// Advanced spectral processing configuration
15#[derive(Debug, Clone)]
16/// Advanced Spectral Config
17pub struct AdvancedSpectralConfig {
18    /// FFT size for spectral analysis
19    pub fft_size: usize,
20    /// Hop length for STFT
21    pub hop_length: usize,
22    /// Window type for STFT
23    pub window_type: WindowType,
24    /// Enable spectral noise gating
25    pub spectral_noise_gate: bool,
26    /// Noise gate threshold (dB)
27    pub noise_gate_threshold: f32,
28    /// Enable harmonic enhancement
29    pub harmonic_enhancement: bool,
30    /// Harmonic enhancement factor
31    pub harmonic_factor: f32,
32    /// Enable multi-band compression
33    pub multiband_compression: bool,
34    /// Number of frequency bands for compression
35    pub num_bands: usize,
36    /// Compression ratios for each band
37    pub compression_ratios: Vec<f32>,
38    /// Enable perceptual shaping
39    pub perceptual_shaping: bool,
40    /// Sample rate
41    pub sample_rate: u32,
42}
43
44impl Default for AdvancedSpectralConfig {
45    fn default() -> Self {
46        Self {
47            fft_size: 2048,
48            hop_length: 512,
49            window_type: WindowType::Hann,
50            spectral_noise_gate: true,
51            noise_gate_threshold: -40.0,
52            harmonic_enhancement: true,
53            harmonic_factor: 1.2,
54            multiband_compression: true,
55            num_bands: 4,
56            compression_ratios: vec![2.0, 3.0, 4.0, 2.5],
57            perceptual_shaping: true,
58            sample_rate: 16000,
59        }
60    }
61}
62
63/// Window types for STFT
64#[derive(Debug, Clone, Copy)]
65/// Window Type
66pub enum WindowType {
67    /// Hann
68    Hann,
69    /// Hamming
70    Hamming,
71    /// Blackman
72    Blackman,
73    /// Kaiser
74    Kaiser,
75    /// Tukey
76    Tukey,
77}
78
79/// Complex number for FFT operations
80#[derive(Debug, Clone, Copy)]
81struct Complex {
82    real: f32,
83    imag: f32,
84}
85
86impl Complex {
87    fn new(real: f32, imag: f32) -> Self {
88        Self { real, imag }
89    }
90
91    fn magnitude(&self) -> f32 {
92        (self.real * self.real + self.imag * self.imag).sqrt()
93    }
94
95    fn phase(&self) -> f32 {
96        self.imag.atan2(self.real)
97    }
98
99    fn from_polar(magnitude: f32, phase: f32) -> Self {
100        Self {
101            real: magnitude * phase.cos(),
102            imag: magnitude * phase.sin(),
103        }
104    }
105}
106
107/// Processing statistics for advanced spectral processing
108#[derive(Debug, Clone)]
109/// Advanced Spectral Stats
110pub struct AdvancedSpectralStats {
111    /// Noise gate activation percentage
112    pub noise_gate_activation: f32,
113    /// Harmonic enhancement gain (dB)
114    pub harmonic_gain_db: f32,
115    /// Average compression ratio applied
116    pub avg_compression_ratio: f32,
117    /// Spectral centroid (Hz)
118    pub spectral_centroid: f32,
119    /// Processing latency (ms)
120    pub processing_latency_ms: f64,
121}
122
123/// Result of advanced spectral processing
124#[derive(Debug, Clone)]
125/// Advanced Spectral Result
126pub struct AdvancedSpectralResult {
127    /// Enhanced audio buffer
128    pub enhanced_audio: AudioBuffer,
129    /// Processing statistics
130    pub stats: AdvancedSpectralStats,
131}
132
133/// Advanced spectral processor
134#[derive(Debug)]
135/// Advanced Spectral Processor
136pub struct AdvancedSpectralProcessor {
137    config: AdvancedSpectralConfig,
138    window: Vec<f32>,
139    fft_buffer: Vec<Complex>,
140    prev_phase: Vec<f32>,
141    overlap_buffer: Vec<f32>,
142    bark_scale_weights: Vec<f32>,
143}
144
145impl AdvancedSpectralProcessor {
146    /// Create a new advanced spectral processor
147    pub fn new(config: AdvancedSpectralConfig) -> Result<Self, RecognitionError> {
148        let window = Self::generate_window(config.fft_size, config.window_type);
149        let fft_buffer = vec![Complex::new(0.0, 0.0); config.fft_size];
150        let prev_phase = vec![0.0; config.fft_size / 2 + 1];
151        let overlap_buffer = vec![0.0; config.fft_size];
152        let bark_scale_weights =
153            Self::generate_bark_scale_weights(config.fft_size, config.sample_rate);
154
155        Ok(Self {
156            config,
157            window,
158            fft_buffer,
159            prev_phase,
160            overlap_buffer,
161            bark_scale_weights,
162        })
163    }
164
165    /// Process audio with advanced spectral techniques
166    pub fn process(
167        &mut self,
168        audio: &AudioBuffer,
169    ) -> Result<AdvancedSpectralResult, RecognitionError> {
170        let start_time = std::time::Instant::now();
171
172        let samples = audio.samples();
173        let mut enhanced_samples = Vec::with_capacity(samples.len());
174
175        // Initialize statistics
176        let mut noise_gate_activations = 0;
177        let mut total_frames = 0;
178        let mut total_harmonic_gain = 0.0;
179        let mut total_compression = 0.0;
180        let mut spectral_centroid_sum = 0.0;
181
182        // Process in overlapping frames
183        let mut pos = 0;
184        while pos + self.config.fft_size <= samples.len() {
185            // Extract frame and apply window
186            let mut frame = vec![0.0; self.config.fft_size];
187            for i in 0..self.config.fft_size {
188                if pos + i < samples.len() {
189                    frame[i] = samples[pos + i] * self.window[i];
190                }
191            }
192
193            // Forward FFT
194            self.real_fft(&frame);
195
196            // Extract magnitude and phase
197            let mut magnitudes = vec![0.0; self.config.fft_size / 2 + 1];
198            let mut phases = vec![0.0; self.config.fft_size / 2 + 1];
199
200            for i in 0..magnitudes.len() {
201                magnitudes[i] = self.fft_buffer[i].magnitude();
202                phases[i] = self.fft_buffer[i].phase();
203            }
204
205            // Apply spectral noise gate
206            let gate_active = if self.config.spectral_noise_gate {
207                self.apply_spectral_noise_gate(&mut magnitudes)?
208            } else {
209                false
210            };
211
212            if gate_active {
213                noise_gate_activations += 1;
214            }
215
216            // Apply harmonic enhancement
217            let harmonic_gain = if self.config.harmonic_enhancement {
218                self.apply_harmonic_enhancement(&mut magnitudes)?
219            } else {
220                0.0
221            };
222            total_harmonic_gain += harmonic_gain;
223
224            // Apply multi-band compression
225            let compression_ratio = if self.config.multiband_compression {
226                self.apply_multiband_compression(&mut magnitudes)?
227            } else {
228                1.0
229            };
230            total_compression += compression_ratio;
231
232            // Apply perceptual shaping
233            if self.config.perceptual_shaping {
234                self.apply_perceptual_shaping(&mut magnitudes)?;
235            }
236
237            // Calculate spectral centroid
238            spectral_centroid_sum += self.calculate_spectral_centroid(&magnitudes);
239
240            // Reconstruct complex spectrum
241            for i in 0..magnitudes.len() {
242                self.fft_buffer[i] = Complex::from_polar(magnitudes[i], phases[i]);
243            }
244
245            // Inverse FFT
246            let enhanced_frame = self.inverse_real_fft();
247
248            // Overlap-add synthesis
249            self.overlap_add(&enhanced_frame, &mut enhanced_samples, pos);
250
251            pos += self.config.hop_length;
252            total_frames += 1;
253        }
254
255        // Handle remaining samples
256        while enhanced_samples.len() < samples.len() {
257            enhanced_samples.push(0.0);
258        }
259        enhanced_samples.truncate(samples.len());
260
261        let processing_time = start_time.elapsed().as_secs_f64() * 1000.0;
262
263        // Calculate final statistics
264        let stats = AdvancedSpectralStats {
265            noise_gate_activation: if total_frames > 0 {
266                noise_gate_activations as f32 / total_frames as f32 * 100.0
267            } else {
268                0.0
269            },
270            harmonic_gain_db: if total_frames > 0 {
271                20.0 * (total_harmonic_gain / total_frames as f32).log10()
272            } else {
273                0.0
274            },
275            avg_compression_ratio: if total_frames > 0 {
276                total_compression / total_frames as f32
277            } else {
278                1.0
279            },
280            spectral_centroid: if total_frames > 0 {
281                spectral_centroid_sum / total_frames as f32
282            } else {
283                0.0
284            },
285            processing_latency_ms: processing_time,
286        };
287
288        let enhanced_audio =
289            AudioBuffer::new(enhanced_samples, audio.sample_rate(), audio.channels());
290
291        Ok(AdvancedSpectralResult {
292            enhanced_audio,
293            stats,
294        })
295    }
296
297    /// Generate window function
298    fn generate_window(size: usize, window_type: WindowType) -> Vec<f32> {
299        let mut window = vec![0.0; size];
300
301        match window_type {
302            WindowType::Hann => {
303                for i in 0..size {
304                    let phase = 2.0 * PI * i as f32 / (size - 1) as f32;
305                    window[i] = 0.5 * (1.0 - phase.cos());
306                }
307            }
308            WindowType::Hamming => {
309                for i in 0..size {
310                    let phase = 2.0 * PI * i as f32 / (size - 1) as f32;
311                    window[i] = 0.54 - 0.46 * phase.cos();
312                }
313            }
314            WindowType::Blackman => {
315                for i in 0..size {
316                    let phase = 2.0 * PI * i as f32 / (size - 1) as f32;
317                    window[i] = 0.42 - 0.5 * phase.cos() + 0.08 * (2.0 * phase).cos();
318                }
319            }
320            WindowType::Kaiser => {
321                // Simplified Kaiser window (beta = 8.6)
322                let beta = 8.6;
323                let i0_beta = Self::modified_bessel_i0(beta);
324                for i in 0..size {
325                    let x = 2.0 * i as f32 / (size - 1) as f32 - 1.0;
326                    let arg = beta * (1.0 - x * x).sqrt();
327                    window[i] = Self::modified_bessel_i0(arg) / i0_beta;
328                }
329            }
330            WindowType::Tukey => {
331                let alpha = 0.5;
332                let transition_width = (alpha * size as f32 / 2.0) as usize;
333
334                for i in 0..size {
335                    if i < transition_width {
336                        let phase = PI * i as f32 / transition_width as f32;
337                        window[i] = 0.5 * (1.0 - phase.cos());
338                    } else if i >= size - transition_width {
339                        let phase = PI * (size - 1 - i) as f32 / transition_width as f32;
340                        window[i] = 0.5 * (1.0 - phase.cos());
341                    } else {
342                        window[i] = 1.0;
343                    }
344                }
345            }
346        }
347
348        window
349    }
350
351    /// Modified Bessel function I0 (for Kaiser window)
352    fn modified_bessel_i0(x: f32) -> f32 {
353        let mut sum = 1.0;
354        let mut term = 1.0;
355        let x_half_squared = (x / 2.0) * (x / 2.0);
356
357        for k in 1..=20 {
358            term *= x_half_squared / (k as f32 * k as f32);
359            sum += term;
360            if term < 1e-8 {
361                break;
362            }
363        }
364
365        sum
366    }
367
368    /// Generate bark scale weights for perceptual processing
369    fn generate_bark_scale_weights(fft_size: usize, sample_rate: u32) -> Vec<f32> {
370        let num_bins = fft_size / 2 + 1;
371        let mut weights = vec![1.0; num_bins];
372
373        for i in 0..num_bins {
374            let freq = i as f32 * sample_rate as f32 / fft_size as f32;
375            let bark =
376                13.0 * (freq / 1315.8).atan() + 3.5 * ((freq / 7518.0) * (freq / 7518.0)).atan();
377
378            // Apply bark scale weighting (emphasize perceptually important frequencies)
379            weights[i] = 1.0 + 0.3 * (-((bark - 8.0) / 4.0).powi(2)).exp();
380        }
381
382        weights
383    }
384
385    /// Apply spectral noise gate
386    fn apply_spectral_noise_gate(&self, magnitudes: &mut [f32]) -> Result<bool, RecognitionError> {
387        let threshold_linear = 10.0_f32.powf(self.config.noise_gate_threshold / 20.0);
388        let mut gate_active = false;
389
390        for magnitude in magnitudes.iter_mut() {
391            if *magnitude < threshold_linear {
392                *magnitude *= 0.1; // Attenuate by 20dB
393                gate_active = true;
394            }
395        }
396
397        Ok(gate_active)
398    }
399
400    /// Apply harmonic enhancement
401    fn apply_harmonic_enhancement(&self, magnitudes: &mut [f32]) -> Result<f32, RecognitionError> {
402        let mut total_enhancement = 0.0;
403        let num_bins = magnitudes.len();
404
405        // Find harmonic peaks and enhance them
406        for i in 1..num_bins - 1 {
407            let is_peak = magnitudes[i] > magnitudes[i - 1] && magnitudes[i] > magnitudes[i + 1];
408
409            if is_peak {
410                let enhancement = self.config.harmonic_factor;
411                magnitudes[i] *= enhancement;
412                total_enhancement += enhancement;
413            }
414        }
415
416        Ok(total_enhancement / num_bins as f32)
417    }
418
419    /// Apply multi-band dynamic range compression
420    fn apply_multiband_compression(&self, magnitudes: &mut [f32]) -> Result<f32, RecognitionError> {
421        let num_bins = magnitudes.len();
422        let band_size = num_bins / self.config.num_bands;
423        let mut total_compression = 0.0;
424
425        for band in 0..self.config.num_bands {
426            let start_bin = band * band_size;
427            let end_bin = ((band + 1) * band_size).min(num_bins);
428            let compression_ratio = self.config.compression_ratios.get(band).unwrap_or(&2.0);
429
430            // Calculate band energy
431            let mut band_energy = 0.0;
432            for i in start_bin..end_bin {
433                band_energy += magnitudes[i] * magnitudes[i];
434            }
435            band_energy = (band_energy / (end_bin - start_bin) as f32).sqrt();
436
437            // Apply compression
438            if band_energy > 0.0 {
439                let compressed_energy = band_energy.powf(1.0 / compression_ratio);
440                let gain = compressed_energy / band_energy;
441
442                for i in start_bin..end_bin {
443                    magnitudes[i] *= gain;
444                }
445
446                total_compression += *compression_ratio;
447            }
448        }
449
450        Ok(total_compression / self.config.num_bands as f32)
451    }
452
453    /// Apply perceptual shaping based on bark scale
454    fn apply_perceptual_shaping(&self, magnitudes: &mut [f32]) -> Result<(), RecognitionError> {
455        for (i, magnitude) in magnitudes.iter_mut().enumerate() {
456            if i < self.bark_scale_weights.len() {
457                *magnitude *= self.bark_scale_weights[i];
458            }
459        }
460        Ok(())
461    }
462
463    /// Calculate spectral centroid
464    fn calculate_spectral_centroid(&self, magnitudes: &[f32]) -> f32 {
465        let mut weighted_sum = 0.0;
466        let mut magnitude_sum = 0.0;
467
468        for (i, &magnitude) in magnitudes.iter().enumerate() {
469            let freq = i as f32 * self.config.sample_rate as f32 / self.config.fft_size as f32;
470            weighted_sum += freq * magnitude;
471            magnitude_sum += magnitude;
472        }
473
474        if magnitude_sum > 0.0 {
475            weighted_sum / magnitude_sum
476        } else {
477            0.0
478        }
479    }
480
481    /// Simplified real FFT (placeholder - in production use a proper FFT library)
482    fn real_fft(&mut self, input: &[f32]) {
483        // This is a simplified implementation
484        // In production, use rustfft or similar
485        for k in 0..=(self.config.fft_size / 2) {
486            let mut real_sum = 0.0;
487            let mut imag_sum = 0.0;
488
489            for n in 0..self.config.fft_size {
490                let angle = -2.0 * PI * k as f32 * n as f32 / self.config.fft_size as f32;
491                real_sum += input[n] * angle.cos();
492                imag_sum += input[n] * angle.sin();
493            }
494
495            self.fft_buffer[k] = Complex::new(real_sum, imag_sum);
496        }
497    }
498
499    /// Simplified inverse real FFT
500    fn inverse_real_fft(&self) -> Vec<f32> {
501        let mut output = vec![0.0; self.config.fft_size];
502
503        for n in 0..self.config.fft_size {
504            for k in 0..=(self.config.fft_size / 2) {
505                let angle = 2.0 * PI * k as f32 * n as f32 / self.config.fft_size as f32;
506                let weight = if k == 0 || k == self.config.fft_size / 2 {
507                    1.0
508                } else {
509                    2.0
510                };
511                output[n] += weight
512                    * (self.fft_buffer[k].real * angle.cos()
513                        - self.fft_buffer[k].imag * angle.sin());
514            }
515            output[n] /= self.config.fft_size as f32;
516        }
517
518        output
519    }
520
521    /// Overlap-add synthesis
522    fn overlap_add(&mut self, frame: &[f32], output: &mut Vec<f32>, pos: usize) {
523        let output_start = output.len();
524        let required_length = pos + self.config.fft_size;
525
526        // Extend output buffer if necessary
527        while output.len() < required_length {
528            output.push(0.0);
529        }
530
531        // Add with overlap
532        for i in 0..self.config.fft_size {
533            if pos + i < output.len() {
534                output[pos + i] += frame[i];
535            }
536        }
537    }
538
539    /// Reset processor state
540    pub fn reset(&mut self) -> Result<(), RecognitionError> {
541        self.prev_phase.fill(0.0);
542        self.overlap_buffer.fill(0.0);
543        Ok(())
544    }
545
546    /// Get current configuration
547    #[must_use]
548    pub fn config(&self) -> &AdvancedSpectralConfig {
549        &self.config
550    }
551}
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556
557    #[test]
558    fn test_advanced_spectral_processor_creation() {
559        let config = AdvancedSpectralConfig::default();
560        let processor = AdvancedSpectralProcessor::new(config);
561        assert!(processor.is_ok());
562    }
563
564    #[test]
565    fn test_window_generation() {
566        let window = AdvancedSpectralProcessor::generate_window(1024, WindowType::Hann);
567        assert_eq!(window.len(), 1024);
568        assert!(window[0] < 0.1); // Start near zero
569        assert!(window[512] > 0.9); // Peak near middle
570        assert!(window[1023] < 0.1); // End near zero
571    }
572
573    #[test]
574    fn test_spectral_processing() {
575        let config = AdvancedSpectralConfig::default();
576        let mut processor = AdvancedSpectralProcessor::new(config).unwrap();
577
578        let samples = vec![0.1; 4096];
579        let audio = AudioBuffer::mono(samples, 16000);
580
581        let result = processor.process(&audio);
582        assert!(result.is_ok());
583
584        let result = result.unwrap();
585        assert_eq!(result.enhanced_audio.samples().len(), 4096);
586        assert!(result.stats.processing_latency_ms > 0.0);
587    }
588
589    #[test]
590    fn test_bark_scale_weights() {
591        let weights = AdvancedSpectralProcessor::generate_bark_scale_weights(2048, 16000);
592        assert_eq!(weights.len(), 1025); // FFT size / 2 + 1
593        assert!(weights.iter().all(|&w| w > 0.0));
594    }
595
596    #[test]
597    fn test_spectral_centroid() {
598        let config = AdvancedSpectralConfig::default();
599        let processor = AdvancedSpectralProcessor::new(config).unwrap();
600
601        let magnitudes = vec![1.0, 2.0, 3.0, 2.0, 1.0];
602        let centroid = processor.calculate_spectral_centroid(&magnitudes);
603        assert!(centroid > 0.0);
604    }
605}