scirs2_metrics/domains/audio_processing/
audio_quality.rs

1//! Audio quality assessment metrics
2//!
3//! This module provides comprehensive metrics for evaluating audio quality,
4//! including perceptual metrics (PESQ, STOI), objective metrics (SNR, SDR),
5//! and intelligibility measures for speech enhancement and audio processing tasks.
6
7#![allow(clippy::too_many_arguments)]
8#![allow(dead_code)]
9
10use crate::error::{MetricsError, Result};
11use scirs2_core::ndarray::{Array1, ArrayView1};
12use scirs2_core::numeric::Float;
13use serde::{Deserialize, Serialize};
14
15/// Audio quality assessment metrics
16#[derive(Debug, Clone)]
17pub struct AudioQualityMetrics {
18    /// Perceptual evaluation metrics
19    perceptual_metrics: PerceptualAudioMetrics,
20    /// Objective quality metrics
21    objective_metrics: ObjectiveAudioMetrics,
22    /// Intelligibility metrics
23    intelligibility_metrics: IntelligibilityMetrics,
24}
25
26/// Perceptual audio quality metrics
27#[derive(Debug, Clone, Default)]
28pub struct PerceptualAudioMetrics {
29    /// PESQ (Perceptual Evaluation of Speech Quality)
30    pesq: Option<f64>,
31    /// STOI (Short-Time Objective Intelligibility)
32    stoi: Option<f64>,
33    /// MOSNet predicted MOS score
34    mosnet_score: Option<f64>,
35    /// DNSMOS predicted MOS score
36    dnsmos_score: Option<f64>,
37    /// SI-SDR (Scale-Invariant Signal-to-Distortion Ratio)
38    si_sdr: Option<f64>,
39}
40
41/// Objective audio quality metrics
42#[derive(Debug, Clone, Default)]
43pub struct ObjectiveAudioMetrics {
44    /// Signal-to-Noise Ratio
45    snr: f64,
46    /// Signal-to-Distortion Ratio
47    sdr: f64,
48    /// Signal-to-Interference Ratio
49    sir: f64,
50    /// Signal-to-Artifacts Ratio
51    sar: f64,
52    /// Frequency-weighted SNR
53    fw_snr: f64,
54    /// Spectral distortion measures
55    spectral_distortion: SpectralDistortionMetrics,
56}
57
58/// Spectral distortion metrics
59#[derive(Debug, Clone, Default)]
60pub struct SpectralDistortionMetrics {
61    /// Log-spectral distance
62    log_spectral_distance: f64,
63    /// Itakura-Saito distance
64    itakura_saito_distance: f64,
65    /// Mel-cepstral distortion
66    mel_cepstral_distortion: f64,
67    /// Bark spectral distortion
68    bark_spectral_distortion: f64,
69}
70
71/// Speech intelligibility metrics
72#[derive(Debug, Clone, Default)]
73pub struct IntelligibilityMetrics {
74    /// Normalized Covariance Measure (NCM)
75    ncm: f64,
76    /// Coherence Speech Intelligibility Index (CSII)
77    csii: f64,
78    /// Hearing Aid Speech Quality Index (HASQI)
79    hasqi: Option<f64>,
80    /// Extended Short-Time Objective Intelligibility (ESTOI)
81    estoi: Option<f64>,
82}
83
84/// Audio quality assessment results
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct AudioQualityResults {
87    /// PESQ score
88    pub pesq: Option<f64>,
89    /// STOI score
90    pub stoi: Option<f64>,
91    /// Signal-to-Noise Ratio
92    pub snr: f64,
93    /// Signal-to-Distortion Ratio
94    pub sdr: f64,
95    /// SI-SDR score
96    pub si_sdr: Option<f64>,
97}
98
99impl AudioQualityMetrics {
100    /// Create new audio quality metrics
101    pub fn new() -> Self {
102        Self {
103            perceptual_metrics: PerceptualAudioMetrics::default(),
104            objective_metrics: ObjectiveAudioMetrics::default(),
105            intelligibility_metrics: IntelligibilityMetrics::default(),
106        }
107    }
108
109    /// Compute comprehensive audio quality assessment
110    pub fn compute_quality_metrics<F: Float>(
111        &mut self,
112        clean_signal: ArrayView1<F>,
113        processed_signal: ArrayView1<F>,
114        noise_signal: Option<ArrayView1<F>>,
115        sample_rate: f64,
116    ) -> Result<AudioQualityResults> {
117        if clean_signal.len() != processed_signal.len() {
118            return Err(MetricsError::InvalidInput(
119                "Clean and processed signals must have the same length".to_string(),
120            ));
121        }
122
123        // Compute objective metrics
124        self.objective_metrics
125            .compute_snr(clean_signal, processed_signal)?;
126        self.objective_metrics
127            .compute_sdr(clean_signal, processed_signal)?;
128
129        if let Some(noise) = noise_signal {
130            self.objective_metrics
131                .compute_sir(clean_signal, processed_signal, noise)?;
132        }
133
134        // Compute perceptual metrics
135        self.perceptual_metrics
136            .compute_pesq(clean_signal, processed_signal, sample_rate)?;
137        self.perceptual_metrics
138            .compute_stoi(clean_signal, processed_signal, sample_rate)?;
139        self.perceptual_metrics
140            .compute_si_sdr(clean_signal, processed_signal)?;
141
142        // Compute intelligibility metrics
143        self.intelligibility_metrics
144            .compute_ncm(clean_signal, processed_signal)?;
145        self.intelligibility_metrics
146            .compute_csii(clean_signal, processed_signal, sample_rate)?;
147
148        Ok(AudioQualityResults {
149            pesq: self.perceptual_metrics.pesq,
150            stoi: self.perceptual_metrics.stoi,
151            snr: self.objective_metrics.snr,
152            sdr: self.objective_metrics.sdr,
153            si_sdr: self.perceptual_metrics.si_sdr,
154        })
155    }
156
157    /// Compute PESQ score
158    pub fn compute_pesq<F: Float>(
159        &mut self,
160        reference: ArrayView1<F>,
161        degraded: ArrayView1<F>,
162        sample_rate: f64,
163    ) -> Result<f64> {
164        self.perceptual_metrics
165            .compute_pesq(reference, degraded, sample_rate)
166    }
167
168    /// Compute STOI score
169    pub fn compute_stoi<F: Float>(
170        &mut self,
171        reference: ArrayView1<F>,
172        degraded: ArrayView1<F>,
173        sample_rate: f64,
174    ) -> Result<f64> {
175        self.perceptual_metrics
176            .compute_stoi(reference, degraded, sample_rate)
177    }
178
179    /// Compute SNR
180    pub fn compute_snr<F: Float>(
181        &mut self,
182        signal: ArrayView1<F>,
183        noise: ArrayView1<F>,
184    ) -> Result<f64> {
185        self.objective_metrics.compute_snr(signal, noise)
186    }
187
188    /// Compute SDR (Signal-to-Distortion Ratio)
189    pub fn compute_sdr<F: Float>(
190        &mut self,
191        reference: ArrayView1<F>,
192        estimate: ArrayView1<F>,
193    ) -> Result<f64> {
194        self.objective_metrics.compute_sdr(reference, estimate)
195    }
196
197    /// Get comprehensive quality results
198    pub fn get_results(&self) -> AudioQualityResults {
199        AudioQualityResults {
200            pesq: self.perceptual_metrics.pesq,
201            stoi: self.perceptual_metrics.stoi,
202            snr: self.objective_metrics.snr,
203            sdr: self.objective_metrics.sdr,
204            si_sdr: self.perceptual_metrics.si_sdr,
205        }
206    }
207
208    /// Evaluate audio quality (alias for backward compatibility)
209    pub fn evaluate_quality<F>(
210        &mut self,
211        reference_audio: ArrayView1<F>,
212        degraded_audio: ArrayView1<F>,
213        sample_rate: f64,
214    ) -> Result<AudioQualityResults>
215    where
216        F: Float + std::fmt::Debug + std::iter::Sum,
217    {
218        self.compute_quality_metrics(reference_audio, degraded_audio, None, sample_rate)
219    }
220}
221
222impl PerceptualAudioMetrics {
223    /// Compute PESQ (Perceptual Evaluation of Speech Quality)
224    pub fn compute_pesq<F: Float>(
225        &mut self,
226        reference: ArrayView1<F>,
227        degraded: ArrayView1<F>,
228        sample_rate: f64,
229    ) -> Result<f64> {
230        if reference.len() != degraded.len() {
231            return Err(MetricsError::InvalidInput(
232                "Reference and degraded signals must have the same length".to_string(),
233            ));
234        }
235
236        // Simplified PESQ implementation - would use actual ITU-T P.862 algorithm
237        let min_length = 8000; // Minimum 1 second at 8kHz
238        if reference.len() < min_length {
239            return Err(MetricsError::InvalidInput(
240                "Signal too short for PESQ computation".to_string(),
241            ));
242        }
243
244        // Basic correlation-based approximation
245        let correlation = self.compute_correlation(reference, degraded);
246        let pesq_score = (correlation * 4.5).max(1.0).min(4.5); // PESQ range: 1.0-4.5
247
248        self.pesq = Some(pesq_score);
249        Ok(pesq_score)
250    }
251
252    /// Compute STOI (Short-Time Objective Intelligibility)
253    pub fn compute_stoi<F: Float>(
254        &mut self,
255        reference: ArrayView1<F>,
256        degraded: ArrayView1<F>,
257        sample_rate: f64,
258    ) -> Result<f64> {
259        if reference.len() != degraded.len() {
260            return Err(MetricsError::InvalidInput(
261                "Reference and degraded signals must have the same length".to_string(),
262            ));
263        }
264
265        // Simplified STOI implementation - would use actual third-octave band analysis
266        let frame_length = (sample_rate * 0.025) as usize; // 25ms frames
267        let hop_length = frame_length / 2;
268
269        if reference.len() < frame_length {
270            return Err(MetricsError::InvalidInput(
271                "Signal too short for STOI computation".to_string(),
272            ));
273        }
274
275        let mut stoi_values = Vec::new();
276
277        for i in (0..reference.len() - frame_length).step_by(hop_length) {
278            let ref_frame = reference.slice(s![i..i + frame_length]);
279            let deg_frame = degraded.slice(s![i..i + frame_length]);
280
281            let correlation = self.compute_correlation(ref_frame, deg_frame);
282            stoi_values.push(correlation.max(0.0).min(1.0));
283        }
284
285        let stoi_score = if !stoi_values.is_empty() {
286            stoi_values.iter().sum::<f64>() / stoi_values.len() as f64
287        } else {
288            0.0
289        };
290
291        self.stoi = Some(stoi_score);
292        Ok(stoi_score)
293    }
294
295    /// Compute SI-SDR (Scale-Invariant Signal-to-Distortion Ratio)
296    pub fn compute_si_sdr<F: Float>(
297        &mut self,
298        reference: ArrayView1<F>,
299        estimate: ArrayView1<F>,
300    ) -> Result<f64> {
301        if reference.len() != estimate.len() {
302            return Err(MetricsError::InvalidInput(
303                "Reference and estimate signals must have the same length".to_string(),
304            ));
305        }
306
307        // Convert to f64 for computation
308        let ref_vec: Vec<f64> = reference
309            .iter()
310            .map(|&x| x.to_f64().unwrap_or(0.0))
311            .collect();
312        let est_vec: Vec<f64> = estimate
313            .iter()
314            .map(|&x| x.to_f64().unwrap_or(0.0))
315            .collect();
316
317        // Compute optimal scaling factor
318        let numerator: f64 = ref_vec.iter().zip(&est_vec).map(|(r, e)| r * e).sum();
319        let denominator: f64 = ref_vec.iter().map(|r| r * r).sum();
320
321        if denominator == 0.0 {
322            return Ok(f64::NEG_INFINITY);
323        }
324
325        let alpha = numerator / denominator;
326
327        // Compute scaled reference
328        let scaled_ref: Vec<f64> = ref_vec.iter().map(|r| alpha * r).collect();
329
330        // Compute signal and noise powers
331        let signal_power: f64 = scaled_ref.iter().map(|s| s * s).sum();
332        let noise_power: f64 = scaled_ref
333            .iter()
334            .zip(&est_vec)
335            .map(|(s, e)| (s - e).powi(2))
336            .sum();
337
338        let si_sdr = if noise_power > 0.0 {
339            10.0 * (signal_power / noise_power).log10()
340        } else {
341            f64::INFINITY
342        };
343
344        self.si_sdr = Some(si_sdr);
345        Ok(si_sdr)
346    }
347
348    /// Compute correlation between two signals
349    fn compute_correlation<F: Float>(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> f64 {
350        if x.len() != y.len() || x.is_empty() {
351            return 0.0;
352        }
353
354        let x_vec: Vec<f64> = x.iter().map(|&v| v.to_f64().unwrap_or(0.0)).collect();
355        let y_vec: Vec<f64> = y.iter().map(|&v| v.to_f64().unwrap_or(0.0)).collect();
356
357        let mean_x = x_vec.iter().sum::<f64>() / x_vec.len() as f64;
358        let mean_y = y_vec.iter().sum::<f64>() / y_vec.len() as f64;
359
360        let numerator: f64 = x_vec
361            .iter()
362            .zip(&y_vec)
363            .map(|(x, y)| (x - mean_x) * (y - mean_y))
364            .sum();
365        let var_x: f64 = x_vec.iter().map(|x| (x - mean_x).powi(2)).sum();
366        let var_y: f64 = y_vec.iter().map(|y| (y - mean_y).powi(2)).sum();
367
368        let denominator = (var_x * var_y).sqrt();
369
370        if denominator > 0.0 {
371            numerator / denominator
372        } else {
373            0.0
374        }
375    }
376}
377
378impl ObjectiveAudioMetrics {
379    /// Compute Signal-to-Noise Ratio (SNR)
380    pub fn compute_snr<F: Float>(
381        &mut self,
382        signal: ArrayView1<F>,
383        noise: ArrayView1<F>,
384    ) -> Result<f64> {
385        let signal_power = self.compute_power(signal);
386        let noise_power = self.compute_power(noise);
387
388        self.snr = if noise_power > 0.0 {
389            10.0 * (signal_power / noise_power).log10()
390        } else {
391            f64::INFINITY
392        };
393
394        Ok(self.snr)
395    }
396
397    /// Compute Signal-to-Distortion Ratio (SDR)
398    pub fn compute_sdr<F: Float>(
399        &mut self,
400        reference: ArrayView1<F>,
401        estimate: ArrayView1<F>,
402    ) -> Result<f64> {
403        if reference.len() != estimate.len() {
404            return Err(MetricsError::InvalidInput(
405                "Reference and estimate signals must have the same length".to_string(),
406            ));
407        }
408
409        let signal_power = self.compute_power(reference);
410
411        // Compute distortion power
412        let distortion_power: f64 = reference
413            .iter()
414            .zip(estimate.iter())
415            .map(|(&r, &e)| {
416                let diff = r.to_f64().unwrap_or(0.0) - e.to_f64().unwrap_or(0.0);
417                diff * diff
418            })
419            .sum::<f64>()
420            / reference.len() as f64;
421
422        self.sdr = if distortion_power > 0.0 {
423            10.0 * (signal_power / distortion_power).log10()
424        } else {
425            f64::INFINITY
426        };
427
428        Ok(self.sdr)
429    }
430
431    /// Compute Signal-to-Interference Ratio (SIR)
432    pub fn compute_sir<F: Float>(
433        &mut self,
434        signal: ArrayView1<F>,
435        estimate: ArrayView1<F>,
436        interference: ArrayView1<F>,
437    ) -> Result<f64> {
438        let signal_power = self.compute_power(signal);
439        let interference_power = self.compute_power(interference);
440
441        self.sir = if interference_power > 0.0 {
442            10.0 * (signal_power / interference_power).log10()
443        } else {
444            f64::INFINITY
445        };
446
447        Ok(self.sir)
448    }
449
450    /// Compute power of a signal
451    fn compute_power<F: Float>(&self, signal: ArrayView1<F>) -> f64 {
452        if signal.is_empty() {
453            return 0.0;
454        }
455
456        signal
457            .iter()
458            .map(|&x| {
459                let val = x.to_f64().unwrap_or(0.0);
460                val * val
461            })
462            .sum::<f64>()
463            / signal.len() as f64
464    }
465
466    /// Compute spectral distortion metrics
467    pub fn compute_spectral_distortion<F: Float>(
468        &mut self,
469        reference: ArrayView1<F>,
470        estimate: ArrayView1<F>,
471    ) -> Result<()> {
472        self.spectral_distortion
473            .compute_log_spectral_distance(reference, estimate)?;
474        self.spectral_distortion
475            .compute_itakura_saito_distance(reference, estimate)?;
476        Ok(())
477    }
478}
479
480impl SpectralDistortionMetrics {
481    /// Compute log-spectral distance
482    pub fn compute_log_spectral_distance<F: Float>(
483        &mut self,
484        reference: ArrayView1<F>,
485        estimate: ArrayView1<F>,
486    ) -> Result<f64> {
487        // Simplified implementation - would use actual FFT-based spectral analysis
488        let ref_spectrum = self.compute_simple_spectrum(reference);
489        let est_spectrum = self.compute_simple_spectrum(estimate);
490
491        if ref_spectrum.len() != est_spectrum.len() {
492            return Err(MetricsError::InvalidInput(
493                "Spectrum lengths must match".to_string(),
494            ));
495        }
496
497        let mut distance_sum = 0.0;
498        let mut valid_bins = 0;
499
500        for (ref_bin, est_bin) in ref_spectrum.iter().zip(est_spectrum.iter()) {
501            if *ref_bin > 0.0 && *est_bin > 0.0 {
502                distance_sum += (ref_bin.ln() - est_bin.ln()).powi(2);
503                valid_bins += 1;
504            }
505        }
506
507        self.log_spectral_distance = if valid_bins > 0 {
508            (distance_sum / valid_bins as f64).sqrt()
509        } else {
510            0.0
511        };
512
513        Ok(self.log_spectral_distance)
514    }
515
516    /// Compute Itakura-Saito distance
517    pub fn compute_itakura_saito_distance<F: Float>(
518        &mut self,
519        reference: ArrayView1<F>,
520        estimate: ArrayView1<F>,
521    ) -> Result<f64> {
522        let ref_spectrum = self.compute_simple_spectrum(reference);
523        let est_spectrum = self.compute_simple_spectrum(estimate);
524
525        let mut distance_sum = 0.0;
526        let mut valid_bins = 0;
527
528        for (ref_bin, est_bin) in ref_spectrum.iter().zip(est_spectrum.iter()) {
529            if *ref_bin > 0.0 && *est_bin > 0.0 {
530                distance_sum += (ref_bin / est_bin) - (ref_bin / est_bin).ln() - 1.0;
531                valid_bins += 1;
532            }
533        }
534
535        self.itakura_saito_distance = if valid_bins > 0 {
536            distance_sum / valid_bins as f64
537        } else {
538            0.0
539        };
540
541        Ok(self.itakura_saito_distance)
542    }
543
544    /// Compute simple power spectrum (placeholder)
545    fn compute_simple_spectrum<F: Float>(&self, signal: ArrayView1<F>) -> Vec<f64> {
546        // Simplified spectrum computation - would use actual FFT
547        let window_size = signal.len().min(1024);
548        let mut spectrum = Vec::with_capacity(window_size / 2);
549
550        for i in 0..window_size / 2 {
551            let start = i * 2;
552            let end = (start + window_size).min(signal.len());
553
554            if start < signal.len() {
555                let power: f64 = signal
556                    .slice(s![start..end])
557                    .iter()
558                    .map(|&x| {
559                        let val = x.to_f64().unwrap_or(0.0);
560                        val * val
561                    })
562                    .sum::<f64>()
563                    / (end - start) as f64;
564
565                spectrum.push(power.max(1e-10)); // Avoid log(0)
566            }
567        }
568
569        spectrum
570    }
571}
572
573impl IntelligibilityMetrics {
574    /// Compute Normalized Covariance Measure (NCM)
575    pub fn compute_ncm<F: Float>(
576        &mut self,
577        reference: ArrayView1<F>,
578        degraded: ArrayView1<F>,
579    ) -> Result<f64> {
580        if reference.len() != degraded.len() {
581            return Err(MetricsError::InvalidInput(
582                "Reference and degraded signals must have the same length".to_string(),
583            ));
584        }
585
586        // Simplified NCM computation
587        let correlation = self.compute_cross_correlation(reference, degraded);
588        self.ncm = correlation.abs();
589        Ok(self.ncm)
590    }
591
592    /// Compute Coherence Speech Intelligibility Index (CSII)
593    pub fn compute_csii<F: Float>(
594        &mut self,
595        reference: ArrayView1<F>,
596        degraded: ArrayView1<F>,
597        sample_rate: f64,
598    ) -> Result<f64> {
599        // Simplified CSII computation - would use actual coherence analysis
600        let frame_length = (sample_rate * 0.032) as usize; // 32ms frames
601        let hop_length = frame_length / 2;
602
603        let mut coherence_values = Vec::new();
604
605        for i in (0..reference.len() - frame_length).step_by(hop_length) {
606            let ref_frame = reference.slice(s![i..i + frame_length]);
607            let deg_frame = degraded.slice(s![i..i + frame_length]);
608
609            let coherence = self.compute_frame_coherence(ref_frame, deg_frame);
610            coherence_values.push(coherence);
611        }
612
613        self.csii = if !coherence_values.is_empty() {
614            coherence_values.iter().sum::<f64>() / coherence_values.len() as f64
615        } else {
616            0.0
617        };
618
619        Ok(self.csii)
620    }
621
622    /// Compute cross-correlation between signals
623    fn compute_cross_correlation<F: Float>(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> f64 {
624        if x.len() != y.len() || x.is_empty() {
625            return 0.0;
626        }
627
628        let x_vec: Vec<f64> = x.iter().map(|&v| v.to_f64().unwrap_or(0.0)).collect();
629        let y_vec: Vec<f64> = y.iter().map(|&v| v.to_f64().unwrap_or(0.0)).collect();
630
631        let mean_x = x_vec.iter().sum::<f64>() / x_vec.len() as f64;
632        let mean_y = y_vec.iter().sum::<f64>() / y_vec.len() as f64;
633
634        let numerator: f64 = x_vec
635            .iter()
636            .zip(&y_vec)
637            .map(|(x, y)| (x - mean_x) * (y - mean_y))
638            .sum();
639        let var_x: f64 = x_vec.iter().map(|x| (x - mean_x).powi(2)).sum();
640        let var_y: f64 = y_vec.iter().map(|y| (y - mean_y).powi(2)).sum();
641
642        let denominator = (var_x * var_y).sqrt();
643
644        if denominator > 0.0 {
645            numerator / denominator
646        } else {
647            0.0
648        }
649    }
650
651    /// Compute frame-level coherence
652    fn compute_frame_coherence<F: Float>(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> f64 {
653        // Simplified coherence computation
654        self.compute_cross_correlation(x, y).abs()
655    }
656}
657
658// Import necessary ndarray features
659use scirs2_core::ndarray::s;
660
661impl Default for AudioQualityMetrics {
662    fn default() -> Self {
663        Self::new()
664    }
665}