voirs_recognizer/preprocessing/
realtime_features.rs1use crate::RecognitionError;
7use std::collections::HashMap;
8use voirs_sdk::AudioBuffer;
9
10#[derive(Debug, Clone)]
12pub struct RealTimeFeatureConfig {
13 pub window_size: usize,
15 pub hop_length: usize,
17 pub n_mels: usize,
19 pub extract_mfcc: bool,
21 pub extract_spectral_centroid: bool,
23 pub extract_zcr: bool,
25 pub extract_spectral_rolloff: bool,
27 pub extract_energy: bool,
29}
30
31impl Default for RealTimeFeatureConfig {
32 fn default() -> Self {
33 Self {
34 window_size: 512,
35 hop_length: 256,
36 n_mels: 13,
37 extract_mfcc: true,
38 extract_spectral_centroid: true,
39 extract_zcr: true,
40 extract_spectral_rolloff: true,
41 extract_energy: true,
42 }
43 }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
48pub enum FeatureType {
49 MFCC,
51 SpectralCentroid,
53 ZeroCrossingRate,
55 SpectralRolloff,
57 Energy,
59 Pitch,
61 SpectralBandwidth,
63}
64
65#[derive(Debug, Clone)]
67pub struct RealTimeFeatureResult {
68 pub features: HashMap<FeatureType, Vec<f32>>,
70 pub num_frames: usize,
72 pub processing_time_ms: f32,
74 pub quality_metrics: HashMap<String, f32>,
76}
77
78impl Default for RealTimeFeatureResult {
79 fn default() -> Self {
80 Self {
81 features: HashMap::new(),
82 num_frames: 0,
83 processing_time_ms: 0.0,
84 quality_metrics: HashMap::new(),
85 }
86 }
87}
88
89#[derive(Debug)]
91pub struct RealTimeFeatureExtractor {
92 config: RealTimeFeatureConfig,
93 window: Vec<f32>,
94 mel_filterbank: Vec<Vec<f32>>,
95 dct_matrix: Vec<Vec<f32>>,
96}
97
98impl RealTimeFeatureExtractor {
99 pub fn new(config: RealTimeFeatureConfig) -> Result<Self, RecognitionError> {
101 let window = Self::create_hann_window(config.window_size);
102 let mel_filterbank = Self::create_mel_filterbank(config.n_mels, config.window_size / 2 + 1);
103 let dct_matrix = Self::create_dct_matrix(config.n_mels);
104
105 Ok(Self {
106 config,
107 window,
108 mel_filterbank,
109 dct_matrix,
110 })
111 }
112
113 fn create_hann_window(size: usize) -> Vec<f32> {
115 (0..size)
116 .map(|i| {
117 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / (size - 1) as f32).cos())
118 })
119 .collect()
120 }
121
122 fn create_mel_filterbank(n_mels: usize, n_fft: usize) -> Vec<Vec<f32>> {
124 (0..n_mels)
126 .map(|i| {
127 (0..n_fft)
128 .map(|j| {
129 let mel_freq = 2595.0 * (1.0 + j as f32 / n_fft as f32).ln();
130
131 if i == 0 {
132 1.0 - (j as f32 / n_fft as f32)
133 } else {
134 (mel_freq / (i + 1) as f32).sin().abs()
135 }
136 })
137 .collect()
138 })
139 .collect()
140 }
141
142 fn create_dct_matrix(n_mels: usize) -> Vec<Vec<f32>> {
144 (0..n_mels)
145 .map(|i| {
146 (0..n_mels)
147 .map(|j| {
148 ((2.0 * j as f32 + 1.0) * i as f32 * std::f32::consts::PI
149 / (2.0 * n_mels as f32))
150 .cos()
151 })
152 .collect()
153 })
154 .collect()
155 }
156
157 pub fn extract_features(
159 &self,
160 audio: &AudioBuffer,
161 ) -> Result<RealTimeFeatureResult, RecognitionError> {
162 let start_time = std::time::Instant::now();
163 let mut result = RealTimeFeatureResult::default();
164
165 let samples = audio.samples();
166 let num_frames = (samples.len() - self.config.window_size) / self.config.hop_length + 1;
167 result.num_frames = num_frames;
168
169 for frame_idx in 0..num_frames {
170 let start = frame_idx * self.config.hop_length;
171 let end = (start + self.config.window_size).min(samples.len());
172 let frame = &samples[start..end];
173
174 if frame.len() == self.config.window_size {
175 let windowed: Vec<f32> = frame
177 .iter()
178 .zip(self.window.iter())
179 .map(|(s, w)| s * w)
180 .collect();
181
182 if self.config.extract_mfcc {
184 let mfcc = self.extract_mfcc(&windowed)?;
185 result
186 .features
187 .entry(FeatureType::MFCC)
188 .or_insert_with(Vec::new)
189 .extend(mfcc);
190 }
191
192 if self.config.extract_spectral_centroid {
193 let centroid = self.extract_spectral_centroid(&windowed)?;
194 result
195 .features
196 .entry(FeatureType::SpectralCentroid)
197 .or_insert_with(Vec::new)
198 .push(centroid);
199 }
200
201 if self.config.extract_zcr {
202 let zcr = self.extract_zero_crossing_rate(frame)?;
203 result
204 .features
205 .entry(FeatureType::ZeroCrossingRate)
206 .or_insert_with(Vec::new)
207 .push(zcr);
208 }
209
210 if self.config.extract_spectral_rolloff {
211 let rolloff = self.extract_spectral_rolloff(&windowed)?;
212 result
213 .features
214 .entry(FeatureType::SpectralRolloff)
215 .or_insert_with(Vec::new)
216 .push(rolloff);
217 }
218
219 if self.config.extract_energy {
220 let energy = self.extract_energy(frame)?;
221 result
222 .features
223 .entry(FeatureType::Energy)
224 .or_insert_with(Vec::new)
225 .push(energy);
226 }
227 }
228 }
229
230 result.processing_time_ms = start_time.elapsed().as_secs_f32() * 1000.0;
231
232 result
234 .quality_metrics
235 .insert("snr_estimate".to_string(), self.estimate_snr(samples));
236 result.quality_metrics.insert(
237 "spectral_flatness".to_string(),
238 self.calculate_spectral_flatness(samples),
239 );
240
241 Ok(result)
242 }
243
244 fn extract_mfcc(&self, windowed_frame: &[f32]) -> Result<Vec<f32>, RecognitionError> {
246 let fft = self.simple_fft(windowed_frame);
248 let power_spectrum: Vec<f32> = fft.iter().map(scirs2_core::Complex::norm_sqr).collect();
249
250 let mel_energies: Vec<f32> = self
252 .mel_filterbank
253 .iter()
254 .map(|filter| {
255 filter
256 .iter()
257 .zip(power_spectrum.iter())
258 .map(|(f, p)| f * p)
259 .sum::<f32>()
260 .max(1e-10)
261 .ln()
262 })
263 .collect();
264
265 let mfcc: Vec<f32> = self
267 .dct_matrix
268 .iter()
269 .map(|dct_row| {
270 dct_row
271 .iter()
272 .zip(mel_energies.iter())
273 .map(|(d, m)| d * m)
274 .sum()
275 })
276 .collect();
277
278 Ok(mfcc)
279 }
280
281 fn extract_spectral_centroid(&self, windowed_frame: &[f32]) -> Result<f32, RecognitionError> {
283 let fft = self.simple_fft(windowed_frame);
284 let power_spectrum: Vec<f32> = fft.iter().map(scirs2_core::Complex::norm_sqr).collect();
285
286 let numerator: f32 = power_spectrum
287 .iter()
288 .enumerate()
289 .map(|(i, p)| i as f32 * p)
290 .sum();
291
292 let denominator: f32 = power_spectrum.iter().sum();
293
294 Ok(if denominator > 0.0 {
295 numerator / denominator
296 } else {
297 0.0
298 })
299 }
300
301 fn extract_zero_crossing_rate(&self, frame: &[f32]) -> Result<f32, RecognitionError> {
303 let crossings = frame
304 .windows(2)
305 .filter(|w| (w[0] >= 0.0) != (w[1] >= 0.0))
306 .count();
307
308 Ok(crossings as f32 / frame.len() as f32)
309 }
310
311 fn extract_spectral_rolloff(&self, windowed_frame: &[f32]) -> Result<f32, RecognitionError> {
313 let fft = self.simple_fft(windowed_frame);
314 let power_spectrum: Vec<f32> = fft.iter().map(scirs2_core::Complex::norm_sqr).collect();
315
316 let total_energy: f32 = power_spectrum.iter().sum();
317 let threshold = 0.85 * total_energy;
318
319 let mut cumsum = 0.0;
320 for (i, power) in power_spectrum.iter().enumerate() {
321 cumsum += power;
322 if cumsum >= threshold {
323 return Ok(i as f32 / power_spectrum.len() as f32);
324 }
325 }
326
327 Ok(1.0)
328 }
329
330 fn extract_energy(&self, frame: &[f32]) -> Result<f32, RecognitionError> {
332 let energy: f32 = frame.iter().map(|s| s * s).sum();
333 Ok((energy / frame.len() as f32).sqrt())
334 }
335
336 fn simple_fft(&self, input: &[f32]) -> Vec<scirs2_core::Complex<f32>> {
338 input
340 .iter()
341 .enumerate()
342 .map(|(i, &sample)| {
343 let angle = -2.0 * std::f32::consts::PI * i as f32 / input.len() as f32;
344 scirs2_core::Complex::new(sample * angle.cos(), sample * angle.sin())
345 })
346 .collect()
347 }
348
349 fn estimate_snr(&self, samples: &[f32]) -> f32 {
351 let signal_power: f32 = samples.iter().map(|s| s * s).sum();
352 let mean_power = signal_power / samples.len() as f32;
353
354 let sorted_powers: Vec<f32> = samples.iter().map(|s| s * s).collect::<Vec<_>>();
356
357 let noise_floor = sorted_powers.iter().take(samples.len() / 10).sum::<f32>()
358 / (samples.len() / 10) as f32;
359
360 if noise_floor > 0.0 {
361 10.0 * (mean_power / noise_floor).log10()
362 } else {
363 60.0 }
365 }
366
367 fn calculate_spectral_flatness(&self, samples: &[f32]) -> f32 {
369 let fft = self.simple_fft(samples);
370 let power_spectrum: Vec<f32> = fft.iter().map(scirs2_core::Complex::norm_sqr).collect();
371
372 let geometric_mean = power_spectrum
373 .iter()
374 .map(|p| p.max(1e-10).ln())
375 .sum::<f32>()
376 / power_spectrum.len() as f32;
377
378 let arithmetic_mean = power_spectrum.iter().sum::<f32>() / power_spectrum.len() as f32;
379
380 if arithmetic_mean > 0.0 {
381 geometric_mean.exp() / arithmetic_mean
382 } else {
383 0.0
384 }
385 }
386
387 #[must_use]
389 pub fn config(&self) -> &RealTimeFeatureConfig {
390 &self.config
391 }
392
393 pub fn set_config(&mut self, config: RealTimeFeatureConfig) -> Result<(), RecognitionError> {
395 self.window = Self::create_hann_window(config.window_size);
396 self.mel_filterbank =
397 Self::create_mel_filterbank(config.n_mels, config.window_size / 2 + 1);
398 self.dct_matrix = Self::create_dct_matrix(config.n_mels);
399 self.config = config;
400 Ok(())
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_realtime_feature_config_default() {
410 let config = RealTimeFeatureConfig::default();
411 assert_eq!(config.window_size, 512);
412 assert_eq!(config.hop_length, 256);
413 assert_eq!(config.n_mels, 13);
414 assert!(config.extract_mfcc);
415 assert!(config.extract_spectral_centroid);
416 }
417
418 #[test]
419 fn test_feature_extractor_creation() {
420 let config = RealTimeFeatureConfig::default();
421 let extractor = RealTimeFeatureExtractor::new(config);
422 assert!(extractor.is_ok());
423 }
424
425 #[test]
426 fn test_feature_extraction() {
427 let config = RealTimeFeatureConfig::default();
428 let extractor = RealTimeFeatureExtractor::new(config).unwrap();
429
430 let samples = vec![0.1; 1024]; let audio = AudioBuffer::new(samples, 16000, 1);
432
433 let result = extractor.extract_features(&audio);
434 assert!(result.is_ok());
435
436 let features = result.unwrap();
437 assert!(features.features.contains_key(&FeatureType::MFCC));
438 assert!(features
439 .features
440 .contains_key(&FeatureType::SpectralCentroid));
441 assert!(features.num_frames > 0);
442 assert!(features.processing_time_ms >= 0.0);
443 }
444
445 #[test]
446 fn test_feature_types() {
447 let types = vec![
448 FeatureType::MFCC,
449 FeatureType::SpectralCentroid,
450 FeatureType::ZeroCrossingRate,
451 FeatureType::SpectralRolloff,
452 FeatureType::Energy,
453 FeatureType::Pitch,
454 FeatureType::SpectralBandwidth,
455 ];
456
457 for feature_type in types {
458 assert_eq!(feature_type.clone(), feature_type);
460 }
461 }
462
463 #[test]
464 fn test_feature_result_default() {
465 let result = RealTimeFeatureResult::default();
466 assert!(result.features.is_empty());
467 assert_eq!(result.num_frames, 0);
468 assert!((result.processing_time_ms - 0.0).abs() < f32::EPSILON);
469 assert!(result.quality_metrics.is_empty());
470 }
471}