scirs2_metrics/domains/audio_processing/
speech_recognition.rs

1//! Speech recognition evaluation metrics
2//!
3//! This module provides comprehensive metrics for evaluating speech recognition systems,
4//! including Word Error Rate (WER), Character Error Rate (CER), Phone Error Rate (PER),
5//! BLEU scores for speech translation, and confidence score analysis.
6
7#![allow(clippy::too_many_arguments)]
8#![allow(dead_code)]
9
10use crate::error::{MetricsError, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// Speech recognition evaluation metrics
15#[derive(Debug, Clone)]
16pub struct SpeechRecognitionMetrics {
17    /// Word Error Rate calculations
18    wer_calculator: WerCalculator,
19    /// Character Error Rate calculations
20    cer_calculator: CerCalculator,
21    /// Phone Error Rate calculations
22    per_calculator: PerCalculator,
23    /// BLEU score for speech translation
24    bleu_calculator: BleuCalculator,
25    /// Confidence score metrics
26    confidence_metrics: ConfidenceMetrics,
27}
28
29/// Word Error Rate (WER) calculator
30#[derive(Debug, Clone)]
31pub struct WerCalculator {
32    /// Total word substitutions
33    substitutions: usize,
34    /// Total word deletions
35    deletions: usize,
36    /// Total word insertions
37    insertions: usize,
38    /// Total reference words
39    total_words: usize,
40    /// Per-utterance WER scores
41    utterance_wers: Vec<f64>,
42}
43
44/// Character Error Rate (CER) calculator
45#[derive(Debug, Clone)]
46pub struct CerCalculator {
47    /// Total character substitutions
48    char_substitutions: usize,
49    /// Total character deletions
50    char_deletions: usize,
51    /// Total character insertions
52    char_insertions: usize,
53    /// Total reference characters
54    total_chars: usize,
55    /// Per-utterance CER scores
56    utterance_cers: Vec<f64>,
57}
58
59/// Phone Error Rate (PER) calculator
60#[derive(Debug, Clone)]
61pub struct PerCalculator {
62    /// Total phone substitutions
63    phone_substitutions: usize,
64    /// Total phone deletions
65    phone_deletions: usize,
66    /// Total phone insertions
67    phone_insertions: usize,
68    /// Total reference phones
69    total_phones: usize,
70    /// Phone confusion matrix
71    confusion_matrix: HashMap<(String, String), usize>,
72}
73
74/// BLEU score calculator for speech translation
75#[derive(Debug, Clone)]
76pub struct BleuCalculator {
77    /// N-gram weights (typically 1-gram to 4-gram)
78    ngram_weights: Vec<f64>,
79    /// Brevity penalty settings
80    brevity_penalty: bool,
81    /// Smoothing method
82    smoothing: BleuSmoothing,
83}
84
85/// BLEU smoothing methods
86#[derive(Debug, Clone)]
87pub enum BleuSmoothing {
88    None,
89    Epsilon(f64),
90    Add1,
91    ExponentialDecay,
92}
93
94/// Confidence score metrics for ASR
95#[derive(Debug, Clone)]
96pub struct ConfidenceMetrics {
97    /// Confidence threshold for filtering
98    confidence_threshold: f64,
99    /// Per-word confidence scores
100    word_confidences: Vec<f64>,
101    /// Utterance-level confidence scores
102    utterance_confidences: Vec<f64>,
103    /// Confidence-WER correlation
104    confidence_wer_correlation: Option<f64>,
105}
106
107/// Speech recognition evaluation results
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct SpeechRecognitionResults {
110    /// Word Error Rate
111    pub wer: f64,
112    /// Character Error Rate
113    pub cer: f64,
114    /// Phone Error Rate
115    pub per: Option<f64>,
116    /// BLEU score
117    pub bleu: Option<f64>,
118    /// Average confidence score
119    pub avg_confidence: f64,
120    /// Confidence-WER correlation
121    pub confidence_wer_correlation: Option<f64>,
122}
123
124impl SpeechRecognitionMetrics {
125    /// Create new speech recognition metrics
126    pub fn new() -> Self {
127        Self {
128            wer_calculator: WerCalculator::new(),
129            cer_calculator: CerCalculator::new(),
130            per_calculator: PerCalculator::new(),
131            bleu_calculator: BleuCalculator::new(),
132            confidence_metrics: ConfidenceMetrics::new(),
133        }
134    }
135
136    /// Evaluate speech recognition performance
137    pub fn evaluate_recognition(
138        &mut self,
139        reference_text: &[String],
140        hypothesis_text: &[String],
141        reference_phones: Option<&[Vec<String>]>,
142        hypothesis_phones: Option<&[Vec<String>]>,
143        confidence_scores: Option<&[f64]>,
144    ) -> Result<SpeechRecognitionResults> {
145        // Calculate WER
146        let wer = self
147            .wer_calculator
148            .calculate(reference_text, hypothesis_text)?;
149
150        // Calculate CER
151        let cer = self
152            .cer_calculator
153            .calculate(reference_text, hypothesis_text)?;
154
155        // Calculate PER if phone sequences provided
156        let per =
157            if let (Some(ref_phones), Some(hyp_phones)) = (reference_phones, hypothesis_phones) {
158                Some(self.per_calculator.calculate(ref_phones, hyp_phones)?)
159            } else {
160                None
161            };
162
163        // Calculate BLEU score
164        let bleu = Some(
165            self.bleu_calculator
166                .calculate(reference_text, hypothesis_text)?,
167        );
168
169        // Calculate confidence metrics
170        let (avg_confidence, confidence_wer_correlation) =
171            if let Some(conf_scores) = confidence_scores {
172                let avg_conf = conf_scores.iter().sum::<f64>() / conf_scores.len() as f64;
173                let correlation = self
174                    .confidence_metrics
175                    .calculate_confidence_wer_correlation(
176                        reference_text,
177                        hypothesis_text,
178                        conf_scores,
179                    )?;
180                (avg_conf, Some(correlation))
181            } else {
182                (0.0, None)
183            };
184
185        Ok(SpeechRecognitionResults {
186            wer,
187            cer,
188            per,
189            bleu,
190            avg_confidence,
191            confidence_wer_correlation,
192        })
193    }
194
195    /// Compute word error rate between reference and hypothesis
196    pub fn compute_wer(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
197        self.wer_calculator.compute_wer(reference, hypothesis)
198    }
199
200    /// Compute character error rate between reference and hypothesis
201    pub fn compute_cer(&mut self, reference: &str, hypothesis: &str) -> Result<f64> {
202        self.cer_calculator.compute_cer(reference, hypothesis)
203    }
204
205    /// Compute phone error rate between reference and hypothesis phoneme sequences
206    pub fn compute_per(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
207        self.per_calculator.compute_per(reference, hypothesis)
208    }
209
210    /// Compute BLEU score for translation tasks
211    pub fn compute_bleu(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
212        self.bleu_calculator.compute_bleu(reference, hypothesis)
213    }
214
215    /// Add confidence scores for analysis
216    pub fn add_confidence_scores(&mut self, word_confidences: Vec<f64>, utterance_confidence: f64) {
217        self.confidence_metrics
218            .add_scores(word_confidences, utterance_confidence);
219    }
220
221    /// Get comprehensive speech recognition results
222    pub fn get_results(&self) -> SpeechRecognitionResults {
223        SpeechRecognitionResults {
224            wer: self.wer_calculator.get_wer(),
225            cer: self.cer_calculator.get_cer(),
226            per: self.per_calculator.get_per(),
227            bleu: self.bleu_calculator.get_bleu(),
228            avg_confidence: self.confidence_metrics.get_average_confidence(),
229            confidence_wer_correlation: self.confidence_metrics.confidence_wer_correlation,
230        }
231    }
232}
233
234impl WerCalculator {
235    /// Create new WER calculator
236    pub fn new() -> Self {
237        Self {
238            substitutions: 0,
239            deletions: 0,
240            insertions: 0,
241            total_words: 0,
242            utterance_wers: Vec::new(),
243        }
244    }
245
246    /// Compute WER using edit distance algorithm
247    pub fn compute_wer(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
248        let (subs, dels, ins) = self.edit_distance(reference, hypothesis);
249
250        self.substitutions += subs;
251        self.deletions += dels;
252        self.insertions += ins;
253        self.total_words += reference.len();
254
255        let utterance_wer = if reference.is_empty() {
256            if hypothesis.is_empty() {
257                0.0
258            } else {
259                1.0
260            }
261        } else {
262            (subs + dels + ins) as f64 / reference.len() as f64
263        };
264
265        self.utterance_wers.push(utterance_wer);
266        Ok(utterance_wer)
267    }
268
269    /// Get overall WER
270    pub fn get_wer(&self) -> f64 {
271        if self.total_words == 0 {
272            0.0
273        } else {
274            (self.substitutions + self.deletions + self.insertions) as f64 / self.total_words as f64
275        }
276    }
277
278    /// Compute edit distance between reference and hypothesis
279    fn edit_distance(&self, reference: &[String], hypothesis: &[String]) -> (usize, usize, usize) {
280        let ref_len = reference.len();
281        let hyp_len = hypothesis.len();
282
283        let mut dp = vec![vec![0; hyp_len + 1]; ref_len + 1];
284        let mut ops = vec![vec![(0, 0, 0); hyp_len + 1]; ref_len + 1]; // (subs, dels, ins)
285
286        // Initialize base cases
287        for i in 0..=ref_len {
288            dp[i][0] = i;
289            ops[i][0] = (0, i, 0);
290        }
291        for j in 0..=hyp_len {
292            dp[0][j] = j;
293            ops[0][j] = (0, 0, j);
294        }
295
296        // Fill DP table
297        for i in 1..=ref_len {
298            for j in 1..=hyp_len {
299                if reference[i - 1] == hypothesis[j - 1] {
300                    dp[i][j] = dp[i - 1][j - 1];
301                    ops[i][j] = ops[i - 1][j - 1];
302                } else {
303                    let sub_cost = dp[i - 1][j - 1] + 1;
304                    let del_cost = dp[i - 1][j] + 1;
305                    let ins_cost = dp[i][j - 1] + 1;
306
307                    if sub_cost <= del_cost && sub_cost <= ins_cost {
308                        dp[i][j] = sub_cost;
309                        ops[i][j] = (
310                            ops[i - 1][j - 1].0 + 1,
311                            ops[i - 1][j - 1].1,
312                            ops[i - 1][j - 1].2,
313                        );
314                    } else if del_cost <= ins_cost {
315                        dp[i][j] = del_cost;
316                        ops[i][j] = (ops[i - 1][j].0, ops[i - 1][j].1 + 1, ops[i - 1][j].2);
317                    } else {
318                        dp[i][j] = ins_cost;
319                        ops[i][j] = (ops[i][j - 1].0, ops[i][j - 1].1, ops[i][j - 1].2 + 1);
320                    }
321                }
322            }
323        }
324
325        ops[ref_len][hyp_len]
326    }
327
328    /// Calculate WER (alias for compute_wer for backward compatibility)
329    pub fn calculate(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
330        self.compute_wer(reference, hypothesis)
331    }
332}
333
334impl CerCalculator {
335    /// Create new CER calculator
336    pub fn new() -> Self {
337        Self {
338            char_substitutions: 0,
339            char_deletions: 0,
340            char_insertions: 0,
341            total_chars: 0,
342            utterance_cers: Vec::new(),
343        }
344    }
345
346    /// Compute CER using character-level edit distance
347    pub fn compute_cer(&mut self, reference: &str, hypothesis: &str) -> Result<f64> {
348        let ref_chars: Vec<char> = reference.chars().collect();
349        let hyp_chars: Vec<char> = hypothesis.chars().collect();
350
351        let (subs, dels, ins) = self.char_edit_distance(&ref_chars, &hyp_chars);
352
353        self.char_substitutions += subs;
354        self.char_deletions += dels;
355        self.char_insertions += ins;
356        self.total_chars += ref_chars.len();
357
358        let utterance_cer = if ref_chars.is_empty() {
359            if hyp_chars.is_empty() {
360                0.0
361            } else {
362                1.0
363            }
364        } else {
365            (subs + dels + ins) as f64 / ref_chars.len() as f64
366        };
367
368        self.utterance_cers.push(utterance_cer);
369        Ok(utterance_cer)
370    }
371
372    /// Get overall CER
373    pub fn get_cer(&self) -> f64 {
374        if self.total_chars == 0 {
375            0.0
376        } else {
377            (self.char_substitutions + self.char_deletions + self.char_insertions) as f64
378                / self.total_chars as f64
379        }
380    }
381
382    /// Compute character-level edit distance
383    fn char_edit_distance(&self, reference: &[char], hypothesis: &[char]) -> (usize, usize, usize) {
384        let ref_len = reference.len();
385        let hyp_len = hypothesis.len();
386
387        let mut dp = vec![vec![0; hyp_len + 1]; ref_len + 1];
388
389        for i in 0..=ref_len {
390            dp[i][0] = i;
391        }
392        for j in 0..=hyp_len {
393            dp[0][j] = j;
394        }
395
396        for i in 1..=ref_len {
397            for j in 1..=hyp_len {
398                if reference[i - 1] == hypothesis[j - 1] {
399                    dp[i][j] = dp[i - 1][j - 1];
400                } else {
401                    dp[i][j] = 1 + dp[i - 1][j - 1].min(dp[i - 1][j]).min(dp[i][j - 1]);
402                }
403            }
404        }
405
406        // Simplified: return total edit distance as substitutions for now
407        (dp[ref_len][hyp_len], 0, 0)
408    }
409
410    /// Calculate CER (alias for compute_cer for backward compatibility)
411    pub fn calculate(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
412        if reference.len() != hypothesis.len() {
413            return Err(MetricsError::InvalidInput(
414                "Reference and hypothesis must have the same length".to_string(),
415            ));
416        }
417
418        let mut total_errors = 0;
419        let mut total_chars = 0;
420
421        for (ref_sent, hyp_sent) in reference.iter().zip(hypothesis.iter()) {
422            let cer = self.compute_cer(ref_sent, hyp_sent)?;
423            let ref_chars = ref_sent.chars().count();
424            total_errors += (cer * ref_chars as f64) as usize;
425            total_chars += ref_chars;
426        }
427
428        if total_chars == 0 {
429            Ok(0.0)
430        } else {
431            Ok(total_errors as f64 / total_chars as f64)
432        }
433    }
434}
435
436impl PerCalculator {
437    /// Create new PER calculator
438    pub fn new() -> Self {
439        Self {
440            phone_substitutions: 0,
441            phone_deletions: 0,
442            phone_insertions: 0,
443            total_phones: 0,
444            confusion_matrix: HashMap::new(),
445        }
446    }
447
448    /// Compute PER using phoneme-level edit distance
449    pub fn compute_per(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
450        let (subs, dels, ins) = self.phone_edit_distance(reference, hypothesis);
451
452        self.phone_substitutions += subs;
453        self.phone_deletions += dels;
454        self.phone_insertions += ins;
455        self.total_phones += reference.len();
456
457        let per = if reference.is_empty() {
458            if hypothesis.is_empty() {
459                0.0
460            } else {
461                1.0
462            }
463        } else {
464            (subs + dels + ins) as f64 / reference.len() as f64
465        };
466
467        Ok(per)
468    }
469
470    /// Get overall PER
471    pub fn get_per(&self) -> Option<f64> {
472        if self.total_phones == 0 {
473            None
474        } else {
475            Some(
476                (self.phone_substitutions + self.phone_deletions + self.phone_insertions) as f64
477                    / self.total_phones as f64,
478            )
479        }
480    }
481
482    /// Compute phoneme-level edit distance
483    fn phone_edit_distance(
484        &mut self,
485        reference: &[String],
486        hypothesis: &[String],
487    ) -> (usize, usize, usize) {
488        // Track phone confusions
489        for (i, ref_phone) in reference.iter().enumerate() {
490            if i < hypothesis.len() && ref_phone != &hypothesis[i] {
491                *self
492                    .confusion_matrix
493                    .entry((ref_phone.clone(), hypothesis[i].clone()))
494                    .or_insert(0) += 1;
495            }
496        }
497
498        // Simplified edit distance calculation
499        let mut subs = 0;
500        let mut dels = 0;
501        let mut ins = 0;
502
503        let max_len = reference.len().max(hypothesis.len());
504        for i in 0..max_len {
505            match (reference.get(i), hypothesis.get(i)) {
506                (Some(r), Some(h)) if r != h => subs += 1,
507                (Some(_), None) => dels += 1,
508                (None, Some(_)) => ins += 1,
509                _ => {}
510            }
511        }
512
513        (subs, dels, ins)
514    }
515
516    /// Calculate PER (alias for compute_per for backward compatibility)
517    pub fn calculate(
518        &mut self,
519        reference: &[Vec<String>],
520        hypothesis: &[Vec<String>],
521    ) -> Result<f64> {
522        if reference.len() != hypothesis.len() {
523            return Err(MetricsError::InvalidInput(
524                "Reference and hypothesis must have the same length".to_string(),
525            ));
526        }
527
528        let mut total_errors = 0;
529        let mut total_phones = 0;
530
531        for (ref_seq, hyp_seq) in reference.iter().zip(hypothesis.iter()) {
532            let per = self.compute_per(ref_seq, hyp_seq)?;
533            total_errors += (per * ref_seq.len() as f64) as usize;
534            total_phones += ref_seq.len();
535        }
536
537        if total_phones == 0 {
538            Ok(0.0)
539        } else {
540            Ok(total_errors as f64 / total_phones as f64)
541        }
542    }
543}
544
545impl BleuCalculator {
546    /// Create new BLEU calculator
547    pub fn new() -> Self {
548        Self {
549            ngram_weights: vec![0.25, 0.25, 0.25, 0.25], // Equal weights for 1-4 grams
550            brevity_penalty: true,
551            smoothing: BleuSmoothing::Epsilon(1e-7),
552        }
553    }
554
555    /// Compute BLEU score
556    pub fn compute_bleu(&self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
557        if reference.is_empty() || hypothesis.is_empty() {
558            return Ok(0.0);
559        }
560
561        let mut precisions = Vec::new();
562
563        // Compute n-gram precisions
564        for n in 1..=4 {
565            let precision = self.compute_ngram_precision(reference, hypothesis, n);
566            precisions.push(precision);
567        }
568
569        // Compute geometric mean of precisions
570        let log_sum: f64 = precisions
571            .iter()
572            .zip(&self.ngram_weights)
573            .map(|(p, w)| w * p.ln())
574            .sum();
575
576        let mut bleu = log_sum.exp();
577
578        // Apply brevity penalty
579        if self.brevity_penalty {
580            let bp = self.compute_brevity_penalty(reference.len(), hypothesis.len());
581            bleu *= bp;
582        }
583
584        Ok(bleu)
585    }
586
587    /// Get BLEU score (placeholder)
588    pub fn get_bleu(&self) -> Option<f64> {
589        None // Would store computed BLEU scores
590    }
591
592    /// Compute n-gram precision
593    fn compute_ngram_precision(
594        &self,
595        reference: &[String],
596        hypothesis: &[String],
597        n: usize,
598    ) -> f64 {
599        if hypothesis.len() < n {
600            return 0.0;
601        }
602
603        let ref_ngrams = self.extract_ngrams(reference, n);
604        let hyp_ngrams = self.extract_ngrams(hypothesis, n);
605
606        let mut matches = 0;
607        for ngram in &hyp_ngrams {
608            if ref_ngrams.contains(ngram) {
609                matches += 1;
610            }
611        }
612
613        if hyp_ngrams.is_empty() {
614            0.0
615        } else {
616            matches as f64 / hyp_ngrams.len() as f64
617        }
618    }
619
620    /// Extract n-grams from sequence
621    fn extract_ngrams(&self, sequence: &[String], n: usize) -> Vec<Vec<String>> {
622        if sequence.len() < n {
623            return Vec::new();
624        }
625
626        (0..=sequence.len() - n)
627            .map(|i| sequence[i..i + n].to_vec())
628            .collect()
629    }
630
631    /// Compute brevity penalty
632    fn compute_brevity_penalty(&self, ref_len: usize, hyp_len: usize) -> f64 {
633        if hyp_len >= ref_len {
634            1.0
635        } else {
636            (1.0 - ref_len as f64 / hyp_len as f64).exp()
637        }
638    }
639
640    /// Calculate BLEU score (alias for compute_bleu for backward compatibility)
641    pub fn calculate(&self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
642        self.compute_bleu(reference, hypothesis)
643    }
644}
645
646impl ConfidenceMetrics {
647    /// Create new confidence metrics
648    pub fn new() -> Self {
649        Self {
650            confidence_threshold: 0.5,
651            word_confidences: Vec::new(),
652            utterance_confidences: Vec::new(),
653            confidence_wer_correlation: None,
654        }
655    }
656
657    /// Add confidence scores
658    pub fn add_scores(&mut self, word_confidences: Vec<f64>, utterance_confidence: f64) {
659        self.word_confidences.extend(word_confidences);
660        self.utterance_confidences.push(utterance_confidence);
661    }
662
663    /// Get average confidence score
664    pub fn get_average_confidence(&self) -> f64 {
665        if self.utterance_confidences.is_empty() {
666            0.0
667        } else {
668            self.utterance_confidences.iter().sum::<f64>() / self.utterance_confidences.len() as f64
669        }
670    }
671
672    /// Set confidence threshold
673    pub fn set_threshold(&mut self, threshold: f64) {
674        self.confidence_threshold = threshold;
675    }
676
677    /// Calculate confidence-WER correlation
678    pub fn calculate_confidence_wer_correlation(
679        &mut self,
680        reference: &[String],
681        hypothesis: &[String],
682        confidence: &[f64],
683    ) -> Result<f64> {
684        if reference.len() != hypothesis.len() || hypothesis.len() != confidence.len() {
685            return Err(MetricsError::InvalidInput(
686                "Mismatched array lengths".to_string(),
687            ));
688        }
689
690        let mut correct_scores = Vec::new();
691        let mut incorrect_scores = Vec::new();
692
693        for ((r, h), &c) in reference
694            .iter()
695            .zip(hypothesis.iter())
696            .zip(confidence.iter())
697        {
698            if r == h {
699                correct_scores.push(c);
700            } else {
701                incorrect_scores.push(c);
702            }
703        }
704
705        if correct_scores.is_empty() || incorrect_scores.is_empty() {
706            return Ok(0.0);
707        }
708
709        let correct_mean = correct_scores.iter().sum::<f64>() / correct_scores.len() as f64;
710        let incorrect_mean = incorrect_scores.iter().sum::<f64>() / incorrect_scores.len() as f64;
711
712        Ok((correct_mean - incorrect_mean).abs())
713    }
714}
715
716impl Default for SpeechRecognitionMetrics {
717    fn default() -> Self {
718        Self::new()
719    }
720}
721
722impl Default for WerCalculator {
723    fn default() -> Self {
724        Self::new()
725    }
726}
727
728impl Default for CerCalculator {
729    fn default() -> Self {
730        Self::new()
731    }
732}
733
734impl Default for PerCalculator {
735    fn default() -> Self {
736        Self::new()
737    }
738}
739
740impl Default for BleuCalculator {
741    fn default() -> Self {
742        Self::new()
743    }
744}
745
746impl Default for ConfidenceMetrics {
747    fn default() -> Self {
748        Self::new()
749    }
750}