Skip to main content

scirs2_transform/signal_transforms/
mfcc.rs

1//! Mel-Frequency Cepstral Coefficients (MFCC) Implementation
2//!
3//! Provides MFCC extraction for audio and speech processing.
4
5use crate::error::{Result, TransformError};
6use crate::signal_transforms::stft::{STFTConfig, WindowType, STFT};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
8use scirs2_core::numeric::Complex;
9use std::f64::consts::PI;
10
11/// Mel filterbank configuration
12#[derive(Debug, Clone)]
13pub struct MelFilterbank {
14    /// Number of mel filters
15    pub n_filters: usize,
16    /// FFT size
17    pub nfft: usize,
18    /// Sampling rate in Hz
19    pub sample_rate: f64,
20    /// Lower frequency bound in Hz
21    pub fmin: f64,
22    /// Upper frequency bound in Hz
23    pub fmax: f64,
24    /// Filter weights (n_filters x n_freqs)
25    filters: Array2<f64>,
26}
27
28impl MelFilterbank {
29    /// Create a new mel filterbank
30    pub fn new(
31        n_filters: usize,
32        nfft: usize,
33        sample_rate: f64,
34        fmin: f64,
35        fmax: f64,
36    ) -> Result<Self> {
37        if fmin >= fmax {
38            return Err(TransformError::InvalidInput(
39                "fmin must be less than fmax".to_string(),
40            ));
41        }
42
43        if fmax > sample_rate / 2.0 {
44            return Err(TransformError::InvalidInput(
45                "fmax must be <= sample_rate/2".to_string(),
46            ));
47        }
48
49        let filters = Self::compute_filters(n_filters, nfft, sample_rate, fmin, fmax);
50
51        Ok(MelFilterbank {
52            n_filters,
53            nfft,
54            sample_rate,
55            fmin,
56            fmax,
57            filters,
58        })
59    }
60
61    /// Convert Hz to Mel scale
62    fn hz_to_mel(hz: f64) -> f64 {
63        2595.0 * (1.0 + hz / 700.0).log10()
64    }
65
66    /// Convert Mel to Hz scale
67    fn mel_to_hz(mel: f64) -> f64 {
68        700.0 * (10.0_f64.powf(mel / 2595.0) - 1.0)
69    }
70
71    /// Compute mel filter weights
72    fn compute_filters(
73        n_filters: usize,
74        nfft: usize,
75        sample_rate: f64,
76        fmin: f64,
77        fmax: f64,
78    ) -> Array2<f64> {
79        let n_freqs = nfft / 2 + 1;
80        let mut filters = Array2::zeros((n_filters, n_freqs));
81
82        // Convert frequency bounds to mel scale
83        let mel_min = Self::hz_to_mel(fmin);
84        let mel_max = Self::hz_to_mel(fmax);
85
86        // Create mel-spaced frequency points
87        let mel_points: Vec<f64> = (0..=n_filters + 1)
88            .map(|i| mel_min + (mel_max - mel_min) * i as f64 / (n_filters + 1) as f64)
89            .collect();
90
91        let hz_points: Vec<f64> = mel_points.iter().map(|&m| Self::mel_to_hz(m)).collect();
92
93        // Convert Hz points to FFT bin indices
94        let bin_points: Vec<usize> = hz_points
95            .iter()
96            .map(|&f| ((nfft + 1) as f64 * f / sample_rate).floor() as usize)
97            .collect();
98
99        // Create triangular filters
100        for i in 0..n_filters {
101            let left = bin_points[i];
102            let center = bin_points[i + 1];
103            let right = bin_points[i + 2];
104
105            // Rising slope
106            for j in left..center {
107                if center > left && j < n_freqs {
108                    filters[[i, j]] = (j - left) as f64 / (center - left) as f64;
109                }
110            }
111
112            // Falling slope
113            for j in center..right {
114                if right > center && j < n_freqs {
115                    filters[[i, j]] = (right - j) as f64 / (right - center) as f64;
116                }
117            }
118        }
119
120        filters
121    }
122
123    /// Apply mel filterbank to power spectrum
124    pub fn apply(&self, power_spectrum: &ArrayView1<f64>) -> Result<Array1<f64>> {
125        let n_freqs = power_spectrum.len();
126        if n_freqs != self.nfft / 2 + 1 {
127            return Err(TransformError::InvalidInput(format!(
128                "Expected {} frequency bins, got {}",
129                self.nfft / 2 + 1,
130                n_freqs
131            )));
132        }
133
134        let mut mel_energies = Array1::zeros(self.n_filters);
135
136        for i in 0..self.n_filters {
137            let mut energy = 0.0;
138            for j in 0..n_freqs {
139                energy += self.filters[[i, j]] * power_spectrum[j];
140            }
141            mel_energies[i] = energy;
142        }
143
144        Ok(mel_energies)
145    }
146
147    /// Get the filter weights
148    pub fn filters(&self) -> &Array2<f64> {
149        &self.filters
150    }
151
152    /// Get center frequencies in Hz
153    pub fn center_frequencies(&self) -> Vec<f64> {
154        let mel_min = Self::hz_to_mel(self.fmin);
155        let mel_max = Self::hz_to_mel(self.fmax);
156
157        (0..self.n_filters)
158            .map(|i| {
159                let mel =
160                    mel_min + (mel_max - mel_min) * (i + 1) as f64 / (self.n_filters + 1) as f64;
161                Self::mel_to_hz(mel)
162            })
163            .collect()
164    }
165}
166
167/// MFCC configuration
168#[derive(Debug, Clone)]
169pub struct MFCCConfig {
170    /// Number of MFCCs to extract
171    pub n_mfcc: usize,
172    /// Number of mel filters
173    pub n_mels: usize,
174    /// FFT size
175    pub nfft: usize,
176    /// Hop size for STFT
177    pub hop_size: usize,
178    /// Window size for STFT
179    pub window_size: usize,
180    /// Sampling rate in Hz
181    pub sample_rate: f64,
182    /// Lower frequency bound in Hz
183    pub fmin: f64,
184    /// Upper frequency bound in Hz
185    pub fmax: f64,
186    /// Liftering coefficient
187    pub lifter: Option<usize>,
188    /// Whether to apply mean normalization
189    pub normalize: bool,
190}
191
192impl Default for MFCCConfig {
193    fn default() -> Self {
194        MFCCConfig {
195            n_mfcc: 13,
196            n_mels: 40,
197            nfft: 512,
198            hop_size: 160,
199            window_size: 400,
200            sample_rate: 16000.0,
201            fmin: 0.0,
202            fmax: 8000.0,
203            lifter: Some(22),
204            normalize: true,
205        }
206    }
207}
208
209/// MFCC extractor
210#[derive(Debug, Clone)]
211pub struct MFCC {
212    config: MFCCConfig,
213    mel_filterbank: MelFilterbank,
214    stft: STFT,
215    dct_matrix: Array2<f64>,
216}
217
218impl MFCC {
219    /// Create a new MFCC extractor
220    pub fn new(config: MFCCConfig) -> Result<Self> {
221        let mel_filterbank = MelFilterbank::new(
222            config.n_mels,
223            config.nfft,
224            config.sample_rate,
225            config.fmin,
226            config.fmax,
227        )?;
228
229        let stft_config = STFTConfig {
230            window_size: config.window_size,
231            hop_size: config.hop_size,
232            window_type: WindowType::Hamming,
233            nfft: Some(config.nfft),
234            onesided: true,
235            padding: crate::signal_transforms::stft::PaddingMode::Zero,
236        };
237
238        let stft = STFT::new(stft_config);
239        let dct_matrix = Self::compute_dct_matrix(config.n_mfcc, config.n_mels);
240
241        Ok(MFCC {
242            config,
243            mel_filterbank,
244            stft,
245            dct_matrix,
246        })
247    }
248
249    /// Create with default configuration
250    pub fn default() -> Result<Self> {
251        Self::new(MFCCConfig::default())
252    }
253
254    /// Compute DCT-II matrix
255    fn compute_dct_matrix(n_mfcc: usize, n_mels: usize) -> Array2<f64> {
256        let mut dct = Array2::zeros((n_mfcc, n_mels));
257        let norm = (2.0 / n_mels as f64).sqrt();
258
259        for i in 0..n_mfcc {
260            for j in 0..n_mels {
261                dct[[i, j]] = norm * (PI * i as f64 * (j as f64 + 0.5) / n_mels as f64).cos();
262            }
263        }
264
265        dct
266    }
267
268    /// Extract MFCCs from audio signal
269    pub fn extract(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
270        // Compute STFT
271        let stft = self.stft.transform(signal)?;
272        let (n_freqs, n_frames) = stft.dim();
273
274        // Compute power spectrum
275        let mut power_spec = Array2::zeros((n_freqs, n_frames));
276        for i in 0..n_freqs {
277            for j in 0..n_frames {
278                let mag = stft[[i, j]].norm();
279                power_spec[[i, j]] = mag * mag;
280            }
281        }
282
283        // Apply mel filterbank and extract MFCCs for each frame
284        let mut mfccs = Array2::zeros((self.config.n_mfcc, n_frames));
285
286        for frame_idx in 0..n_frames {
287            let power_frame = power_spec.column(frame_idx);
288            let mel_energies = self.mel_filterbank.apply(&power_frame)?;
289
290            // Log mel energies
291            let log_mel_energies: Array1<f64> = mel_energies
292                .iter()
293                .map(|&e| {
294                    if e > 1e-10 {
295                        e.ln()
296                    } else {
297                        -23.025850929940457 // ln(1e-10)
298                    }
299                })
300                .collect();
301
302            // Apply DCT
303            let mfcc_frame = self.dct_matrix.dot(&log_mel_energies);
304
305            // Apply liftering if configured
306            let mfcc_frame = if let Some(lifter) = self.config.lifter {
307                self.apply_lifter(&mfcc_frame, lifter)
308            } else {
309                mfcc_frame
310            };
311
312            // Store MFCCs
313            for (i, &val) in mfcc_frame.iter().enumerate() {
314                mfccs[[i, frame_idx]] = val;
315            }
316        }
317
318        // Apply mean normalization if configured
319        if self.config.normalize {
320            self.normalize_mfccs(&mut mfccs);
321        }
322
323        Ok(mfccs)
324    }
325
326    /// Apply liftering (cepstral filtering)
327    fn apply_lifter(&self, mfcc: &Array1<f64>, lifter: usize) -> Array1<f64> {
328        let n = mfcc.len();
329        let mut lifted = Array1::zeros(n);
330
331        for i in 0..n {
332            let lift_weight = 1.0 + (lifter as f64 / 2.0) * (PI * i as f64 / lifter as f64).sin();
333            lifted[i] = mfcc[i] * lift_weight;
334        }
335
336        lifted
337    }
338
339    /// Apply mean normalization to MFCCs
340    fn normalize_mfccs(&self, mfccs: &mut Array2<f64>) {
341        let (n_mfcc, n_frames) = mfccs.dim();
342
343        for i in 0..n_mfcc {
344            let mut sum = 0.0;
345            for j in 0..n_frames {
346                sum += mfccs[[i, j]];
347            }
348            let mean = sum / n_frames as f64;
349
350            for j in 0..n_frames {
351                mfccs[[i, j]] -= mean;
352            }
353        }
354    }
355
356    /// Extract delta (first derivative) features
357    pub fn delta(features: &Array2<f64>, width: usize) -> Array2<f64> {
358        let (n_features, n_frames) = features.dim();
359        let mut deltas = Array2::zeros((n_features, n_frames));
360
361        let width = width as i64;
362        let denominator: f64 = (1..=width).map(|i| i * i).sum::<i64>() as f64 * 2.0;
363
364        for i in 0..n_features {
365            for j in 0..n_frames {
366                let mut delta = 0.0;
367
368                for t in 1..=width {
369                    let t_f64 = t as f64;
370
371                    // Forward difference
372                    let idx_forward = (j as i64 + t).min(n_frames as i64 - 1) as usize;
373                    // Backward difference
374                    let idx_backward = (j as i64 - t).max(0) as usize;
375
376                    delta += t_f64 * (features[[i, idx_forward]] - features[[i, idx_backward]]);
377                }
378
379                deltas[[i, j]] = delta / denominator;
380            }
381        }
382
383        deltas
384    }
385
386    /// Extract delta-delta (second derivative) features
387    pub fn delta_delta(features: &Array2<f64>, width: usize) -> Array2<f64> {
388        let deltas = Self::delta(features, width);
389        Self::delta(&deltas, width)
390    }
391
392    /// Extract MFCCs with delta and delta-delta features
393    pub fn extract_with_deltas(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
394        let mfccs = self.extract(signal)?;
395        let deltas = Self::delta(&mfccs, 2);
396        let delta_deltas = Self::delta_delta(&mfccs, 2);
397
398        // Stack features vertically
399        let (n_mfcc, n_frames) = mfccs.dim();
400        let mut combined = Array2::zeros((n_mfcc * 3, n_frames));
401
402        for i in 0..n_mfcc {
403            for j in 0..n_frames {
404                combined[[i, j]] = mfccs[[i, j]];
405                combined[[i + n_mfcc, j]] = deltas[[i, j]];
406                combined[[i + 2 * n_mfcc, j]] = delta_deltas[[i, j]];
407            }
408        }
409
410        Ok(combined)
411    }
412
413    /// Get the configuration
414    pub fn config(&self) -> &MFCCConfig {
415        &self.config
416    }
417
418    /// Get the mel filterbank
419    pub fn mel_filterbank(&self) -> &MelFilterbank {
420        &self.mel_filterbank
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use approx::assert_abs_diff_eq;
428
429    #[test]
430    fn test_hz_mel_conversion() {
431        let hz = 1000.0;
432        let mel = MelFilterbank::hz_to_mel(hz);
433        let hz_back = MelFilterbank::mel_to_hz(mel);
434
435        assert_abs_diff_eq!(hz, hz_back, epsilon = 1e-6);
436    }
437
438    #[test]
439    fn test_mel_filterbank() -> Result<()> {
440        let filterbank = MelFilterbank::new(40, 512, 16000.0, 0.0, 8000.0)?;
441
442        assert_eq!(filterbank.filters.dim(), (40, 257));
443
444        // Check that filters sum approximately to 1 for overlapping regions
445        let center_freqs = filterbank.center_frequencies();
446        assert_eq!(center_freqs.len(), 40);
447        assert!(center_freqs[0] > 0.0);
448        assert!(center_freqs[39] < 8000.0);
449
450        Ok(())
451    }
452
453    #[test]
454    fn test_mfcc_extraction() -> Result<()> {
455        let signal = Array1::from_vec((0..16000).map(|i| (i as f64 * 0.01).sin()).collect());
456        let mfcc = MFCC::default()?;
457
458        let features = mfcc.extract(&signal.view())?;
459
460        assert_eq!(features.dim().0, 13); // 13 MFCCs
461        assert!(features.dim().1 > 0); // Multiple frames
462
463        Ok(())
464    }
465
466    #[test]
467    fn test_mfcc_with_deltas() -> Result<()> {
468        let signal = Array1::from_vec((0..16000).map(|i| (i as f64 * 0.01).sin()).collect());
469        let mfcc = MFCC::default()?;
470
471        let features = mfcc.extract_with_deltas(&signal.view())?;
472
473        assert_eq!(features.dim().0, 39); // 13 + 13 + 13
474        assert!(features.dim().1 > 0);
475
476        Ok(())
477    }
478
479    #[test]
480    fn test_delta_features() {
481        let features = Array2::from_shape_vec(
482            (2, 5),
483            vec![1.0, 2.0, 3.0, 4.0, 5.0, 0.5, 1.0, 1.5, 2.0, 2.5],
484        )
485        .expect("Failed to create array");
486
487        let deltas = MFCC::delta(&features, 2);
488
489        assert_eq!(deltas.dim(), (2, 5));
490
491        // Deltas should capture the rate of change
492        for i in 1..4 {
493            assert!(deltas[[0, i]].abs() > 0.0);
494        }
495    }
496
497    #[test]
498    fn test_dct_matrix() {
499        let dct = MFCC::compute_dct_matrix(13, 40);
500
501        assert_eq!(dct.dim(), (13, 40));
502
503        // Check orthogonality (approximately)
504        let product = dct.dot(&dct.t());
505        for i in 0..13 {
506            for j in 0..13 {
507                if i == j {
508                    assert!(product[[i, j]] > 0.5);
509                }
510            }
511        }
512    }
513
514    #[test]
515    fn test_mfcc_config() {
516        let config = MFCCConfig::default();
517        assert_eq!(config.n_mfcc, 13);
518        assert_eq!(config.n_mels, 40);
519        assert_eq!(config.sample_rate, 16000.0);
520    }
521}