scirs2_metrics/domains/audio_processing/
audio_classification.rs

1//! Audio classification metrics
2//!
3//! This module provides comprehensive metrics for evaluating audio classification tasks,
4//! including general classification metrics, audio-specific metrics like Equal Error Rate (EER),
5//! temporal consistency measures, and boundary detection capabilities.
6
7#![allow(clippy::too_many_arguments)]
8#![allow(dead_code)]
9
10use crate::error::{MetricsError, Result};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
12use scirs2_core::numeric::Float;
13use serde::{Deserialize, Serialize};
14
15/// Audio classification metrics
16#[derive(Debug, Clone)]
17pub struct AudioClassificationMetrics {
18    /// Standard classification metrics
19    classification_metrics: crate::sklearn_compat::ClassificationMetrics,
20    /// Audio-specific metrics
21    audio_specific: AudioSpecificMetrics,
22    /// Temporal metrics for audio segments
23    temporal_metrics: TemporalAudioMetrics,
24}
25
26/// Audio-specific classification metrics
27#[derive(Debug, Clone)]
28pub struct AudioSpecificMetrics {
29    /// Equal Error Rate (EER)
30    eer: Option<f64>,
31    /// Detection Cost Function (DCF)
32    dcf: Option<f64>,
33    /// Area Under ROC Curve for audio
34    auc_audio: Option<f64>,
35    /// Minimum DCF
36    min_dcf: Option<f64>,
37}
38
39/// Temporal metrics for audio classification
40#[derive(Debug, Clone)]
41pub struct TemporalAudioMetrics {
42    /// Frame-level accuracy
43    frame_accuracy: f64,
44    /// Segment-level accuracy
45    segment_accuracy: f64,
46    /// Temporal consistency score
47    temporal_consistency: f64,
48    /// Boundary detection metrics
49    boundary_metrics: BoundaryDetectionMetrics,
50}
51
52/// Boundary detection metrics
53#[derive(Debug, Clone)]
54pub struct BoundaryDetectionMetrics {
55    /// Precision of boundary detection
56    boundary_precision: f64,
57    /// Recall of boundary detection
58    boundary_recall: f64,
59    /// F1 score for boundary detection
60    boundary_f1: f64,
61    /// Boundary tolerance (in seconds)
62    tolerance: f64,
63}
64
65/// Audio classification evaluation results
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct AudioClassificationResults {
68    /// Overall accuracy
69    pub accuracy: f64,
70    /// Precision
71    pub precision: f64,
72    /// Recall
73    pub recall: f64,
74    /// F1 score
75    pub f1_score: f64,
76    /// Equal Error Rate
77    pub eer: Option<f64>,
78    /// Area Under Curve
79    pub auc: f64,
80    /// Frame-level accuracy
81    pub frame_accuracy: f64,
82}
83
84impl AudioClassificationMetrics {
85    /// Create new audio classification metrics
86    pub fn new() -> Self {
87        Self {
88            classification_metrics: crate::sklearn_compat::ClassificationMetrics::new(),
89            audio_specific: AudioSpecificMetrics::new(),
90            temporal_metrics: TemporalAudioMetrics::new(),
91        }
92    }
93
94    /// Compute comprehensive audio classification metrics
95    pub fn compute_metrics<F: Float + std::fmt::Debug>(
96        &mut self,
97        y_true: ArrayView1<i32>,
98        y_pred: ArrayView1<i32>,
99        y_scores: Option<ArrayView2<F>>,
100        frame_predictions: Option<ArrayView2<i32>>,
101    ) -> Result<AudioClassificationResults> {
102        // Compute standard classification metrics
103        let standard_results = self.classification_metrics.compute(
104            y_true,
105            y_pred,
106            y_scores.map(|s| s.map(|&x| x.to_f64().unwrap_or(0.0))),
107        )?;
108
109        // Compute audio-specific metrics
110        if let Some(scores) = y_scores {
111            self.audio_specific.compute_eer(y_true, scores.column(0))?;
112            self.audio_specific.compute_dcf(y_true, scores.column(0))?;
113        }
114
115        // Compute temporal metrics if frame-level data is available
116        if let Some(frame_preds) = frame_predictions {
117            self.temporal_metrics.compute_frame_accuracy(frame_preds)?;
118            self.temporal_metrics
119                .compute_temporal_consistency(frame_preds)?;
120        }
121
122        Ok(AudioClassificationResults {
123            accuracy: standard_results.accuracy,
124            precision: standard_results.precision_weighted,
125            recall: standard_results.recall_weighted,
126            f1_score: standard_results.f1_weighted,
127            eer: self.audio_specific.eer,
128            auc: standard_results.auc_roc,
129            frame_accuracy: self.temporal_metrics.frame_accuracy,
130        })
131    }
132
133    /// Compute Equal Error Rate (EER)
134    pub fn compute_eer<F: Float>(
135        &mut self,
136        y_true: ArrayView1<i32>,
137        y_scores: ArrayView1<F>,
138    ) -> Result<f64> {
139        self.audio_specific.compute_eer(y_true, y_scores)
140    }
141
142    /// Compute Detection Cost Function (DCF)
143    pub fn compute_dcf<F: Float>(
144        &mut self,
145        y_true: ArrayView1<i32>,
146        y_scores: ArrayView1<F>,
147    ) -> Result<f64> {
148        self.audio_specific.compute_dcf(y_true, y_scores)
149    }
150
151    /// Compute frame-level accuracy
152    pub fn compute_frame_accuracy(&mut self, frame_predictions: ArrayView2<i32>) -> Result<f64> {
153        self.temporal_metrics
154            .compute_frame_accuracy(frame_predictions)
155    }
156
157    /// Compute temporal consistency
158    pub fn compute_temporal_consistency(
159        &mut self,
160        frame_predictions: ArrayView2<i32>,
161    ) -> Result<f64> {
162        self.temporal_metrics
163            .compute_temporal_consistency(frame_predictions)
164    }
165
166    /// Detect segment boundaries
167    pub fn detect_boundaries(
168        &mut self,
169        predictions: ArrayView1<i32>,
170        timestamps: ArrayView1<f64>,
171    ) -> Result<Vec<f64>> {
172        self.temporal_metrics
173            .boundary_metrics
174            .detect_boundaries(predictions, timestamps)
175    }
176
177    /// Get comprehensive results
178    pub fn get_results(&self) -> AudioClassificationResults {
179        AudioClassificationResults {
180            accuracy: 0.0, // Would be computed from standard metrics
181            precision: 0.0,
182            recall: 0.0,
183            f1_score: 0.0,
184            eer: self.audio_specific.eer,
185            auc: 0.0,
186            frame_accuracy: self.temporal_metrics.frame_accuracy,
187        }
188    }
189}
190
191impl AudioSpecificMetrics {
192    /// Create new audio-specific metrics
193    pub fn new() -> Self {
194        Self {
195            eer: None,
196            dcf: None,
197            auc_audio: None,
198            min_dcf: None,
199        }
200    }
201
202    /// Compute Equal Error Rate (EER)
203    pub fn compute_eer<F: Float>(
204        &mut self,
205        y_true: ArrayView1<i32>,
206        y_scores: ArrayView1<F>,
207    ) -> Result<f64> {
208        if y_true.len() != y_scores.len() {
209            return Err(MetricsError::InvalidInput(
210                "True labels and scores must have the same length".to_string(),
211            ));
212        }
213
214        // Create (score, label) pairs and sort by score
215        let mut score_label_pairs: Vec<(f64, i32)> = y_true
216            .iter()
217            .zip(y_scores.iter())
218            .map(|(&label, &score)| (score.to_f64().unwrap_or(0.0), label))
219            .collect();
220
221        score_label_pairs
222            .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
223
224        let total_positives = y_true.iter().filter(|&&x| x == 1).count() as f64;
225        let total_negatives = y_true.iter().filter(|&&x| x == 0).count() as f64;
226
227        if total_positives == 0.0 || total_negatives == 0.0 {
228            return Err(MetricsError::InvalidInput(
229                "Need both positive and negative examples for EER".to_string(),
230            ));
231        }
232
233        let mut min_diff = f64::INFINITY;
234        let mut best_eer = 0.0;
235
236        let mut true_positives = 0.0;
237        let mut false_positives = 0.0;
238
239        for (_, label) in score_label_pairs.iter().rev() {
240            if *label == 1 {
241                true_positives += 1.0;
242            } else {
243                false_positives += 1.0;
244            }
245
246            let tpr = true_positives / total_positives;
247            let fpr = false_positives / total_negatives;
248            let fnr = 1.0 - tpr;
249
250            let diff = (fpr - fnr).abs();
251            if diff < min_diff {
252                min_diff = diff;
253                best_eer = (fpr + fnr) / 2.0;
254            }
255        }
256
257        self.eer = Some(best_eer);
258        Ok(best_eer)
259    }
260
261    /// Compute Detection Cost Function (DCF)
262    pub fn compute_dcf<F: Float>(
263        &mut self,
264        y_true: ArrayView1<i32>,
265        y_scores: ArrayView1<F>,
266    ) -> Result<f64> {
267        // DCF parameters (NIST SRE standard)
268        let c_miss = 1.0;
269        let c_fa = 1.0;
270        let p_target = 0.01;
271
272        let mut score_label_pairs: Vec<(f64, i32)> = y_true
273            .iter()
274            .zip(y_scores.iter())
275            .map(|(&label, &score)| (score.to_f64().unwrap_or(0.0), label))
276            .collect();
277
278        score_label_pairs
279            .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
280
281        let total_positives = y_true.iter().filter(|&&x| x == 1).count() as f64;
282        let total_negatives = y_true.iter().filter(|&&x| x == 0).count() as f64;
283
284        let mut min_dcf = f64::INFINITY;
285        let mut true_positives = 0.0;
286        let mut false_positives = 0.0;
287
288        for (_, label) in score_label_pairs.iter().rev() {
289            if *label == 1 {
290                true_positives += 1.0;
291            } else {
292                false_positives += 1.0;
293            }
294
295            let pmiss = 1.0 - (true_positives / total_positives);
296            let pfa = false_positives / total_negatives;
297
298            let dcf = c_miss * pmiss * p_target + c_fa * pfa * (1.0 - p_target);
299            min_dcf = min_dcf.min(dcf);
300        }
301
302        self.dcf = Some(min_dcf);
303        self.min_dcf = Some(min_dcf);
304        Ok(min_dcf)
305    }
306}
307
308impl TemporalAudioMetrics {
309    /// Create new temporal audio metrics
310    pub fn new() -> Self {
311        Self {
312            frame_accuracy: 0.0,
313            segment_accuracy: 0.0,
314            temporal_consistency: 0.0,
315            boundary_metrics: BoundaryDetectionMetrics::new(),
316        }
317    }
318
319    /// Compute frame-level accuracy
320    pub fn compute_frame_accuracy(&mut self, frame_predictions: ArrayView2<i32>) -> Result<f64> {
321        let (n_utterances, n_frames) = frame_predictions.dim();
322
323        if n_utterances == 0 || n_frames == 0 {
324            return Ok(0.0);
325        }
326
327        // Placeholder: would compute frame-level accuracy from ground truth
328        let total_frames = (n_utterances * n_frames) as f64;
329        let correct_frames = total_frames * 0.85; // Placeholder
330
331        self.frame_accuracy = correct_frames / total_frames;
332        Ok(self.frame_accuracy)
333    }
334
335    /// Compute temporal consistency score
336    pub fn compute_temporal_consistency(
337        &mut self,
338        frame_predictions: ArrayView2<i32>,
339    ) -> Result<f64> {
340        let (n_utterances, n_frames) = frame_predictions.dim();
341
342        if n_utterances == 0 || n_frames < 2 {
343            return Ok(0.0);
344        }
345
346        let mut total_consistency = 0.0;
347        let mut total_transitions = 0;
348
349        for i in 0..n_utterances {
350            for j in 1..n_frames {
351                let prev_pred = frame_predictions[[i, j - 1]];
352                let curr_pred = frame_predictions[[i, j]];
353
354                // Count consistent transitions
355                if prev_pred == curr_pred {
356                    total_consistency += 1.0;
357                }
358                total_transitions += 1;
359            }
360        }
361
362        self.temporal_consistency = if total_transitions > 0 {
363            total_consistency / total_transitions as f64
364        } else {
365            0.0
366        };
367
368        Ok(self.temporal_consistency)
369    }
370}
371
372impl BoundaryDetectionMetrics {
373    /// Create new boundary detection metrics
374    pub fn new() -> Self {
375        Self {
376            boundary_precision: 0.0,
377            boundary_recall: 0.0,
378            boundary_f1: 0.0,
379            tolerance: 0.5, // 500ms tolerance
380        }
381    }
382
383    /// Detect boundaries in prediction sequence
384    pub fn detect_boundaries(
385        &mut self,
386        predictions: ArrayView1<i32>,
387        timestamps: ArrayView1<f64>,
388    ) -> Result<Vec<f64>> {
389        if predictions.len() != timestamps.len() {
390            return Err(MetricsError::InvalidInput(
391                "Predictions and timestamps must have the same length".to_string(),
392            ));
393        }
394
395        let mut boundaries = Vec::new();
396
397        for i in 1..predictions.len() {
398            if predictions[i] != predictions[i - 1] {
399                boundaries.push(timestamps[i]);
400            }
401        }
402
403        Ok(boundaries)
404    }
405
406    /// Evaluate boundary detection performance
407    pub fn evaluate_boundaries(&mut self, detected: &[f64], reference: &[f64]) -> Result<()> {
408        if reference.is_empty() {
409            self.boundary_precision = if detected.is_empty() { 1.0 } else { 0.0 };
410            self.boundary_recall = 1.0;
411            self.boundary_f1 = if detected.is_empty() { 1.0 } else { 0.0 };
412            return Ok(());
413        }
414
415        let mut true_positives = 0;
416        let mut false_positives = 0;
417        let mut false_negatives = 0;
418
419        // Count true positives and false positives
420        for &det_boundary in detected {
421            let mut matched = false;
422            for &ref_boundary in reference {
423                if (det_boundary - ref_boundary).abs() <= self.tolerance {
424                    true_positives += 1;
425                    matched = true;
426                    break;
427                }
428            }
429            if !matched {
430                false_positives += 1;
431            }
432        }
433
434        // Count false negatives
435        for &ref_boundary in reference {
436            let mut matched = false;
437            for &det_boundary in detected {
438                if (det_boundary - ref_boundary).abs() <= self.tolerance {
439                    matched = true;
440                    break;
441                }
442            }
443            if !matched {
444                false_negatives += 1;
445            }
446        }
447
448        // Calculate metrics
449        self.boundary_precision = if true_positives + false_positives > 0 {
450            true_positives as f64 / (true_positives + false_positives) as f64
451        } else {
452            0.0
453        };
454
455        self.boundary_recall = if true_positives + false_negatives > 0 {
456            true_positives as f64 / (true_positives + false_negatives) as f64
457        } else {
458            0.0
459        };
460
461        self.boundary_f1 = if self.boundary_precision + self.boundary_recall > 0.0 {
462            2.0 * self.boundary_precision * self.boundary_recall
463                / (self.boundary_precision + self.boundary_recall)
464        } else {
465            0.0
466        };
467
468        Ok(())
469    }
470
471    /// Set boundary tolerance
472    pub fn set_tolerance(&mut self, tolerance: f64) {
473        self.tolerance = tolerance;
474    }
475}
476
477impl Default for AudioClassificationMetrics {
478    fn default() -> Self {
479        Self::new()
480    }
481}
482
483impl Default for AudioSpecificMetrics {
484    fn default() -> Self {
485        Self::new()
486    }
487}
488
489impl Default for TemporalAudioMetrics {
490    fn default() -> Self {
491        Self::new()
492    }
493}
494
495impl Default for BoundaryDetectionMetrics {
496    fn default() -> Self {
497        Self::new()
498    }
499}