scirs2_transform/signal_transforms/
mfcc.rs1use crate::error::{Result, TransformError};
6use crate::signal_transforms::stft::{STFTConfig, WindowType, STFT};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
8use scirs2_core::numeric::Complex;
9use std::f64::consts::PI;
10
11#[derive(Debug, Clone)]
13pub struct MelFilterbank {
14 pub n_filters: usize,
16 pub nfft: usize,
18 pub sample_rate: f64,
20 pub fmin: f64,
22 pub fmax: f64,
24 filters: Array2<f64>,
26}
27
28impl MelFilterbank {
29 pub fn new(
31 n_filters: usize,
32 nfft: usize,
33 sample_rate: f64,
34 fmin: f64,
35 fmax: f64,
36 ) -> Result<Self> {
37 if fmin >= fmax {
38 return Err(TransformError::InvalidInput(
39 "fmin must be less than fmax".to_string(),
40 ));
41 }
42
43 if fmax > sample_rate / 2.0 {
44 return Err(TransformError::InvalidInput(
45 "fmax must be <= sample_rate/2".to_string(),
46 ));
47 }
48
49 let filters = Self::compute_filters(n_filters, nfft, sample_rate, fmin, fmax);
50
51 Ok(MelFilterbank {
52 n_filters,
53 nfft,
54 sample_rate,
55 fmin,
56 fmax,
57 filters,
58 })
59 }
60
61 fn hz_to_mel(hz: f64) -> f64 {
63 2595.0 * (1.0 + hz / 700.0).log10()
64 }
65
66 fn mel_to_hz(mel: f64) -> f64 {
68 700.0 * (10.0_f64.powf(mel / 2595.0) - 1.0)
69 }
70
71 fn compute_filters(
73 n_filters: usize,
74 nfft: usize,
75 sample_rate: f64,
76 fmin: f64,
77 fmax: f64,
78 ) -> Array2<f64> {
79 let n_freqs = nfft / 2 + 1;
80 let mut filters = Array2::zeros((n_filters, n_freqs));
81
82 let mel_min = Self::hz_to_mel(fmin);
84 let mel_max = Self::hz_to_mel(fmax);
85
86 let mel_points: Vec<f64> = (0..=n_filters + 1)
88 .map(|i| mel_min + (mel_max - mel_min) * i as f64 / (n_filters + 1) as f64)
89 .collect();
90
91 let hz_points: Vec<f64> = mel_points.iter().map(|&m| Self::mel_to_hz(m)).collect();
92
93 let bin_points: Vec<usize> = hz_points
95 .iter()
96 .map(|&f| ((nfft + 1) as f64 * f / sample_rate).floor() as usize)
97 .collect();
98
99 for i in 0..n_filters {
101 let left = bin_points[i];
102 let center = bin_points[i + 1];
103 let right = bin_points[i + 2];
104
105 for j in left..center {
107 if center > left && j < n_freqs {
108 filters[[i, j]] = (j - left) as f64 / (center - left) as f64;
109 }
110 }
111
112 for j in center..right {
114 if right > center && j < n_freqs {
115 filters[[i, j]] = (right - j) as f64 / (right - center) as f64;
116 }
117 }
118 }
119
120 filters
121 }
122
123 pub fn apply(&self, power_spectrum: &ArrayView1<f64>) -> Result<Array1<f64>> {
125 let n_freqs = power_spectrum.len();
126 if n_freqs != self.nfft / 2 + 1 {
127 return Err(TransformError::InvalidInput(format!(
128 "Expected {} frequency bins, got {}",
129 self.nfft / 2 + 1,
130 n_freqs
131 )));
132 }
133
134 let mut mel_energies = Array1::zeros(self.n_filters);
135
136 for i in 0..self.n_filters {
137 let mut energy = 0.0;
138 for j in 0..n_freqs {
139 energy += self.filters[[i, j]] * power_spectrum[j];
140 }
141 mel_energies[i] = energy;
142 }
143
144 Ok(mel_energies)
145 }
146
147 pub fn filters(&self) -> &Array2<f64> {
149 &self.filters
150 }
151
152 pub fn center_frequencies(&self) -> Vec<f64> {
154 let mel_min = Self::hz_to_mel(self.fmin);
155 let mel_max = Self::hz_to_mel(self.fmax);
156
157 (0..self.n_filters)
158 .map(|i| {
159 let mel =
160 mel_min + (mel_max - mel_min) * (i + 1) as f64 / (self.n_filters + 1) as f64;
161 Self::mel_to_hz(mel)
162 })
163 .collect()
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct MFCCConfig {
170 pub n_mfcc: usize,
172 pub n_mels: usize,
174 pub nfft: usize,
176 pub hop_size: usize,
178 pub window_size: usize,
180 pub sample_rate: f64,
182 pub fmin: f64,
184 pub fmax: f64,
186 pub lifter: Option<usize>,
188 pub normalize: bool,
190}
191
192impl Default for MFCCConfig {
193 fn default() -> Self {
194 MFCCConfig {
195 n_mfcc: 13,
196 n_mels: 40,
197 nfft: 512,
198 hop_size: 160,
199 window_size: 400,
200 sample_rate: 16000.0,
201 fmin: 0.0,
202 fmax: 8000.0,
203 lifter: Some(22),
204 normalize: true,
205 }
206 }
207}
208
209#[derive(Debug, Clone)]
211pub struct MFCC {
212 config: MFCCConfig,
213 mel_filterbank: MelFilterbank,
214 stft: STFT,
215 dct_matrix: Array2<f64>,
216}
217
218impl MFCC {
219 pub fn new(config: MFCCConfig) -> Result<Self> {
221 let mel_filterbank = MelFilterbank::new(
222 config.n_mels,
223 config.nfft,
224 config.sample_rate,
225 config.fmin,
226 config.fmax,
227 )?;
228
229 let stft_config = STFTConfig {
230 window_size: config.window_size,
231 hop_size: config.hop_size,
232 window_type: WindowType::Hamming,
233 nfft: Some(config.nfft),
234 onesided: true,
235 padding: crate::signal_transforms::stft::PaddingMode::Zero,
236 };
237
238 let stft = STFT::new(stft_config);
239 let dct_matrix = Self::compute_dct_matrix(config.n_mfcc, config.n_mels);
240
241 Ok(MFCC {
242 config,
243 mel_filterbank,
244 stft,
245 dct_matrix,
246 })
247 }
248
249 pub fn default() -> Result<Self> {
251 Self::new(MFCCConfig::default())
252 }
253
254 fn compute_dct_matrix(n_mfcc: usize, n_mels: usize) -> Array2<f64> {
256 let mut dct = Array2::zeros((n_mfcc, n_mels));
257 let norm = (2.0 / n_mels as f64).sqrt();
258
259 for i in 0..n_mfcc {
260 for j in 0..n_mels {
261 dct[[i, j]] = norm * (PI * i as f64 * (j as f64 + 0.5) / n_mels as f64).cos();
262 }
263 }
264
265 dct
266 }
267
268 pub fn extract(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
270 let stft = self.stft.transform(signal)?;
272 let (n_freqs, n_frames) = stft.dim();
273
274 let mut power_spec = Array2::zeros((n_freqs, n_frames));
276 for i in 0..n_freqs {
277 for j in 0..n_frames {
278 let mag = stft[[i, j]].norm();
279 power_spec[[i, j]] = mag * mag;
280 }
281 }
282
283 let mut mfccs = Array2::zeros((self.config.n_mfcc, n_frames));
285
286 for frame_idx in 0..n_frames {
287 let power_frame = power_spec.column(frame_idx);
288 let mel_energies = self.mel_filterbank.apply(&power_frame)?;
289
290 let log_mel_energies: Array1<f64> = mel_energies
292 .iter()
293 .map(|&e| {
294 if e > 1e-10 {
295 e.ln()
296 } else {
297 -23.025850929940457 }
299 })
300 .collect();
301
302 let mfcc_frame = self.dct_matrix.dot(&log_mel_energies);
304
305 let mfcc_frame = if let Some(lifter) = self.config.lifter {
307 self.apply_lifter(&mfcc_frame, lifter)
308 } else {
309 mfcc_frame
310 };
311
312 for (i, &val) in mfcc_frame.iter().enumerate() {
314 mfccs[[i, frame_idx]] = val;
315 }
316 }
317
318 if self.config.normalize {
320 self.normalize_mfccs(&mut mfccs);
321 }
322
323 Ok(mfccs)
324 }
325
326 fn apply_lifter(&self, mfcc: &Array1<f64>, lifter: usize) -> Array1<f64> {
328 let n = mfcc.len();
329 let mut lifted = Array1::zeros(n);
330
331 for i in 0..n {
332 let lift_weight = 1.0 + (lifter as f64 / 2.0) * (PI * i as f64 / lifter as f64).sin();
333 lifted[i] = mfcc[i] * lift_weight;
334 }
335
336 lifted
337 }
338
339 fn normalize_mfccs(&self, mfccs: &mut Array2<f64>) {
341 let (n_mfcc, n_frames) = mfccs.dim();
342
343 for i in 0..n_mfcc {
344 let mut sum = 0.0;
345 for j in 0..n_frames {
346 sum += mfccs[[i, j]];
347 }
348 let mean = sum / n_frames as f64;
349
350 for j in 0..n_frames {
351 mfccs[[i, j]] -= mean;
352 }
353 }
354 }
355
356 pub fn delta(features: &Array2<f64>, width: usize) -> Array2<f64> {
358 let (n_features, n_frames) = features.dim();
359 let mut deltas = Array2::zeros((n_features, n_frames));
360
361 let width = width as i64;
362 let denominator: f64 = (1..=width).map(|i| i * i).sum::<i64>() as f64 * 2.0;
363
364 for i in 0..n_features {
365 for j in 0..n_frames {
366 let mut delta = 0.0;
367
368 for t in 1..=width {
369 let t_f64 = t as f64;
370
371 let idx_forward = (j as i64 + t).min(n_frames as i64 - 1) as usize;
373 let idx_backward = (j as i64 - t).max(0) as usize;
375
376 delta += t_f64 * (features[[i, idx_forward]] - features[[i, idx_backward]]);
377 }
378
379 deltas[[i, j]] = delta / denominator;
380 }
381 }
382
383 deltas
384 }
385
386 pub fn delta_delta(features: &Array2<f64>, width: usize) -> Array2<f64> {
388 let deltas = Self::delta(features, width);
389 Self::delta(&deltas, width)
390 }
391
392 pub fn extract_with_deltas(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
394 let mfccs = self.extract(signal)?;
395 let deltas = Self::delta(&mfccs, 2);
396 let delta_deltas = Self::delta_delta(&mfccs, 2);
397
398 let (n_mfcc, n_frames) = mfccs.dim();
400 let mut combined = Array2::zeros((n_mfcc * 3, n_frames));
401
402 for i in 0..n_mfcc {
403 for j in 0..n_frames {
404 combined[[i, j]] = mfccs[[i, j]];
405 combined[[i + n_mfcc, j]] = deltas[[i, j]];
406 combined[[i + 2 * n_mfcc, j]] = delta_deltas[[i, j]];
407 }
408 }
409
410 Ok(combined)
411 }
412
413 pub fn config(&self) -> &MFCCConfig {
415 &self.config
416 }
417
418 pub fn mel_filterbank(&self) -> &MelFilterbank {
420 &self.mel_filterbank
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use approx::assert_abs_diff_eq;
428
429 #[test]
430 fn test_hz_mel_conversion() {
431 let hz = 1000.0;
432 let mel = MelFilterbank::hz_to_mel(hz);
433 let hz_back = MelFilterbank::mel_to_hz(mel);
434
435 assert_abs_diff_eq!(hz, hz_back, epsilon = 1e-6);
436 }
437
438 #[test]
439 fn test_mel_filterbank() -> Result<()> {
440 let filterbank = MelFilterbank::new(40, 512, 16000.0, 0.0, 8000.0)?;
441
442 assert_eq!(filterbank.filters.dim(), (40, 257));
443
444 let center_freqs = filterbank.center_frequencies();
446 assert_eq!(center_freqs.len(), 40);
447 assert!(center_freqs[0] > 0.0);
448 assert!(center_freqs[39] < 8000.0);
449
450 Ok(())
451 }
452
453 #[test]
454 fn test_mfcc_extraction() -> Result<()> {
455 let signal = Array1::from_vec((0..16000).map(|i| (i as f64 * 0.01).sin()).collect());
456 let mfcc = MFCC::default()?;
457
458 let features = mfcc.extract(&signal.view())?;
459
460 assert_eq!(features.dim().0, 13); assert!(features.dim().1 > 0); Ok(())
464 }
465
466 #[test]
467 fn test_mfcc_with_deltas() -> Result<()> {
468 let signal = Array1::from_vec((0..16000).map(|i| (i as f64 * 0.01).sin()).collect());
469 let mfcc = MFCC::default()?;
470
471 let features = mfcc.extract_with_deltas(&signal.view())?;
472
473 assert_eq!(features.dim().0, 39); assert!(features.dim().1 > 0);
475
476 Ok(())
477 }
478
479 #[test]
480 fn test_delta_features() {
481 let features = Array2::from_shape_vec(
482 (2, 5),
483 vec![1.0, 2.0, 3.0, 4.0, 5.0, 0.5, 1.0, 1.5, 2.0, 2.5],
484 )
485 .expect("Failed to create array");
486
487 let deltas = MFCC::delta(&features, 2);
488
489 assert_eq!(deltas.dim(), (2, 5));
490
491 for i in 1..4 {
493 assert!(deltas[[0, i]].abs() > 0.0);
494 }
495 }
496
497 #[test]
498 fn test_dct_matrix() {
499 let dct = MFCC::compute_dct_matrix(13, 40);
500
501 assert_eq!(dct.dim(), (13, 40));
502
503 let product = dct.dot(&dct.t());
505 for i in 0..13 {
506 for j in 0..13 {
507 if i == j {
508 assert!(product[[i, j]] > 0.5);
509 }
510 }
511 }
512 }
513
514 #[test]
515 fn test_mfcc_config() {
516 let config = MFCCConfig::default();
517 assert_eq!(config.n_mfcc, 13);
518 assert_eq!(config.n_mels, 40);
519 assert_eq!(config.sample_rate, 16000.0);
520 }
521}