1#![allow(clippy::too_many_arguments)]
8#![allow(dead_code)]
9
10use crate::error::{MetricsError, Result};
11use scirs2_core::ndarray::{Array1, ArrayView1};
12use scirs2_core::numeric::Float;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone)]
17pub struct AudioQualityMetrics {
18 perceptual_metrics: PerceptualAudioMetrics,
20 objective_metrics: ObjectiveAudioMetrics,
22 intelligibility_metrics: IntelligibilityMetrics,
24}
25
26#[derive(Debug, Clone, Default)]
28pub struct PerceptualAudioMetrics {
29 pesq: Option<f64>,
31 stoi: Option<f64>,
33 mosnet_score: Option<f64>,
35 dnsmos_score: Option<f64>,
37 si_sdr: Option<f64>,
39}
40
41#[derive(Debug, Clone, Default)]
43pub struct ObjectiveAudioMetrics {
44 snr: f64,
46 sdr: f64,
48 sir: f64,
50 sar: f64,
52 fw_snr: f64,
54 spectral_distortion: SpectralDistortionMetrics,
56}
57
58#[derive(Debug, Clone, Default)]
60pub struct SpectralDistortionMetrics {
61 log_spectral_distance: f64,
63 itakura_saito_distance: f64,
65 mel_cepstral_distortion: f64,
67 bark_spectral_distortion: f64,
69}
70
71#[derive(Debug, Clone, Default)]
73pub struct IntelligibilityMetrics {
74 ncm: f64,
76 csii: f64,
78 hasqi: Option<f64>,
80 estoi: Option<f64>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct AudioQualityResults {
87 pub pesq: Option<f64>,
89 pub stoi: Option<f64>,
91 pub snr: f64,
93 pub sdr: f64,
95 pub si_sdr: Option<f64>,
97}
98
99impl AudioQualityMetrics {
100 pub fn new() -> Self {
102 Self {
103 perceptual_metrics: PerceptualAudioMetrics::default(),
104 objective_metrics: ObjectiveAudioMetrics::default(),
105 intelligibility_metrics: IntelligibilityMetrics::default(),
106 }
107 }
108
109 pub fn compute_quality_metrics<F: Float>(
111 &mut self,
112 clean_signal: ArrayView1<F>,
113 processed_signal: ArrayView1<F>,
114 noise_signal: Option<ArrayView1<F>>,
115 sample_rate: f64,
116 ) -> Result<AudioQualityResults> {
117 if clean_signal.len() != processed_signal.len() {
118 return Err(MetricsError::InvalidInput(
119 "Clean and processed signals must have the same length".to_string(),
120 ));
121 }
122
123 self.objective_metrics
125 .compute_snr(clean_signal, processed_signal)?;
126 self.objective_metrics
127 .compute_sdr(clean_signal, processed_signal)?;
128
129 if let Some(noise) = noise_signal {
130 self.objective_metrics
131 .compute_sir(clean_signal, processed_signal, noise)?;
132 }
133
134 self.perceptual_metrics
136 .compute_pesq(clean_signal, processed_signal, sample_rate)?;
137 self.perceptual_metrics
138 .compute_stoi(clean_signal, processed_signal, sample_rate)?;
139 self.perceptual_metrics
140 .compute_si_sdr(clean_signal, processed_signal)?;
141
142 self.intelligibility_metrics
144 .compute_ncm(clean_signal, processed_signal)?;
145 self.intelligibility_metrics
146 .compute_csii(clean_signal, processed_signal, sample_rate)?;
147
148 Ok(AudioQualityResults {
149 pesq: self.perceptual_metrics.pesq,
150 stoi: self.perceptual_metrics.stoi,
151 snr: self.objective_metrics.snr,
152 sdr: self.objective_metrics.sdr,
153 si_sdr: self.perceptual_metrics.si_sdr,
154 })
155 }
156
157 pub fn compute_pesq<F: Float>(
159 &mut self,
160 reference: ArrayView1<F>,
161 degraded: ArrayView1<F>,
162 sample_rate: f64,
163 ) -> Result<f64> {
164 self.perceptual_metrics
165 .compute_pesq(reference, degraded, sample_rate)
166 }
167
168 pub fn compute_stoi<F: Float>(
170 &mut self,
171 reference: ArrayView1<F>,
172 degraded: ArrayView1<F>,
173 sample_rate: f64,
174 ) -> Result<f64> {
175 self.perceptual_metrics
176 .compute_stoi(reference, degraded, sample_rate)
177 }
178
179 pub fn compute_snr<F: Float>(
181 &mut self,
182 signal: ArrayView1<F>,
183 noise: ArrayView1<F>,
184 ) -> Result<f64> {
185 self.objective_metrics.compute_snr(signal, noise)
186 }
187
188 pub fn compute_sdr<F: Float>(
190 &mut self,
191 reference: ArrayView1<F>,
192 estimate: ArrayView1<F>,
193 ) -> Result<f64> {
194 self.objective_metrics.compute_sdr(reference, estimate)
195 }
196
197 pub fn get_results(&self) -> AudioQualityResults {
199 AudioQualityResults {
200 pesq: self.perceptual_metrics.pesq,
201 stoi: self.perceptual_metrics.stoi,
202 snr: self.objective_metrics.snr,
203 sdr: self.objective_metrics.sdr,
204 si_sdr: self.perceptual_metrics.si_sdr,
205 }
206 }
207
208 pub fn evaluate_quality<F>(
210 &mut self,
211 reference_audio: ArrayView1<F>,
212 degraded_audio: ArrayView1<F>,
213 sample_rate: f64,
214 ) -> Result<AudioQualityResults>
215 where
216 F: Float + std::fmt::Debug + std::iter::Sum,
217 {
218 self.compute_quality_metrics(reference_audio, degraded_audio, None, sample_rate)
219 }
220}
221
222impl PerceptualAudioMetrics {
223 pub fn compute_pesq<F: Float>(
225 &mut self,
226 reference: ArrayView1<F>,
227 degraded: ArrayView1<F>,
228 sample_rate: f64,
229 ) -> Result<f64> {
230 if reference.len() != degraded.len() {
231 return Err(MetricsError::InvalidInput(
232 "Reference and degraded signals must have the same length".to_string(),
233 ));
234 }
235
236 let min_length = 8000; if reference.len() < min_length {
239 return Err(MetricsError::InvalidInput(
240 "Signal too short for PESQ computation".to_string(),
241 ));
242 }
243
244 let correlation = self.compute_correlation(reference, degraded);
246 let pesq_score = (correlation * 4.5).max(1.0).min(4.5); self.pesq = Some(pesq_score);
249 Ok(pesq_score)
250 }
251
252 pub fn compute_stoi<F: Float>(
254 &mut self,
255 reference: ArrayView1<F>,
256 degraded: ArrayView1<F>,
257 sample_rate: f64,
258 ) -> Result<f64> {
259 if reference.len() != degraded.len() {
260 return Err(MetricsError::InvalidInput(
261 "Reference and degraded signals must have the same length".to_string(),
262 ));
263 }
264
265 let frame_length = (sample_rate * 0.025) as usize; let hop_length = frame_length / 2;
268
269 if reference.len() < frame_length {
270 return Err(MetricsError::InvalidInput(
271 "Signal too short for STOI computation".to_string(),
272 ));
273 }
274
275 let mut stoi_values = Vec::new();
276
277 for i in (0..reference.len() - frame_length).step_by(hop_length) {
278 let ref_frame = reference.slice(s![i..i + frame_length]);
279 let deg_frame = degraded.slice(s![i..i + frame_length]);
280
281 let correlation = self.compute_correlation(ref_frame, deg_frame);
282 stoi_values.push(correlation.max(0.0).min(1.0));
283 }
284
285 let stoi_score = if !stoi_values.is_empty() {
286 stoi_values.iter().sum::<f64>() / stoi_values.len() as f64
287 } else {
288 0.0
289 };
290
291 self.stoi = Some(stoi_score);
292 Ok(stoi_score)
293 }
294
295 pub fn compute_si_sdr<F: Float>(
297 &mut self,
298 reference: ArrayView1<F>,
299 estimate: ArrayView1<F>,
300 ) -> Result<f64> {
301 if reference.len() != estimate.len() {
302 return Err(MetricsError::InvalidInput(
303 "Reference and estimate signals must have the same length".to_string(),
304 ));
305 }
306
307 let ref_vec: Vec<f64> = reference
309 .iter()
310 .map(|&x| x.to_f64().unwrap_or(0.0))
311 .collect();
312 let est_vec: Vec<f64> = estimate
313 .iter()
314 .map(|&x| x.to_f64().unwrap_or(0.0))
315 .collect();
316
317 let numerator: f64 = ref_vec.iter().zip(&est_vec).map(|(r, e)| r * e).sum();
319 let denominator: f64 = ref_vec.iter().map(|r| r * r).sum();
320
321 if denominator == 0.0 {
322 return Ok(f64::NEG_INFINITY);
323 }
324
325 let alpha = numerator / denominator;
326
327 let scaled_ref: Vec<f64> = ref_vec.iter().map(|r| alpha * r).collect();
329
330 let signal_power: f64 = scaled_ref.iter().map(|s| s * s).sum();
332 let noise_power: f64 = scaled_ref
333 .iter()
334 .zip(&est_vec)
335 .map(|(s, e)| (s - e).powi(2))
336 .sum();
337
338 let si_sdr = if noise_power > 0.0 {
339 10.0 * (signal_power / noise_power).log10()
340 } else {
341 f64::INFINITY
342 };
343
344 self.si_sdr = Some(si_sdr);
345 Ok(si_sdr)
346 }
347
348 fn compute_correlation<F: Float>(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> f64 {
350 if x.len() != y.len() || x.is_empty() {
351 return 0.0;
352 }
353
354 let x_vec: Vec<f64> = x.iter().map(|&v| v.to_f64().unwrap_or(0.0)).collect();
355 let y_vec: Vec<f64> = y.iter().map(|&v| v.to_f64().unwrap_or(0.0)).collect();
356
357 let mean_x = x_vec.iter().sum::<f64>() / x_vec.len() as f64;
358 let mean_y = y_vec.iter().sum::<f64>() / y_vec.len() as f64;
359
360 let numerator: f64 = x_vec
361 .iter()
362 .zip(&y_vec)
363 .map(|(x, y)| (x - mean_x) * (y - mean_y))
364 .sum();
365 let var_x: f64 = x_vec.iter().map(|x| (x - mean_x).powi(2)).sum();
366 let var_y: f64 = y_vec.iter().map(|y| (y - mean_y).powi(2)).sum();
367
368 let denominator = (var_x * var_y).sqrt();
369
370 if denominator > 0.0 {
371 numerator / denominator
372 } else {
373 0.0
374 }
375 }
376}
377
378impl ObjectiveAudioMetrics {
379 pub fn compute_snr<F: Float>(
381 &mut self,
382 signal: ArrayView1<F>,
383 noise: ArrayView1<F>,
384 ) -> Result<f64> {
385 let signal_power = self.compute_power(signal);
386 let noise_power = self.compute_power(noise);
387
388 self.snr = if noise_power > 0.0 {
389 10.0 * (signal_power / noise_power).log10()
390 } else {
391 f64::INFINITY
392 };
393
394 Ok(self.snr)
395 }
396
397 pub fn compute_sdr<F: Float>(
399 &mut self,
400 reference: ArrayView1<F>,
401 estimate: ArrayView1<F>,
402 ) -> Result<f64> {
403 if reference.len() != estimate.len() {
404 return Err(MetricsError::InvalidInput(
405 "Reference and estimate signals must have the same length".to_string(),
406 ));
407 }
408
409 let signal_power = self.compute_power(reference);
410
411 let distortion_power: f64 = reference
413 .iter()
414 .zip(estimate.iter())
415 .map(|(&r, &e)| {
416 let diff = r.to_f64().unwrap_or(0.0) - e.to_f64().unwrap_or(0.0);
417 diff * diff
418 })
419 .sum::<f64>()
420 / reference.len() as f64;
421
422 self.sdr = if distortion_power > 0.0 {
423 10.0 * (signal_power / distortion_power).log10()
424 } else {
425 f64::INFINITY
426 };
427
428 Ok(self.sdr)
429 }
430
431 pub fn compute_sir<F: Float>(
433 &mut self,
434 signal: ArrayView1<F>,
435 estimate: ArrayView1<F>,
436 interference: ArrayView1<F>,
437 ) -> Result<f64> {
438 let signal_power = self.compute_power(signal);
439 let interference_power = self.compute_power(interference);
440
441 self.sir = if interference_power > 0.0 {
442 10.0 * (signal_power / interference_power).log10()
443 } else {
444 f64::INFINITY
445 };
446
447 Ok(self.sir)
448 }
449
450 fn compute_power<F: Float>(&self, signal: ArrayView1<F>) -> f64 {
452 if signal.is_empty() {
453 return 0.0;
454 }
455
456 signal
457 .iter()
458 .map(|&x| {
459 let val = x.to_f64().unwrap_or(0.0);
460 val * val
461 })
462 .sum::<f64>()
463 / signal.len() as f64
464 }
465
466 pub fn compute_spectral_distortion<F: Float>(
468 &mut self,
469 reference: ArrayView1<F>,
470 estimate: ArrayView1<F>,
471 ) -> Result<()> {
472 self.spectral_distortion
473 .compute_log_spectral_distance(reference, estimate)?;
474 self.spectral_distortion
475 .compute_itakura_saito_distance(reference, estimate)?;
476 Ok(())
477 }
478}
479
480impl SpectralDistortionMetrics {
481 pub fn compute_log_spectral_distance<F: Float>(
483 &mut self,
484 reference: ArrayView1<F>,
485 estimate: ArrayView1<F>,
486 ) -> Result<f64> {
487 let ref_spectrum = self.compute_simple_spectrum(reference);
489 let est_spectrum = self.compute_simple_spectrum(estimate);
490
491 if ref_spectrum.len() != est_spectrum.len() {
492 return Err(MetricsError::InvalidInput(
493 "Spectrum lengths must match".to_string(),
494 ));
495 }
496
497 let mut distance_sum = 0.0;
498 let mut valid_bins = 0;
499
500 for (ref_bin, est_bin) in ref_spectrum.iter().zip(est_spectrum.iter()) {
501 if *ref_bin > 0.0 && *est_bin > 0.0 {
502 distance_sum += (ref_bin.ln() - est_bin.ln()).powi(2);
503 valid_bins += 1;
504 }
505 }
506
507 self.log_spectral_distance = if valid_bins > 0 {
508 (distance_sum / valid_bins as f64).sqrt()
509 } else {
510 0.0
511 };
512
513 Ok(self.log_spectral_distance)
514 }
515
516 pub fn compute_itakura_saito_distance<F: Float>(
518 &mut self,
519 reference: ArrayView1<F>,
520 estimate: ArrayView1<F>,
521 ) -> Result<f64> {
522 let ref_spectrum = self.compute_simple_spectrum(reference);
523 let est_spectrum = self.compute_simple_spectrum(estimate);
524
525 let mut distance_sum = 0.0;
526 let mut valid_bins = 0;
527
528 for (ref_bin, est_bin) in ref_spectrum.iter().zip(est_spectrum.iter()) {
529 if *ref_bin > 0.0 && *est_bin > 0.0 {
530 distance_sum += (ref_bin / est_bin) - (ref_bin / est_bin).ln() - 1.0;
531 valid_bins += 1;
532 }
533 }
534
535 self.itakura_saito_distance = if valid_bins > 0 {
536 distance_sum / valid_bins as f64
537 } else {
538 0.0
539 };
540
541 Ok(self.itakura_saito_distance)
542 }
543
544 fn compute_simple_spectrum<F: Float>(&self, signal: ArrayView1<F>) -> Vec<f64> {
546 let window_size = signal.len().min(1024);
548 let mut spectrum = Vec::with_capacity(window_size / 2);
549
550 for i in 0..window_size / 2 {
551 let start = i * 2;
552 let end = (start + window_size).min(signal.len());
553
554 if start < signal.len() {
555 let power: f64 = signal
556 .slice(s![start..end])
557 .iter()
558 .map(|&x| {
559 let val = x.to_f64().unwrap_or(0.0);
560 val * val
561 })
562 .sum::<f64>()
563 / (end - start) as f64;
564
565 spectrum.push(power.max(1e-10)); }
567 }
568
569 spectrum
570 }
571}
572
573impl IntelligibilityMetrics {
574 pub fn compute_ncm<F: Float>(
576 &mut self,
577 reference: ArrayView1<F>,
578 degraded: ArrayView1<F>,
579 ) -> Result<f64> {
580 if reference.len() != degraded.len() {
581 return Err(MetricsError::InvalidInput(
582 "Reference and degraded signals must have the same length".to_string(),
583 ));
584 }
585
586 let correlation = self.compute_cross_correlation(reference, degraded);
588 self.ncm = correlation.abs();
589 Ok(self.ncm)
590 }
591
592 pub fn compute_csii<F: Float>(
594 &mut self,
595 reference: ArrayView1<F>,
596 degraded: ArrayView1<F>,
597 sample_rate: f64,
598 ) -> Result<f64> {
599 let frame_length = (sample_rate * 0.032) as usize; let hop_length = frame_length / 2;
602
603 let mut coherence_values = Vec::new();
604
605 for i in (0..reference.len() - frame_length).step_by(hop_length) {
606 let ref_frame = reference.slice(s![i..i + frame_length]);
607 let deg_frame = degraded.slice(s![i..i + frame_length]);
608
609 let coherence = self.compute_frame_coherence(ref_frame, deg_frame);
610 coherence_values.push(coherence);
611 }
612
613 self.csii = if !coherence_values.is_empty() {
614 coherence_values.iter().sum::<f64>() / coherence_values.len() as f64
615 } else {
616 0.0
617 };
618
619 Ok(self.csii)
620 }
621
622 fn compute_cross_correlation<F: Float>(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> f64 {
624 if x.len() != y.len() || x.is_empty() {
625 return 0.0;
626 }
627
628 let x_vec: Vec<f64> = x.iter().map(|&v| v.to_f64().unwrap_or(0.0)).collect();
629 let y_vec: Vec<f64> = y.iter().map(|&v| v.to_f64().unwrap_or(0.0)).collect();
630
631 let mean_x = x_vec.iter().sum::<f64>() / x_vec.len() as f64;
632 let mean_y = y_vec.iter().sum::<f64>() / y_vec.len() as f64;
633
634 let numerator: f64 = x_vec
635 .iter()
636 .zip(&y_vec)
637 .map(|(x, y)| (x - mean_x) * (y - mean_y))
638 .sum();
639 let var_x: f64 = x_vec.iter().map(|x| (x - mean_x).powi(2)).sum();
640 let var_y: f64 = y_vec.iter().map(|y| (y - mean_y).powi(2)).sum();
641
642 let denominator = (var_x * var_y).sqrt();
643
644 if denominator > 0.0 {
645 numerator / denominator
646 } else {
647 0.0
648 }
649 }
650
651 fn compute_frame_coherence<F: Float>(&self, x: ArrayView1<F>, y: ArrayView1<F>) -> f64 {
653 self.compute_cross_correlation(x, y).abs()
655 }
656}
657
658use scirs2_core::ndarray::s;
660
661impl Default for AudioQualityMetrics {
662 fn default() -> Self {
663 Self::new()
664 }
665}