Skip to main content

voirs_recognizer/preprocessing/
realtime_features.rs

1//! Real-time Feature Extraction Module
2//!
3//! Extracts acoustic features from audio streams in real-time for improved
4//! recognition accuracy and analysis.
5
6use crate::RecognitionError;
7use std::collections::HashMap;
8use voirs_sdk::AudioBuffer;
9
10/// Configuration for real-time feature extraction
11#[derive(Debug, Clone)]
12pub struct RealTimeFeatureConfig {
13    /// Window size for feature extraction
14    pub window_size: usize,
15    /// Hop length between windows
16    pub hop_length: usize,
17    /// Number of mel filterbank channels
18    pub n_mels: usize,
19    /// Enable MFCC extraction
20    pub extract_mfcc: bool,
21    /// Enable spectral centroid
22    pub extract_spectral_centroid: bool,
23    /// Enable zero crossing rate
24    pub extract_zcr: bool,
25    /// Enable spectral rolloff
26    pub extract_spectral_rolloff: bool,
27    /// Enable energy features
28    pub extract_energy: bool,
29}
30
31impl Default for RealTimeFeatureConfig {
32    fn default() -> Self {
33        Self {
34            window_size: 512,
35            hop_length: 256,
36            n_mels: 13,
37            extract_mfcc: true,
38            extract_spectral_centroid: true,
39            extract_zcr: true,
40            extract_spectral_rolloff: true,
41            extract_energy: true,
42        }
43    }
44}
45
46/// Feature types that can be extracted
47#[derive(Debug, Clone, PartialEq, Eq, Hash)]
48pub enum FeatureType {
49    /// Mel-frequency cepstral coefficients
50    MFCC,
51    /// Spectral centroid
52    SpectralCentroid,
53    /// Zero crossing rate
54    ZeroCrossingRate,
55    /// Spectral rolloff
56    SpectralRolloff,
57    /// RMS energy
58    Energy,
59    /// Pitch/F0
60    Pitch,
61    /// Spectral bandwidth
62    SpectralBandwidth,
63}
64
65/// Result of real-time feature extraction
66#[derive(Debug, Clone)]
67pub struct RealTimeFeatureResult {
68    /// Extracted features by type
69    pub features: HashMap<FeatureType, Vec<f32>>,
70    /// Number of frames processed
71    pub num_frames: usize,
72    /// Processing time in milliseconds
73    pub processing_time_ms: f32,
74    /// Feature quality metrics
75    pub quality_metrics: HashMap<String, f32>,
76}
77
78impl Default for RealTimeFeatureResult {
79    fn default() -> Self {
80        Self {
81            features: HashMap::new(),
82            num_frames: 0,
83            processing_time_ms: 0.0,
84            quality_metrics: HashMap::new(),
85        }
86    }
87}
88
89/// Real-time feature extractor
90#[derive(Debug)]
91pub struct RealTimeFeatureExtractor {
92    config: RealTimeFeatureConfig,
93    window: Vec<f32>,
94    mel_filterbank: Vec<Vec<f32>>,
95    dct_matrix: Vec<Vec<f32>>,
96}
97
98impl RealTimeFeatureExtractor {
99    /// Create a new real-time feature extractor
100    pub fn new(config: RealTimeFeatureConfig) -> Result<Self, RecognitionError> {
101        let window = Self::create_hann_window(config.window_size);
102        let mel_filterbank = Self::create_mel_filterbank(config.n_mels, config.window_size / 2 + 1);
103        let dct_matrix = Self::create_dct_matrix(config.n_mels);
104
105        Ok(Self {
106            config,
107            window,
108            mel_filterbank,
109            dct_matrix,
110        })
111    }
112
113    /// Create Hann window
114    fn create_hann_window(size: usize) -> Vec<f32> {
115        (0..size)
116            .map(|i| {
117                0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / (size - 1) as f32).cos())
118            })
119            .collect()
120    }
121
122    /// Create mel filterbank
123    fn create_mel_filterbank(n_mels: usize, n_fft: usize) -> Vec<Vec<f32>> {
124        // Simplified mel filterbank creation
125        (0..n_mels)
126            .map(|i| {
127                (0..n_fft)
128                    .map(|j| {
129                        let mel_freq = 2595.0 * (1.0 + j as f32 / n_fft as f32).ln();
130
131                        if i == 0 {
132                            1.0 - (j as f32 / n_fft as f32)
133                        } else {
134                            (mel_freq / (i + 1) as f32).sin().abs()
135                        }
136                    })
137                    .collect()
138            })
139            .collect()
140    }
141
142    /// Create DCT matrix for MFCC
143    fn create_dct_matrix(n_mels: usize) -> Vec<Vec<f32>> {
144        (0..n_mels)
145            .map(|i| {
146                (0..n_mels)
147                    .map(|j| {
148                        ((2.0 * j as f32 + 1.0) * i as f32 * std::f32::consts::PI
149                            / (2.0 * n_mels as f32))
150                            .cos()
151                    })
152                    .collect()
153            })
154            .collect()
155    }
156
157    /// Extract features from audio buffer
158    pub fn extract_features(
159        &self,
160        audio: &AudioBuffer,
161    ) -> Result<RealTimeFeatureResult, RecognitionError> {
162        let start_time = std::time::Instant::now();
163        let mut result = RealTimeFeatureResult::default();
164
165        let samples = audio.samples();
166        let num_frames = (samples.len() - self.config.window_size) / self.config.hop_length + 1;
167        result.num_frames = num_frames;
168
169        for frame_idx in 0..num_frames {
170            let start = frame_idx * self.config.hop_length;
171            let end = (start + self.config.window_size).min(samples.len());
172            let frame = &samples[start..end];
173
174            if frame.len() == self.config.window_size {
175                // Apply window
176                let windowed: Vec<f32> = frame
177                    .iter()
178                    .zip(self.window.iter())
179                    .map(|(s, w)| s * w)
180                    .collect();
181
182                // Extract requested features
183                if self.config.extract_mfcc {
184                    let mfcc = self.extract_mfcc(&windowed)?;
185                    result
186                        .features
187                        .entry(FeatureType::MFCC)
188                        .or_insert_with(Vec::new)
189                        .extend(mfcc);
190                }
191
192                if self.config.extract_spectral_centroid {
193                    let centroid = self.extract_spectral_centroid(&windowed)?;
194                    result
195                        .features
196                        .entry(FeatureType::SpectralCentroid)
197                        .or_insert_with(Vec::new)
198                        .push(centroid);
199                }
200
201                if self.config.extract_zcr {
202                    let zcr = self.extract_zero_crossing_rate(frame)?;
203                    result
204                        .features
205                        .entry(FeatureType::ZeroCrossingRate)
206                        .or_insert_with(Vec::new)
207                        .push(zcr);
208                }
209
210                if self.config.extract_spectral_rolloff {
211                    let rolloff = self.extract_spectral_rolloff(&windowed)?;
212                    result
213                        .features
214                        .entry(FeatureType::SpectralRolloff)
215                        .or_insert_with(Vec::new)
216                        .push(rolloff);
217                }
218
219                if self.config.extract_energy {
220                    let energy = self.extract_energy(frame)?;
221                    result
222                        .features
223                        .entry(FeatureType::Energy)
224                        .or_insert_with(Vec::new)
225                        .push(energy);
226                }
227            }
228        }
229
230        result.processing_time_ms = start_time.elapsed().as_secs_f32() * 1000.0;
231
232        // Calculate quality metrics
233        result
234            .quality_metrics
235            .insert("snr_estimate".to_string(), self.estimate_snr(samples));
236        result.quality_metrics.insert(
237            "spectral_flatness".to_string(),
238            self.calculate_spectral_flatness(samples),
239        );
240
241        Ok(result)
242    }
243
244    /// Extract MFCC features
245    fn extract_mfcc(&self, windowed_frame: &[f32]) -> Result<Vec<f32>, RecognitionError> {
246        // Simplified MFCC extraction
247        let fft = self.simple_fft(windowed_frame);
248        let power_spectrum: Vec<f32> = fft.iter().map(scirs2_core::Complex::norm_sqr).collect();
249
250        // Apply mel filterbank
251        let mel_energies: Vec<f32> = self
252            .mel_filterbank
253            .iter()
254            .map(|filter| {
255                filter
256                    .iter()
257                    .zip(power_spectrum.iter())
258                    .map(|(f, p)| f * p)
259                    .sum::<f32>()
260                    .max(1e-10)
261                    .ln()
262            })
263            .collect();
264
265        // Apply DCT
266        let mfcc: Vec<f32> = self
267            .dct_matrix
268            .iter()
269            .map(|dct_row| {
270                dct_row
271                    .iter()
272                    .zip(mel_energies.iter())
273                    .map(|(d, m)| d * m)
274                    .sum()
275            })
276            .collect();
277
278        Ok(mfcc)
279    }
280
281    /// Extract spectral centroid
282    fn extract_spectral_centroid(&self, windowed_frame: &[f32]) -> Result<f32, RecognitionError> {
283        let fft = self.simple_fft(windowed_frame);
284        let power_spectrum: Vec<f32> = fft.iter().map(scirs2_core::Complex::norm_sqr).collect();
285
286        let numerator: f32 = power_spectrum
287            .iter()
288            .enumerate()
289            .map(|(i, p)| i as f32 * p)
290            .sum();
291
292        let denominator: f32 = power_spectrum.iter().sum();
293
294        Ok(if denominator > 0.0 {
295            numerator / denominator
296        } else {
297            0.0
298        })
299    }
300
301    /// Extract zero crossing rate
302    fn extract_zero_crossing_rate(&self, frame: &[f32]) -> Result<f32, RecognitionError> {
303        let crossings = frame
304            .windows(2)
305            .filter(|w| (w[0] >= 0.0) != (w[1] >= 0.0))
306            .count();
307
308        Ok(crossings as f32 / frame.len() as f32)
309    }
310
311    /// Extract spectral rolloff
312    fn extract_spectral_rolloff(&self, windowed_frame: &[f32]) -> Result<f32, RecognitionError> {
313        let fft = self.simple_fft(windowed_frame);
314        let power_spectrum: Vec<f32> = fft.iter().map(scirs2_core::Complex::norm_sqr).collect();
315
316        let total_energy: f32 = power_spectrum.iter().sum();
317        let threshold = 0.85 * total_energy;
318
319        let mut cumsum = 0.0;
320        for (i, power) in power_spectrum.iter().enumerate() {
321            cumsum += power;
322            if cumsum >= threshold {
323                return Ok(i as f32 / power_spectrum.len() as f32);
324            }
325        }
326
327        Ok(1.0)
328    }
329
330    /// Extract RMS energy
331    fn extract_energy(&self, frame: &[f32]) -> Result<f32, RecognitionError> {
332        let energy: f32 = frame.iter().map(|s| s * s).sum();
333        Ok((energy / frame.len() as f32).sqrt())
334    }
335
336    /// Simple FFT implementation (for demonstration)
337    fn simple_fft(&self, input: &[f32]) -> Vec<scirs2_core::Complex<f32>> {
338        // Very simplified FFT - in practice, use a proper FFT library
339        input
340            .iter()
341            .enumerate()
342            .map(|(i, &sample)| {
343                let angle = -2.0 * std::f32::consts::PI * i as f32 / input.len() as f32;
344                scirs2_core::Complex::new(sample * angle.cos(), sample * angle.sin())
345            })
346            .collect()
347    }
348
349    /// Estimate signal-to-noise ratio
350    fn estimate_snr(&self, samples: &[f32]) -> f32 {
351        let signal_power: f32 = samples.iter().map(|s| s * s).sum();
352        let mean_power = signal_power / samples.len() as f32;
353
354        // Simple noise floor estimation
355        let sorted_powers: Vec<f32> = samples.iter().map(|s| s * s).collect::<Vec<_>>();
356
357        let noise_floor = sorted_powers.iter().take(samples.len() / 10).sum::<f32>()
358            / (samples.len() / 10) as f32;
359
360        if noise_floor > 0.0 {
361            10.0 * (mean_power / noise_floor).log10()
362        } else {
363            60.0 // High SNR if no detectable noise
364        }
365    }
366
367    /// Calculate spectral flatness
368    fn calculate_spectral_flatness(&self, samples: &[f32]) -> f32 {
369        let fft = self.simple_fft(samples);
370        let power_spectrum: Vec<f32> = fft.iter().map(scirs2_core::Complex::norm_sqr).collect();
371
372        let geometric_mean = power_spectrum
373            .iter()
374            .map(|p| p.max(1e-10).ln())
375            .sum::<f32>()
376            / power_spectrum.len() as f32;
377
378        let arithmetic_mean = power_spectrum.iter().sum::<f32>() / power_spectrum.len() as f32;
379
380        if arithmetic_mean > 0.0 {
381            geometric_mean.exp() / arithmetic_mean
382        } else {
383            0.0
384        }
385    }
386
387    /// Get current configuration
388    #[must_use]
389    pub fn config(&self) -> &RealTimeFeatureConfig {
390        &self.config
391    }
392
393    /// Update configuration
394    pub fn set_config(&mut self, config: RealTimeFeatureConfig) -> Result<(), RecognitionError> {
395        self.window = Self::create_hann_window(config.window_size);
396        self.mel_filterbank =
397            Self::create_mel_filterbank(config.n_mels, config.window_size / 2 + 1);
398        self.dct_matrix = Self::create_dct_matrix(config.n_mels);
399        self.config = config;
400        Ok(())
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_realtime_feature_config_default() {
410        let config = RealTimeFeatureConfig::default();
411        assert_eq!(config.window_size, 512);
412        assert_eq!(config.hop_length, 256);
413        assert_eq!(config.n_mels, 13);
414        assert!(config.extract_mfcc);
415        assert!(config.extract_spectral_centroid);
416    }
417
418    #[test]
419    fn test_feature_extractor_creation() {
420        let config = RealTimeFeatureConfig::default();
421        let extractor = RealTimeFeatureExtractor::new(config);
422        assert!(extractor.is_ok());
423    }
424
425    #[test]
426    fn test_feature_extraction() {
427        let config = RealTimeFeatureConfig::default();
428        let extractor = RealTimeFeatureExtractor::new(config).unwrap();
429
430        let samples = vec![0.1; 1024]; // 1024 samples
431        let audio = AudioBuffer::new(samples, 16000, 1);
432
433        let result = extractor.extract_features(&audio);
434        assert!(result.is_ok());
435
436        let features = result.unwrap();
437        assert!(features.features.contains_key(&FeatureType::MFCC));
438        assert!(features
439            .features
440            .contains_key(&FeatureType::SpectralCentroid));
441        assert!(features.num_frames > 0);
442        assert!(features.processing_time_ms >= 0.0);
443    }
444
445    #[test]
446    fn test_feature_types() {
447        let types = vec![
448            FeatureType::MFCC,
449            FeatureType::SpectralCentroid,
450            FeatureType::ZeroCrossingRate,
451            FeatureType::SpectralRolloff,
452            FeatureType::Energy,
453            FeatureType::Pitch,
454            FeatureType::SpectralBandwidth,
455        ];
456
457        for feature_type in types {
458            // Test that feature types are properly comparable
459            assert_eq!(feature_type.clone(), feature_type);
460        }
461    }
462
463    #[test]
464    fn test_feature_result_default() {
465        let result = RealTimeFeatureResult::default();
466        assert!(result.features.is_empty());
467        assert_eq!(result.num_frames, 0);
468        assert!((result.processing_time_ms - 0.0).abs() < f32::EPSILON);
469        assert!(result.quality_metrics.is_empty());
470    }
471}