1#![allow(clippy::too_many_arguments)]
8#![allow(dead_code)]
9
10use crate::error::{MetricsError, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub struct SpeechRecognitionMetrics {
17 wer_calculator: WerCalculator,
19 cer_calculator: CerCalculator,
21 per_calculator: PerCalculator,
23 bleu_calculator: BleuCalculator,
25 confidence_metrics: ConfidenceMetrics,
27}
28
29#[derive(Debug, Clone)]
31pub struct WerCalculator {
32 substitutions: usize,
34 deletions: usize,
36 insertions: usize,
38 total_words: usize,
40 utterance_wers: Vec<f64>,
42}
43
44#[derive(Debug, Clone)]
46pub struct CerCalculator {
47 char_substitutions: usize,
49 char_deletions: usize,
51 char_insertions: usize,
53 total_chars: usize,
55 utterance_cers: Vec<f64>,
57}
58
59#[derive(Debug, Clone)]
61pub struct PerCalculator {
62 phone_substitutions: usize,
64 phone_deletions: usize,
66 phone_insertions: usize,
68 total_phones: usize,
70 confusion_matrix: HashMap<(String, String), usize>,
72}
73
74#[derive(Debug, Clone)]
76pub struct BleuCalculator {
77 ngram_weights: Vec<f64>,
79 brevity_penalty: bool,
81 smoothing: BleuSmoothing,
83}
84
85#[derive(Debug, Clone)]
87pub enum BleuSmoothing {
88 None,
89 Epsilon(f64),
90 Add1,
91 ExponentialDecay,
92}
93
94#[derive(Debug, Clone)]
96pub struct ConfidenceMetrics {
97 confidence_threshold: f64,
99 word_confidences: Vec<f64>,
101 utterance_confidences: Vec<f64>,
103 confidence_wer_correlation: Option<f64>,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct SpeechRecognitionResults {
110 pub wer: f64,
112 pub cer: f64,
114 pub per: Option<f64>,
116 pub bleu: Option<f64>,
118 pub avg_confidence: f64,
120 pub confidence_wer_correlation: Option<f64>,
122}
123
124impl SpeechRecognitionMetrics {
125 pub fn new() -> Self {
127 Self {
128 wer_calculator: WerCalculator::new(),
129 cer_calculator: CerCalculator::new(),
130 per_calculator: PerCalculator::new(),
131 bleu_calculator: BleuCalculator::new(),
132 confidence_metrics: ConfidenceMetrics::new(),
133 }
134 }
135
136 pub fn evaluate_recognition(
138 &mut self,
139 reference_text: &[String],
140 hypothesis_text: &[String],
141 reference_phones: Option<&[Vec<String>]>,
142 hypothesis_phones: Option<&[Vec<String>]>,
143 confidence_scores: Option<&[f64]>,
144 ) -> Result<SpeechRecognitionResults> {
145 let wer = self
147 .wer_calculator
148 .calculate(reference_text, hypothesis_text)?;
149
150 let cer = self
152 .cer_calculator
153 .calculate(reference_text, hypothesis_text)?;
154
155 let per =
157 if let (Some(ref_phones), Some(hyp_phones)) = (reference_phones, hypothesis_phones) {
158 Some(self.per_calculator.calculate(ref_phones, hyp_phones)?)
159 } else {
160 None
161 };
162
163 let bleu = Some(
165 self.bleu_calculator
166 .calculate(reference_text, hypothesis_text)?,
167 );
168
169 let (avg_confidence, confidence_wer_correlation) =
171 if let Some(conf_scores) = confidence_scores {
172 let avg_conf = conf_scores.iter().sum::<f64>() / conf_scores.len() as f64;
173 let correlation = self
174 .confidence_metrics
175 .calculate_confidence_wer_correlation(
176 reference_text,
177 hypothesis_text,
178 conf_scores,
179 )?;
180 (avg_conf, Some(correlation))
181 } else {
182 (0.0, None)
183 };
184
185 Ok(SpeechRecognitionResults {
186 wer,
187 cer,
188 per,
189 bleu,
190 avg_confidence,
191 confidence_wer_correlation,
192 })
193 }
194
195 pub fn compute_wer(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
197 self.wer_calculator.compute_wer(reference, hypothesis)
198 }
199
200 pub fn compute_cer(&mut self, reference: &str, hypothesis: &str) -> Result<f64> {
202 self.cer_calculator.compute_cer(reference, hypothesis)
203 }
204
205 pub fn compute_per(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
207 self.per_calculator.compute_per(reference, hypothesis)
208 }
209
210 pub fn compute_bleu(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
212 self.bleu_calculator.compute_bleu(reference, hypothesis)
213 }
214
215 pub fn add_confidence_scores(&mut self, word_confidences: Vec<f64>, utterance_confidence: f64) {
217 self.confidence_metrics
218 .add_scores(word_confidences, utterance_confidence);
219 }
220
221 pub fn get_results(&self) -> SpeechRecognitionResults {
223 SpeechRecognitionResults {
224 wer: self.wer_calculator.get_wer(),
225 cer: self.cer_calculator.get_cer(),
226 per: self.per_calculator.get_per(),
227 bleu: self.bleu_calculator.get_bleu(),
228 avg_confidence: self.confidence_metrics.get_average_confidence(),
229 confidence_wer_correlation: self.confidence_metrics.confidence_wer_correlation,
230 }
231 }
232}
233
234impl WerCalculator {
235 pub fn new() -> Self {
237 Self {
238 substitutions: 0,
239 deletions: 0,
240 insertions: 0,
241 total_words: 0,
242 utterance_wers: Vec::new(),
243 }
244 }
245
246 pub fn compute_wer(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
248 let (subs, dels, ins) = self.edit_distance(reference, hypothesis);
249
250 self.substitutions += subs;
251 self.deletions += dels;
252 self.insertions += ins;
253 self.total_words += reference.len();
254
255 let utterance_wer = if reference.is_empty() {
256 if hypothesis.is_empty() {
257 0.0
258 } else {
259 1.0
260 }
261 } else {
262 (subs + dels + ins) as f64 / reference.len() as f64
263 };
264
265 self.utterance_wers.push(utterance_wer);
266 Ok(utterance_wer)
267 }
268
269 pub fn get_wer(&self) -> f64 {
271 if self.total_words == 0 {
272 0.0
273 } else {
274 (self.substitutions + self.deletions + self.insertions) as f64 / self.total_words as f64
275 }
276 }
277
278 fn edit_distance(&self, reference: &[String], hypothesis: &[String]) -> (usize, usize, usize) {
280 let ref_len = reference.len();
281 let hyp_len = hypothesis.len();
282
283 let mut dp = vec![vec![0; hyp_len + 1]; ref_len + 1];
284 let mut ops = vec![vec![(0, 0, 0); hyp_len + 1]; ref_len + 1]; for i in 0..=ref_len {
288 dp[i][0] = i;
289 ops[i][0] = (0, i, 0);
290 }
291 for j in 0..=hyp_len {
292 dp[0][j] = j;
293 ops[0][j] = (0, 0, j);
294 }
295
296 for i in 1..=ref_len {
298 for j in 1..=hyp_len {
299 if reference[i - 1] == hypothesis[j - 1] {
300 dp[i][j] = dp[i - 1][j - 1];
301 ops[i][j] = ops[i - 1][j - 1];
302 } else {
303 let sub_cost = dp[i - 1][j - 1] + 1;
304 let del_cost = dp[i - 1][j] + 1;
305 let ins_cost = dp[i][j - 1] + 1;
306
307 if sub_cost <= del_cost && sub_cost <= ins_cost {
308 dp[i][j] = sub_cost;
309 ops[i][j] = (
310 ops[i - 1][j - 1].0 + 1,
311 ops[i - 1][j - 1].1,
312 ops[i - 1][j - 1].2,
313 );
314 } else if del_cost <= ins_cost {
315 dp[i][j] = del_cost;
316 ops[i][j] = (ops[i - 1][j].0, ops[i - 1][j].1 + 1, ops[i - 1][j].2);
317 } else {
318 dp[i][j] = ins_cost;
319 ops[i][j] = (ops[i][j - 1].0, ops[i][j - 1].1, ops[i][j - 1].2 + 1);
320 }
321 }
322 }
323 }
324
325 ops[ref_len][hyp_len]
326 }
327
328 pub fn calculate(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
330 self.compute_wer(reference, hypothesis)
331 }
332}
333
334impl CerCalculator {
335 pub fn new() -> Self {
337 Self {
338 char_substitutions: 0,
339 char_deletions: 0,
340 char_insertions: 0,
341 total_chars: 0,
342 utterance_cers: Vec::new(),
343 }
344 }
345
346 pub fn compute_cer(&mut self, reference: &str, hypothesis: &str) -> Result<f64> {
348 let ref_chars: Vec<char> = reference.chars().collect();
349 let hyp_chars: Vec<char> = hypothesis.chars().collect();
350
351 let (subs, dels, ins) = self.char_edit_distance(&ref_chars, &hyp_chars);
352
353 self.char_substitutions += subs;
354 self.char_deletions += dels;
355 self.char_insertions += ins;
356 self.total_chars += ref_chars.len();
357
358 let utterance_cer = if ref_chars.is_empty() {
359 if hyp_chars.is_empty() {
360 0.0
361 } else {
362 1.0
363 }
364 } else {
365 (subs + dels + ins) as f64 / ref_chars.len() as f64
366 };
367
368 self.utterance_cers.push(utterance_cer);
369 Ok(utterance_cer)
370 }
371
372 pub fn get_cer(&self) -> f64 {
374 if self.total_chars == 0 {
375 0.0
376 } else {
377 (self.char_substitutions + self.char_deletions + self.char_insertions) as f64
378 / self.total_chars as f64
379 }
380 }
381
382 fn char_edit_distance(&self, reference: &[char], hypothesis: &[char]) -> (usize, usize, usize) {
384 let ref_len = reference.len();
385 let hyp_len = hypothesis.len();
386
387 let mut dp = vec![vec![0; hyp_len + 1]; ref_len + 1];
388
389 for i in 0..=ref_len {
390 dp[i][0] = i;
391 }
392 for j in 0..=hyp_len {
393 dp[0][j] = j;
394 }
395
396 for i in 1..=ref_len {
397 for j in 1..=hyp_len {
398 if reference[i - 1] == hypothesis[j - 1] {
399 dp[i][j] = dp[i - 1][j - 1];
400 } else {
401 dp[i][j] = 1 + dp[i - 1][j - 1].min(dp[i - 1][j]).min(dp[i][j - 1]);
402 }
403 }
404 }
405
406 (dp[ref_len][hyp_len], 0, 0)
408 }
409
410 pub fn calculate(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
412 if reference.len() != hypothesis.len() {
413 return Err(MetricsError::InvalidInput(
414 "Reference and hypothesis must have the same length".to_string(),
415 ));
416 }
417
418 let mut total_errors = 0;
419 let mut total_chars = 0;
420
421 for (ref_sent, hyp_sent) in reference.iter().zip(hypothesis.iter()) {
422 let cer = self.compute_cer(ref_sent, hyp_sent)?;
423 let ref_chars = ref_sent.chars().count();
424 total_errors += (cer * ref_chars as f64) as usize;
425 total_chars += ref_chars;
426 }
427
428 if total_chars == 0 {
429 Ok(0.0)
430 } else {
431 Ok(total_errors as f64 / total_chars as f64)
432 }
433 }
434}
435
436impl PerCalculator {
437 pub fn new() -> Self {
439 Self {
440 phone_substitutions: 0,
441 phone_deletions: 0,
442 phone_insertions: 0,
443 total_phones: 0,
444 confusion_matrix: HashMap::new(),
445 }
446 }
447
448 pub fn compute_per(&mut self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
450 let (subs, dels, ins) = self.phone_edit_distance(reference, hypothesis);
451
452 self.phone_substitutions += subs;
453 self.phone_deletions += dels;
454 self.phone_insertions += ins;
455 self.total_phones += reference.len();
456
457 let per = if reference.is_empty() {
458 if hypothesis.is_empty() {
459 0.0
460 } else {
461 1.0
462 }
463 } else {
464 (subs + dels + ins) as f64 / reference.len() as f64
465 };
466
467 Ok(per)
468 }
469
470 pub fn get_per(&self) -> Option<f64> {
472 if self.total_phones == 0 {
473 None
474 } else {
475 Some(
476 (self.phone_substitutions + self.phone_deletions + self.phone_insertions) as f64
477 / self.total_phones as f64,
478 )
479 }
480 }
481
482 fn phone_edit_distance(
484 &mut self,
485 reference: &[String],
486 hypothesis: &[String],
487 ) -> (usize, usize, usize) {
488 for (i, ref_phone) in reference.iter().enumerate() {
490 if i < hypothesis.len() && ref_phone != &hypothesis[i] {
491 *self
492 .confusion_matrix
493 .entry((ref_phone.clone(), hypothesis[i].clone()))
494 .or_insert(0) += 1;
495 }
496 }
497
498 let mut subs = 0;
500 let mut dels = 0;
501 let mut ins = 0;
502
503 let max_len = reference.len().max(hypothesis.len());
504 for i in 0..max_len {
505 match (reference.get(i), hypothesis.get(i)) {
506 (Some(r), Some(h)) if r != h => subs += 1,
507 (Some(_), None) => dels += 1,
508 (None, Some(_)) => ins += 1,
509 _ => {}
510 }
511 }
512
513 (subs, dels, ins)
514 }
515
516 pub fn calculate(
518 &mut self,
519 reference: &[Vec<String>],
520 hypothesis: &[Vec<String>],
521 ) -> Result<f64> {
522 if reference.len() != hypothesis.len() {
523 return Err(MetricsError::InvalidInput(
524 "Reference and hypothesis must have the same length".to_string(),
525 ));
526 }
527
528 let mut total_errors = 0;
529 let mut total_phones = 0;
530
531 for (ref_seq, hyp_seq) in reference.iter().zip(hypothesis.iter()) {
532 let per = self.compute_per(ref_seq, hyp_seq)?;
533 total_errors += (per * ref_seq.len() as f64) as usize;
534 total_phones += ref_seq.len();
535 }
536
537 if total_phones == 0 {
538 Ok(0.0)
539 } else {
540 Ok(total_errors as f64 / total_phones as f64)
541 }
542 }
543}
544
545impl BleuCalculator {
546 pub fn new() -> Self {
548 Self {
549 ngram_weights: vec![0.25, 0.25, 0.25, 0.25], brevity_penalty: true,
551 smoothing: BleuSmoothing::Epsilon(1e-7),
552 }
553 }
554
555 pub fn compute_bleu(&self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
557 if reference.is_empty() || hypothesis.is_empty() {
558 return Ok(0.0);
559 }
560
561 let mut precisions = Vec::new();
562
563 for n in 1..=4 {
565 let precision = self.compute_ngram_precision(reference, hypothesis, n);
566 precisions.push(precision);
567 }
568
569 let log_sum: f64 = precisions
571 .iter()
572 .zip(&self.ngram_weights)
573 .map(|(p, w)| w * p.ln())
574 .sum();
575
576 let mut bleu = log_sum.exp();
577
578 if self.brevity_penalty {
580 let bp = self.compute_brevity_penalty(reference.len(), hypothesis.len());
581 bleu *= bp;
582 }
583
584 Ok(bleu)
585 }
586
587 pub fn get_bleu(&self) -> Option<f64> {
589 None }
591
592 fn compute_ngram_precision(
594 &self,
595 reference: &[String],
596 hypothesis: &[String],
597 n: usize,
598 ) -> f64 {
599 if hypothesis.len() < n {
600 return 0.0;
601 }
602
603 let ref_ngrams = self.extract_ngrams(reference, n);
604 let hyp_ngrams = self.extract_ngrams(hypothesis, n);
605
606 let mut matches = 0;
607 for ngram in &hyp_ngrams {
608 if ref_ngrams.contains(ngram) {
609 matches += 1;
610 }
611 }
612
613 if hyp_ngrams.is_empty() {
614 0.0
615 } else {
616 matches as f64 / hyp_ngrams.len() as f64
617 }
618 }
619
620 fn extract_ngrams(&self, sequence: &[String], n: usize) -> Vec<Vec<String>> {
622 if sequence.len() < n {
623 return Vec::new();
624 }
625
626 (0..=sequence.len() - n)
627 .map(|i| sequence[i..i + n].to_vec())
628 .collect()
629 }
630
631 fn compute_brevity_penalty(&self, ref_len: usize, hyp_len: usize) -> f64 {
633 if hyp_len >= ref_len {
634 1.0
635 } else {
636 (1.0 - ref_len as f64 / hyp_len as f64).exp()
637 }
638 }
639
640 pub fn calculate(&self, reference: &[String], hypothesis: &[String]) -> Result<f64> {
642 self.compute_bleu(reference, hypothesis)
643 }
644}
645
646impl ConfidenceMetrics {
647 pub fn new() -> Self {
649 Self {
650 confidence_threshold: 0.5,
651 word_confidences: Vec::new(),
652 utterance_confidences: Vec::new(),
653 confidence_wer_correlation: None,
654 }
655 }
656
657 pub fn add_scores(&mut self, word_confidences: Vec<f64>, utterance_confidence: f64) {
659 self.word_confidences.extend(word_confidences);
660 self.utterance_confidences.push(utterance_confidence);
661 }
662
663 pub fn get_average_confidence(&self) -> f64 {
665 if self.utterance_confidences.is_empty() {
666 0.0
667 } else {
668 self.utterance_confidences.iter().sum::<f64>() / self.utterance_confidences.len() as f64
669 }
670 }
671
672 pub fn set_threshold(&mut self, threshold: f64) {
674 self.confidence_threshold = threshold;
675 }
676
677 pub fn calculate_confidence_wer_correlation(
679 &mut self,
680 reference: &[String],
681 hypothesis: &[String],
682 confidence: &[f64],
683 ) -> Result<f64> {
684 if reference.len() != hypothesis.len() || hypothesis.len() != confidence.len() {
685 return Err(MetricsError::InvalidInput(
686 "Mismatched array lengths".to_string(),
687 ));
688 }
689
690 let mut correct_scores = Vec::new();
691 let mut incorrect_scores = Vec::new();
692
693 for ((r, h), &c) in reference
694 .iter()
695 .zip(hypothesis.iter())
696 .zip(confidence.iter())
697 {
698 if r == h {
699 correct_scores.push(c);
700 } else {
701 incorrect_scores.push(c);
702 }
703 }
704
705 if correct_scores.is_empty() || incorrect_scores.is_empty() {
706 return Ok(0.0);
707 }
708
709 let correct_mean = correct_scores.iter().sum::<f64>() / correct_scores.len() as f64;
710 let incorrect_mean = incorrect_scores.iter().sum::<f64>() / incorrect_scores.len() as f64;
711
712 Ok((correct_mean - incorrect_mean).abs())
713 }
714}
715
716impl Default for SpeechRecognitionMetrics {
717 fn default() -> Self {
718 Self::new()
719 }
720}
721
722impl Default for WerCalculator {
723 fn default() -> Self {
724 Self::new()
725 }
726}
727
728impl Default for CerCalculator {
729 fn default() -> Self {
730 Self::new()
731 }
732}
733
734impl Default for PerCalculator {
735 fn default() -> Self {
736 Self::new()
737 }
738}
739
740impl Default for BleuCalculator {
741 fn default() -> Self {
742 Self::new()
743 }
744}
745
746impl Default for ConfidenceMetrics {
747 fn default() -> Self {
748 Self::new()
749 }
750}