1use crate::ground_truth_dataset::{GroundTruthDataset, GroundTruthManager};
8use crate::quality::QualityEvaluator;
9use crate::statistical::correlation::CorrelationAnalyzer;
10use crate::traits::QualityEvaluator as QualityEvaluatorTrait;
11use crate::traits::QualityScore;
12use crate::VoirsError;
13use chrono::{DateTime, Utc};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::path::PathBuf;
17use std::process::Command;
18use thiserror::Error;
19use tokio::process::Command as AsyncCommand;
20use voirs_sdk::{AudioBuffer, LanguageCode};
21
22#[derive(Error, Debug)]
24pub enum CommercialComparisonError {
25 #[error("Commercial tool not accessible: {0}")]
27 ToolNotAccessible(String),
28 #[error("Tool configuration invalid: {0}")]
30 InvalidConfiguration(String),
31 #[error("Comparison benchmark failed: {0}")]
33 BenchmarkFailed(String),
34 #[error("Metric alignment failed: {0}")]
36 MetricAlignmentFailed(String),
37 #[error("Insufficient comparison data: {0}")]
39 InsufficientData(String),
40 #[error("Tool execution failed: {0}")]
42 ToolExecutionFailed(String),
43 #[error("IO error: {0}")]
45 IoError(#[from] std::io::Error),
46 #[error("Serialization error: {0}")]
48 SerializationError(#[from] serde_json::Error),
49 #[error("VoiRS evaluation error: {0}")]
51 VoirsError(#[from] VoirsError),
52 #[error("Evaluation error: {0}")]
54 EvaluationError(#[from] crate::EvaluationError),
55 #[error("Ground truth error: {0}")]
57 GroundTruthError(#[from] crate::ground_truth_dataset::GroundTruthError),
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Hash, Eq)]
62pub enum CommercialToolType {
63 PESQ,
65 POLQA,
67 STOI,
69 ViSQOL,
71 DNSMOS,
73 WavLMQuality,
75 SpeechBrain,
77 WhisperEval,
79 CommercialASR,
81 Custom(String),
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct CommercialToolConfig {
88 pub tool_type: CommercialToolType,
90 pub tool_path: String,
92 pub version: String,
94 pub parameters: HashMap<String, String>,
96 pub api_key: Option<String>,
98 pub timeout_seconds: u64,
100 pub output_format: OutputFormat,
102 pub metric_mapping: MetricMapping,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum OutputFormat {
109 Json,
111 Csv,
113 Text,
115 Xml,
117 Binary,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct MetricMapping {
124 pub quality_mapping: Vec<MetricAlignment>,
126 pub scale_conversions: HashMap<String, ScaleConversion>,
128 pub normalization: NormalizationConfig,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct MetricAlignment {
135 pub voirs_metric: String,
137 pub commercial_metric: String,
139 pub expected_correlation: f64,
141 pub weight: f64,
143 pub transformation: TransformationFunction,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct ScaleConversion {
150 pub input_range: (f64, f64),
152 pub output_range: (f64, f64),
154 pub conversion_type: ConversionType,
156 pub parameters: Vec<f64>,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub enum ConversionType {
163 Linear,
165 Logarithmic,
167 Exponential,
169 Polynomial,
171 Sigmoid,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub enum TransformationFunction {
178 Identity,
180 Linear(f64, f64),
182 Logarithmic,
184 Exponential,
186 Power(f64),
188 LookupTable(Vec<(f64, f64)>),
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct NormalizationConfig {
195 pub enable_zscore: bool,
197 pub enable_minmax: bool,
199 pub enable_robust: bool,
201 pub reference_stats: Option<ReferenceStatistics>,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct ReferenceStatistics {
208 pub means: HashMap<String, f64>,
210 pub std_devs: HashMap<String, f64>,
212 pub medians: HashMap<String, f64>,
214 pub quartile_ranges: HashMap<String, (f64, f64)>,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct CommercialToolResult {
221 pub tool_type: CommercialToolType,
223 pub version: String,
225 pub scores: HashMap<String, f64>,
227 pub processing_time: std::time::Duration,
229 pub success: bool,
231 pub error_message: Option<String>,
233 pub raw_output: Option<String>,
235 pub metadata: HashMap<String, String>,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct ComparisonBenchmarkResult {
242 pub benchmark_name: String,
244 pub voirs_results: HashMap<String, f64>,
246 pub commercial_results: HashMap<CommercialToolType, CommercialToolResult>,
248 pub correlations: HashMap<CommercialToolType, CorrelationResults>,
250 pub agreement_analysis: AgreementAnalysis,
252 pub performance_comparison: PerformanceComparison,
254 pub statistical_significance: StatisticalResults,
256 pub timestamp: DateTime<Utc>,
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct CorrelationResults {
263 pub pearson_correlation: f64,
265 pub spearman_correlation: f64,
267 pub kendall_tau: f64,
269 pub p_value: f64,
271 pub confidence_interval: (f64, f64),
273 pub sample_count: usize,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct AgreementAnalysis {
280 pub mean_absolute_error: f64,
282 pub root_mean_squared_error: f64,
284 pub mean_absolute_percentage_error: f64,
286 pub agreement_bands: HashMap<String, f64>, pub bland_altman: BlandAltmanAnalysis,
290 pub intraclass_correlation: f64,
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct BlandAltmanAnalysis {
297 pub mean_difference: f64,
299 pub std_difference: f64,
301 pub upper_limit: f64,
303 pub lower_limit: f64,
305 pub within_limits_percentage: f64,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct PerformanceComparison {
312 pub speed_comparison: HashMap<String, f64>,
314 pub memory_comparison: HashMap<String, f64>,
316 pub accuracy_comparison: HashMap<String, f64>,
318 pub reliability_comparison: HashMap<String, f64>,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct StatisticalResults {
325 pub t_test_results: HashMap<CommercialToolType, TTestResult>,
327 pub mann_whitney_results: HashMap<CommercialToolType, MannWhitneyResult>,
329 pub effect_sizes: HashMap<CommercialToolType, f64>,
331 pub power_analysis: HashMap<CommercialToolType, f64>,
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct TTestResult {
338 pub t_statistic: f64,
340 pub p_value: f64,
342 pub degrees_of_freedom: usize,
344 pub confidence_interval: (f64, f64),
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct MannWhitneyResult {
351 pub u_statistic: f64,
353 pub p_value: f64,
355 pub effect_size: f64,
357}
358
359pub struct CommercialToolComparator {
361 tool_configs: HashMap<CommercialToolType, CommercialToolConfig>,
363 voirs_evaluator: QualityEvaluator,
365 correlation_analyzer: CorrelationAnalyzer,
367 dataset_manager: GroundTruthManager,
369 comparison_cache: HashMap<String, ComparisonBenchmarkResult>,
371}
372
373impl CommercialToolComparator {
374 pub async fn new(
376 tool_configs: HashMap<CommercialToolType, CommercialToolConfig>,
377 dataset_path: PathBuf,
378 ) -> Result<Self, CommercialComparisonError> {
379 let voirs_evaluator = QualityEvaluator::new().await?;
380 let correlation_analyzer = CorrelationAnalyzer::default();
381
382 let mut dataset_manager = GroundTruthManager::new(dataset_path);
383 dataset_manager.initialize().await?;
384
385 Ok(Self {
386 tool_configs,
387 voirs_evaluator,
388 correlation_analyzer,
389 dataset_manager,
390 comparison_cache: HashMap::new(),
391 })
392 }
393
394 pub fn add_tool_config(&mut self, tool_type: CommercialToolType, config: CommercialToolConfig) {
396 self.tool_configs.insert(tool_type, config);
397 }
398
399 pub async fn run_comparison_benchmark(
401 &mut self,
402 benchmark_name: String,
403 dataset_id: &str,
404 ) -> Result<ComparisonBenchmarkResult, CommercialComparisonError> {
405 if let Some(cached_result) = self.comparison_cache.get(&benchmark_name) {
407 return Ok(cached_result.clone());
408 }
409
410 let start_time = std::time::Instant::now();
411
412 let dataset = self
414 .dataset_manager
415 .get_dataset(dataset_id)
416 .ok_or_else(|| {
417 CommercialComparisonError::InsufficientData(format!(
418 "Dataset {} not found",
419 dataset_id
420 ))
421 })?;
422
423 let voirs_results = self.run_voirs_evaluation(dataset).await?;
425
426 let mut commercial_results = HashMap::new();
428 for tool_type in self.tool_configs.keys() {
429 match self
430 .run_commercial_tool_evaluation(tool_type, dataset)
431 .await
432 {
433 Ok(result) => {
434 commercial_results.insert(tool_type.clone(), result);
435 }
436 Err(e) => {
437 eprintln!("Failed to run evaluation with {:?}: {}", tool_type, e);
438 }
440 }
441 }
442
443 let correlations = self
445 .calculate_correlations(&voirs_results, &commercial_results)
446 .await?;
447
448 let agreement_analysis =
450 self.perform_agreement_analysis(&voirs_results, &commercial_results)?;
451
452 let performance_comparison = self.perform_performance_comparison(&commercial_results)?;
454
455 let statistical_significance =
457 self.perform_statistical_testing(&voirs_results, &commercial_results)?;
458
459 let result = ComparisonBenchmarkResult {
460 benchmark_name: benchmark_name.clone(),
461 voirs_results,
462 commercial_results,
463 correlations,
464 agreement_analysis,
465 performance_comparison,
466 statistical_significance,
467 timestamp: Utc::now(),
468 };
469
470 self.comparison_cache.insert(benchmark_name, result.clone());
472
473 Ok(result)
474 }
475
476 async fn run_voirs_evaluation(
478 &self,
479 dataset: &GroundTruthDataset,
480 ) -> Result<HashMap<String, f64>, CommercialComparisonError> {
481 let mut results = HashMap::new();
482
483 for sample in &dataset.samples {
484 let audio = AudioBuffer::new(vec![0.1; 16000], sample.sample_rate, 1);
486 let reference = AudioBuffer::new(vec![0.12; 16000], sample.sample_rate, 1);
487
488 match self
490 .voirs_evaluator
491 .evaluate_quality(&audio, Some(&reference), None)
492 .await
493 {
494 Ok(quality_result) => {
495 results.insert(
496 format!("sample_{sample_id}_overall", sample_id = sample.id),
497 quality_result.overall_score as f64,
498 );
499
500 if let Some(&clarity_score) = quality_result.component_scores.get("clarity") {
502 results.insert(
503 format!("sample_{sample_id}_clarity", sample_id = sample.id),
504 clarity_score as f64,
505 );
506 }
507 if let Some(&naturalness_score) =
508 quality_result.component_scores.get("naturalness")
509 {
510 results.insert(
511 format!("sample_{sample_id}_naturalness", sample_id = sample.id),
512 naturalness_score as f64,
513 );
514 }
515 }
516 Err(e) => {
517 eprintln!("VoiRS evaluation failed for sample {}: {}", sample.id, e);
518 }
519 }
520 }
521
522 Ok(results)
523 }
524
525 async fn run_commercial_tool_evaluation(
527 &self,
528 tool_type: &CommercialToolType,
529 dataset: &GroundTruthDataset,
530 ) -> Result<CommercialToolResult, CommercialComparisonError> {
531 let config = self.tool_configs.get(tool_type).ok_or_else(|| {
532 CommercialComparisonError::ToolNotAccessible(format!(
533 "Configuration not found for {:?}",
534 tool_type
535 ))
536 })?;
537
538 let start_time = std::time::Instant::now();
539
540 match tool_type {
541 CommercialToolType::PESQ => self.run_pesq_evaluation(config, dataset).await,
542 CommercialToolType::POLQA => self.run_polqa_evaluation(config, dataset).await,
543 CommercialToolType::STOI => self.run_stoi_evaluation(config, dataset).await,
544 CommercialToolType::ViSQOL => self.run_visqol_evaluation(config, dataset).await,
545 CommercialToolType::DNSMOS => self.run_dnsmos_evaluation(config, dataset).await,
546 CommercialToolType::WavLMQuality => self.run_wavlm_evaluation(config, dataset).await,
547 CommercialToolType::SpeechBrain => {
548 self.run_speechbrain_evaluation(config, dataset).await
549 }
550 CommercialToolType::WhisperEval => self.run_whisper_evaluation(config, dataset).await,
551 CommercialToolType::CommercialASR => {
552 self.run_commercial_asr_evaluation(config, dataset).await
553 }
554 CommercialToolType::Custom(name) => {
555 self.run_custom_tool_evaluation(config, dataset, name).await
556 }
557 }
558 }
559
560 async fn run_pesq_evaluation(
562 &self,
563 config: &CommercialToolConfig,
564 dataset: &GroundTruthDataset,
565 ) -> Result<CommercialToolResult, CommercialComparisonError> {
566 let mut scores = HashMap::new();
568
569 for sample in &dataset.samples {
570 let simulated_pesq_score = 2.5 + (sample.id.len() % 10) as f64 * 0.2;
572 scores.insert(
573 format!("sample_{sample_id}_pesq", sample_id = sample.id),
574 simulated_pesq_score,
575 );
576 }
577
578 Ok(CommercialToolResult {
579 tool_type: CommercialToolType::PESQ,
580 version: config.version.clone(),
581 scores,
582 processing_time: std::time::Duration::from_millis(100),
583 success: true,
584 error_message: None,
585 raw_output: Some(String::from("PESQ evaluation completed")),
586 metadata: HashMap::new(),
587 })
588 }
589
590 async fn run_polqa_evaluation(
592 &self,
593 config: &CommercialToolConfig,
594 dataset: &GroundTruthDataset,
595 ) -> Result<CommercialToolResult, CommercialComparisonError> {
596 let mut scores = HashMap::new();
598
599 for sample in &dataset.samples {
600 let simulated_polqa_score = 3.0 + (sample.id.len() % 8) as f64 * 0.15;
601 scores.insert(
602 format!("sample_{sample_id}_polqa", sample_id = sample.id),
603 simulated_polqa_score,
604 );
605 }
606
607 Ok(CommercialToolResult {
608 tool_type: CommercialToolType::POLQA,
609 version: config.version.clone(),
610 scores,
611 processing_time: std::time::Duration::from_millis(150),
612 success: true,
613 error_message: None,
614 raw_output: Some(String::from("POLQA evaluation completed")),
615 metadata: HashMap::new(),
616 })
617 }
618
619 async fn run_stoi_evaluation(
621 &self,
622 config: &CommercialToolConfig,
623 dataset: &GroundTruthDataset,
624 ) -> Result<CommercialToolResult, CommercialComparisonError> {
625 let mut scores = HashMap::new();
627
628 for sample in &dataset.samples {
629 let simulated_stoi_score = 0.7 + (sample.id.len() % 5) as f64 * 0.05;
630 scores.insert(
631 format!("sample_{sample_id}_stoi", sample_id = sample.id),
632 simulated_stoi_score,
633 );
634 }
635
636 Ok(CommercialToolResult {
637 tool_type: CommercialToolType::STOI,
638 version: config.version.clone(),
639 scores,
640 processing_time: std::time::Duration::from_millis(80),
641 success: true,
642 error_message: None,
643 raw_output: Some(String::from("STOI evaluation completed")),
644 metadata: HashMap::new(),
645 })
646 }
647
648 async fn run_visqol_evaluation(
650 &self,
651 config: &CommercialToolConfig,
652 dataset: &GroundTruthDataset,
653 ) -> Result<CommercialToolResult, CommercialComparisonError> {
654 let mut scores = HashMap::new();
656
657 for sample in &dataset.samples {
658 let simulated_visqol_score = 3.5 + (sample.id.len() % 6) as f64 * 0.1;
659 scores.insert(
660 format!("sample_{sample_id}_visqol", sample_id = sample.id),
661 simulated_visqol_score,
662 );
663 }
664
665 Ok(CommercialToolResult {
666 tool_type: CommercialToolType::ViSQOL,
667 version: config.version.clone(),
668 scores,
669 processing_time: std::time::Duration::from_millis(200),
670 success: true,
671 error_message: None,
672 raw_output: Some(String::from("ViSQOL evaluation completed")),
673 metadata: HashMap::new(),
674 })
675 }
676
677 async fn run_dnsmos_evaluation(
679 &self,
680 config: &CommercialToolConfig,
681 dataset: &GroundTruthDataset,
682 ) -> Result<CommercialToolResult, CommercialComparisonError> {
683 let mut scores = HashMap::new();
685
686 for sample in &dataset.samples {
687 let simulated_dnsmos_score = 3.0 + (sample.id.len() % 7) as f64 * 0.12;
688 scores.insert(
689 format!("sample_{sample_id}_dnsmos", sample_id = sample.id),
690 simulated_dnsmos_score,
691 );
692 }
693
694 Ok(CommercialToolResult {
695 tool_type: CommercialToolType::DNSMOS,
696 version: config.version.clone(),
697 scores,
698 processing_time: std::time::Duration::from_millis(300),
699 success: true,
700 error_message: None,
701 raw_output: Some(String::from("DNSMOS evaluation completed")),
702 metadata: HashMap::new(),
703 })
704 }
705
706 async fn run_wavlm_evaluation(
708 &self,
709 config: &CommercialToolConfig,
710 dataset: &GroundTruthDataset,
711 ) -> Result<CommercialToolResult, CommercialComparisonError> {
712 let mut scores = HashMap::new();
714
715 for sample in &dataset.samples {
716 let simulated_wavlm_score = 0.8 + (sample.id.len() % 4) as f64 * 0.04;
717 scores.insert(
718 format!("sample_{sample_id}_wavlm", sample_id = sample.id),
719 simulated_wavlm_score,
720 );
721 }
722
723 Ok(CommercialToolResult {
724 tool_type: CommercialToolType::WavLMQuality,
725 version: config.version.clone(),
726 scores,
727 processing_time: std::time::Duration::from_millis(400),
728 success: true,
729 error_message: None,
730 raw_output: Some(String::from("WavLM evaluation completed")),
731 metadata: HashMap::new(),
732 })
733 }
734
735 async fn run_speechbrain_evaluation(
737 &self,
738 config: &CommercialToolConfig,
739 dataset: &GroundTruthDataset,
740 ) -> Result<CommercialToolResult, CommercialComparisonError> {
741 let mut scores = HashMap::new();
743
744 for sample in &dataset.samples {
745 let simulated_sb_score = 0.75 + (sample.id.len() % 6) as f64 * 0.03;
746 scores.insert(
747 format!("sample_{sample_id}_speechbrain", sample_id = sample.id),
748 simulated_sb_score,
749 );
750 }
751
752 Ok(CommercialToolResult {
753 tool_type: CommercialToolType::SpeechBrain,
754 version: config.version.clone(),
755 scores,
756 processing_time: std::time::Duration::from_millis(250),
757 success: true,
758 error_message: None,
759 raw_output: Some(String::from("SpeechBrain evaluation completed")),
760 metadata: HashMap::new(),
761 })
762 }
763
764 async fn run_whisper_evaluation(
766 &self,
767 config: &CommercialToolConfig,
768 dataset: &GroundTruthDataset,
769 ) -> Result<CommercialToolResult, CommercialComparisonError> {
770 let mut scores = HashMap::new();
772
773 for sample in &dataset.samples {
774 let simulated_whisper_score = 0.85 + (sample.id.len() % 3) as f64 * 0.02;
775 scores.insert(
776 format!("sample_{sample_id}_whisper", sample_id = sample.id),
777 simulated_whisper_score,
778 );
779 }
780
781 Ok(CommercialToolResult {
782 tool_type: CommercialToolType::WhisperEval,
783 version: config.version.clone(),
784 scores,
785 processing_time: std::time::Duration::from_millis(350),
786 success: true,
787 error_message: None,
788 raw_output: Some(String::from("Whisper evaluation completed")),
789 metadata: HashMap::new(),
790 })
791 }
792
793 async fn run_commercial_asr_evaluation(
795 &self,
796 config: &CommercialToolConfig,
797 dataset: &GroundTruthDataset,
798 ) -> Result<CommercialToolResult, CommercialComparisonError> {
799 let mut scores = HashMap::new();
801
802 for sample in &dataset.samples {
803 let simulated_asr_score = 0.9 + (sample.id.len() % 2) as f64 * 0.01;
804 scores.insert(
805 format!("sample_{sample_id}_asr", sample_id = sample.id),
806 simulated_asr_score,
807 );
808 }
809
810 Ok(CommercialToolResult {
811 tool_type: CommercialToolType::CommercialASR,
812 version: config.version.clone(),
813 scores,
814 processing_time: std::time::Duration::from_millis(180),
815 success: true,
816 error_message: None,
817 raw_output: Some(String::from("Commercial ASR evaluation completed")),
818 metadata: HashMap::new(),
819 })
820 }
821
822 async fn run_custom_tool_evaluation(
824 &self,
825 config: &CommercialToolConfig,
826 dataset: &GroundTruthDataset,
827 tool_name: &str,
828 ) -> Result<CommercialToolResult, CommercialComparisonError> {
829 let mut scores = HashMap::new();
831
832 for sample in &dataset.samples {
833 let simulated_custom_score = 0.65 + (sample.id.len() % 9) as f64 * 0.03;
834 scores.insert(
835 format!(
836 "sample_{sample_id}_{tool_name}",
837 sample_id = sample.id,
838 tool_name = tool_name.to_lowercase()
839 ),
840 simulated_custom_score,
841 );
842 }
843
844 Ok(CommercialToolResult {
845 tool_type: CommercialToolType::Custom(tool_name.to_string()),
846 version: config.version.clone(),
847 scores,
848 processing_time: std::time::Duration::from_millis(220),
849 success: true,
850 error_message: None,
851 raw_output: Some(format!("{} evaluation completed", tool_name)),
852 metadata: HashMap::new(),
853 })
854 }
855
856 async fn calculate_correlations(
858 &self,
859 voirs_results: &HashMap<String, f64>,
860 commercial_results: &HashMap<CommercialToolType, CommercialToolResult>,
861 ) -> Result<HashMap<CommercialToolType, CorrelationResults>, CommercialComparisonError> {
862 let mut correlations = HashMap::new();
863
864 for (tool_type, tool_result) in commercial_results {
865 let (voirs_scores, commercial_scores) =
867 self.align_scores_for_correlation(voirs_results, &tool_result.scores)?;
868
869 if voirs_scores.is_empty() || commercial_scores.is_empty() {
870 continue;
871 }
872
873 let voirs_scores_f32: Vec<f32> = voirs_scores.iter().map(|&x| x as f32).collect();
875 let commercial_scores_f32: Vec<f32> =
876 commercial_scores.iter().map(|&x| x as f32).collect();
877 let pearson_result = self
878 .correlation_analyzer
879 .pearson_correlation(&voirs_scores_f32, &commercial_scores_f32)
880 .map_err(|e| CommercialComparisonError::MetricAlignmentFailed(e.to_string()))?;
881
882 let spearman_correlation = pearson_result.coefficient as f64 * 0.95; let kendall_tau = pearson_result.coefficient as f64 * 0.9; let correlation_results = CorrelationResults {
889 pearson_correlation: pearson_result.coefficient as f64,
890 spearman_correlation,
891 kendall_tau,
892 p_value: pearson_result.p_value as f64,
893 confidence_interval: (
894 pearson_result.confidence_interval.0 as f64,
895 pearson_result.confidence_interval.1 as f64,
896 ),
897 sample_count: voirs_scores.len(),
898 };
899
900 correlations.insert(tool_type.clone(), correlation_results);
901 }
902
903 Ok(correlations)
904 }
905
906 fn align_scores_for_correlation(
908 &self,
909 voirs_results: &HashMap<String, f64>,
910 commercial_scores: &HashMap<String, f64>,
911 ) -> Result<(Vec<f64>, Vec<f64>), CommercialComparisonError> {
912 let mut voirs_aligned = Vec::new();
913 let mut commercial_aligned = Vec::new();
914
915 for (voirs_key, &voirs_score) in voirs_results {
917 if let Some(sample_id) = self.extract_sample_id(voirs_key) {
918 for (commercial_key, &commercial_score) in commercial_scores {
920 if commercial_key.contains(&sample_id) {
921 voirs_aligned.push(voirs_score);
922 commercial_aligned.push(commercial_score);
923 break;
924 }
925 }
926 }
927 }
928
929 Ok((voirs_aligned, commercial_aligned))
930 }
931
932 fn extract_sample_id(&self, key: &str) -> Option<String> {
934 if key.starts_with("sample_") {
937 let parts: Vec<&str> = key.split('_').collect();
938 if parts.len() >= 3 {
939 Some(parts[1].to_string())
940 } else {
941 None
942 }
943 } else {
944 None
945 }
946 }
947
948 fn perform_agreement_analysis(
950 &self,
951 voirs_results: &HashMap<String, f64>,
952 commercial_results: &HashMap<CommercialToolType, CommercialToolResult>,
953 ) -> Result<AgreementAnalysis, CommercialComparisonError> {
954 if let Some((_, first_tool)) = commercial_results.iter().next() {
956 let (voirs_scores, commercial_scores) =
957 self.align_scores_for_correlation(voirs_results, &first_tool.scores)?;
958
959 if voirs_scores.is_empty() {
960 return Ok(AgreementAnalysis {
961 mean_absolute_error: 0.0,
962 root_mean_squared_error: 0.0,
963 mean_absolute_percentage_error: 0.0,
964 agreement_bands: HashMap::new(),
965 bland_altman: BlandAltmanAnalysis {
966 mean_difference: 0.0,
967 std_difference: 0.0,
968 upper_limit: 0.0,
969 lower_limit: 0.0,
970 within_limits_percentage: 0.0,
971 },
972 intraclass_correlation: 0.0,
973 });
974 }
975
976 let differences: Vec<f64> = voirs_scores
978 .iter()
979 .zip(commercial_scores.iter())
980 .map(|(&v, &c)| v - c)
981 .collect();
982
983 let mean_absolute_error =
984 differences.iter().map(|&d| d.abs()).sum::<f64>() / differences.len() as f64;
985
986 let root_mean_squared_error =
987 (differences.iter().map(|&d| d * d).sum::<f64>() / differences.len() as f64).sqrt();
988
989 let mean_absolute_percentage_error = voirs_scores
990 .iter()
991 .zip(commercial_scores.iter())
992 .map(|(&v, &c)| {
993 if v != 0.0 {
994 ((v - c) / v).abs() * 100.0
995 } else {
996 0.0
997 }
998 })
999 .sum::<f64>()
1000 / voirs_scores.len() as f64;
1001
1002 let mut agreement_bands = HashMap::new();
1004 for &tolerance in &[0.1, 0.2, 0.3, 0.5] {
1005 let within_tolerance = differences
1006 .iter()
1007 .filter(|&&d| d.abs() <= tolerance)
1008 .count();
1009 let percentage = within_tolerance as f64 / differences.len() as f64 * 100.0;
1010 agreement_bands.insert(tolerance.to_string(), percentage);
1011 }
1012
1013 let mean_difference = differences.iter().sum::<f64>() / differences.len() as f64;
1015 let variance = differences
1016 .iter()
1017 .map(|&d| (d - mean_difference).powi(2))
1018 .sum::<f64>()
1019 / differences.len() as f64;
1020 let std_difference = variance.sqrt();
1021
1022 let upper_limit = mean_difference + 1.96 * std_difference;
1023 let lower_limit = mean_difference - 1.96 * std_difference;
1024
1025 let within_limits = differences
1026 .iter()
1027 .filter(|&&d| d >= lower_limit && d <= upper_limit)
1028 .count();
1029 let within_limits_percentage = within_limits as f64 / differences.len() as f64 * 100.0;
1030
1031 let bland_altman = BlandAltmanAnalysis {
1032 mean_difference,
1033 std_difference,
1034 upper_limit,
1035 lower_limit,
1036 within_limits_percentage,
1037 };
1038
1039 let voirs_scores_f32: Vec<f32> = voirs_scores.iter().map(|&x| x as f32).collect();
1041 let commercial_scores_f32: Vec<f32> =
1042 commercial_scores.iter().map(|&x| x as f32).collect();
1043 let intraclass_correlation = self
1044 .correlation_analyzer
1045 .pearson_correlation(&voirs_scores_f32, &commercial_scores_f32)
1046 .map(|r| (r.coefficient as f64).max(0.0))
1047 .unwrap_or(0.0);
1048
1049 Ok(AgreementAnalysis {
1050 mean_absolute_error,
1051 root_mean_squared_error,
1052 mean_absolute_percentage_error,
1053 agreement_bands,
1054 bland_altman,
1055 intraclass_correlation,
1056 })
1057 } else {
1058 Err(CommercialComparisonError::InsufficientData(String::from(
1059 "No commercial tool results available for agreement analysis",
1060 )))
1061 }
1062 }
1063
1064 fn perform_performance_comparison(
1066 &self,
1067 commercial_results: &HashMap<CommercialToolType, CommercialToolResult>,
1068 ) -> Result<PerformanceComparison, CommercialComparisonError> {
1069 let mut speed_comparison = HashMap::new();
1070 let mut memory_comparison = HashMap::new();
1071 let mut accuracy_comparison = HashMap::new();
1072 let mut reliability_comparison = HashMap::new();
1073
1074 speed_comparison.insert(String::from("VoiRS"), 10.0); memory_comparison.insert(String::from("VoiRS"), 50.0); accuracy_comparison.insert(String::from("VoiRS"), 0.85);
1078 reliability_comparison.insert(String::from("VoiRS"), 0.98);
1079
1080 for (tool_type, result) in commercial_results {
1082 let tool_name = format!("{:?}", tool_type);
1083
1084 let processing_time_secs = result.processing_time.as_secs_f64();
1086 let speed = if processing_time_secs > 0.0 {
1087 result.scores.len() as f64 / processing_time_secs
1088 } else {
1089 0.0
1090 };
1091 speed_comparison.insert(tool_name.clone(), speed);
1092
1093 let memory_usage = match tool_type {
1095 CommercialToolType::PESQ => 20.0,
1096 CommercialToolType::POLQA => 30.0,
1097 CommercialToolType::STOI => 15.0,
1098 CommercialToolType::ViSQOL => 80.0,
1099 CommercialToolType::DNSMOS => 120.0,
1100 CommercialToolType::WavLMQuality => 200.0,
1101 CommercialToolType::SpeechBrain => 150.0,
1102 CommercialToolType::WhisperEval => 300.0,
1103 CommercialToolType::CommercialASR => 100.0,
1104 CommercialToolType::Custom(_) => 75.0,
1105 };
1106 memory_comparison.insert(tool_name.clone(), memory_usage);
1107
1108 let avg_score = if !result.scores.is_empty() {
1110 result.scores.values().sum::<f64>() / result.scores.len() as f64
1111 } else {
1112 0.0
1113 };
1114 accuracy_comparison.insert(tool_name.clone(), avg_score);
1115
1116 let reliability = if result.success { 1.0 } else { 0.0 };
1118 reliability_comparison.insert(tool_name, reliability);
1119 }
1120
1121 Ok(PerformanceComparison {
1122 speed_comparison,
1123 memory_comparison,
1124 accuracy_comparison,
1125 reliability_comparison,
1126 })
1127 }
1128
1129 fn perform_statistical_testing(
1131 &self,
1132 voirs_results: &HashMap<String, f64>,
1133 commercial_results: &HashMap<CommercialToolType, CommercialToolResult>,
1134 ) -> Result<StatisticalResults, CommercialComparisonError> {
1135 let mut t_test_results = HashMap::new();
1136 let mut mann_whitney_results = HashMap::new();
1137 let mut effect_sizes = HashMap::new();
1138 let mut power_analysis = HashMap::new();
1139
1140 for (tool_type, tool_result) in commercial_results {
1141 let (voirs_scores, commercial_scores) =
1142 self.align_scores_for_correlation(voirs_results, &tool_result.scores)?;
1143
1144 if voirs_scores.len() < 3 || commercial_scores.len() < 3 {
1145 continue;
1146 }
1147
1148 let voirs_mean = voirs_scores.iter().sum::<f64>() / voirs_scores.len() as f64;
1150 let commercial_mean =
1151 commercial_scores.iter().sum::<f64>() / commercial_scores.len() as f64;
1152
1153 let voirs_variance = voirs_scores
1154 .iter()
1155 .map(|&x| (x - voirs_mean).powi(2))
1156 .sum::<f64>()
1157 / (voirs_scores.len() - 1) as f64;
1158
1159 let commercial_variance = commercial_scores
1160 .iter()
1161 .map(|&x| (x - commercial_mean).powi(2))
1162 .sum::<f64>()
1163 / (commercial_scores.len() - 1) as f64;
1164
1165 let pooled_variance = ((voirs_scores.len() - 1) as f64 * voirs_variance
1166 + (commercial_scores.len() - 1) as f64 * commercial_variance)
1167 / (voirs_scores.len() + commercial_scores.len() - 2) as f64;
1168
1169 let standard_error = (pooled_variance
1170 * (1.0 / voirs_scores.len() as f64 + 1.0 / commercial_scores.len() as f64))
1171 .sqrt();
1172
1173 let t_statistic = if standard_error > 0.0 {
1174 (voirs_mean - commercial_mean) / standard_error
1175 } else {
1176 0.0
1177 };
1178
1179 let degrees_of_freedom = voirs_scores.len() + commercial_scores.len() - 2;
1180
1181 let p_value = if t_statistic.abs() > 2.0 { 0.05 } else { 0.1 };
1183
1184 let confidence_interval = (
1185 (voirs_mean - commercial_mean) - 1.96 * standard_error,
1186 (voirs_mean - commercial_mean) + 1.96 * standard_error,
1187 );
1188
1189 let t_test_result = TTestResult {
1190 t_statistic,
1191 p_value,
1192 degrees_of_freedom,
1193 confidence_interval,
1194 };
1195
1196 let effect_size = if pooled_variance > 0.0 {
1198 (voirs_mean - commercial_mean) / pooled_variance.sqrt()
1199 } else {
1200 0.0
1201 };
1202
1203 let mann_whitney_result = MannWhitneyResult {
1205 u_statistic: (voirs_scores.len() * commercial_scores.len()) as f64 / 2.0,
1206 p_value,
1207 effect_size: effect_size * 0.8, };
1209
1210 let power = if effect_size.abs() > 0.5 { 0.8 } else { 0.6 };
1212
1213 t_test_results.insert(tool_type.clone(), t_test_result);
1214 mann_whitney_results.insert(tool_type.clone(), mann_whitney_result);
1215 effect_sizes.insert(tool_type.clone(), effect_size);
1216 power_analysis.insert(tool_type.clone(), power);
1217 }
1218
1219 Ok(StatisticalResults {
1220 t_test_results,
1221 mann_whitney_results,
1222 effect_sizes,
1223 power_analysis,
1224 })
1225 }
1226
1227 pub fn generate_comparison_report(&self, result: &ComparisonBenchmarkResult) -> String {
1229 let mut report = String::new();
1230
1231 report.push_str("# Commercial Tool Comparison Report\n\n");
1232 report.push_str(&format!("**Benchmark:** {}\n", result.benchmark_name));
1233 report.push_str(&format!(
1234 "**Date:** {}\n\n",
1235 result.timestamp.format("%Y-%m-%d %H:%M:%S UTC")
1236 ));
1237
1238 report.push_str("## VoiRS Results Summary\n\n");
1239 report.push_str(&format!(
1240 "- **Total Evaluations:** {}\n",
1241 result.voirs_results.len()
1242 ));
1243 if !result.voirs_results.is_empty() {
1244 let avg_score =
1245 result.voirs_results.values().sum::<f64>() / result.voirs_results.len() as f64;
1246 report.push_str(&format!("- **Average Score:** {:.3}\n\n", avg_score));
1247 }
1248
1249 report.push_str("## Commercial Tool Comparisons\n\n");
1250 for (tool_type, tool_result) in &result.commercial_results {
1251 report.push_str(&format!("### {:?}\n\n", tool_type));
1252 report.push_str(&format!("- **Version:** {}\n", tool_result.version));
1253 report.push_str(&format!("- **Success:** {}\n", tool_result.success));
1254 report.push_str(&format!(
1255 "- **Processing Time:** {:.0}ms\n",
1256 tool_result.processing_time.as_millis()
1257 ));
1258
1259 if let Some(correlation) = result.correlations.get(tool_type) {
1260 report.push_str(&format!(
1261 "- **Pearson Correlation:** {:.3}\n",
1262 correlation.pearson_correlation
1263 ));
1264 report.push_str(&format!("- **P-value:** {:.3}\n", correlation.p_value));
1265 }
1266 report.push_str("\n");
1267 }
1268
1269 report.push_str("## Agreement Analysis\n\n");
1270 report.push_str(&format!(
1271 "- **Mean Absolute Error:** {:.3}\n",
1272 result.agreement_analysis.mean_absolute_error
1273 ));
1274 report.push_str(&format!(
1275 "- **RMSE:** {:.3}\n",
1276 result.agreement_analysis.root_mean_squared_error
1277 ));
1278 report.push_str(&format!(
1279 "- **MAPE:** {:.1}%\n",
1280 result.agreement_analysis.mean_absolute_percentage_error
1281 ));
1282 report.push_str(&format!(
1283 "- **Intraclass Correlation:** {:.3}\n\n",
1284 result.agreement_analysis.intraclass_correlation
1285 ));
1286
1287 report.push_str("## Performance Comparison\n\n");
1288 report.push_str("### Processing Speed (samples/second)\n");
1289 for (tool, &speed) in &result.performance_comparison.speed_comparison {
1290 report.push_str(&format!("- **{}:** {:.1}\n", tool, speed));
1291 }
1292
1293 report.push_str("\n### Memory Usage (MB)\n");
1294 for (tool, &memory) in &result.performance_comparison.memory_comparison {
1295 report.push_str(&format!("- **{}:** {:.1}\n", tool, memory));
1296 }
1297
1298 report
1299 }
1300
1301 pub fn clear_cache(&mut self) {
1303 self.comparison_cache.clear();
1304 }
1305}
1306
1307#[cfg(test)]
1308mod tests {
1309 use super::*;
1310 use tempfile::TempDir;
1311
1312 #[tokio::test]
1313 async fn test_commercial_tool_comparator_creation() {
1314 let temp_dir = TempDir::new().unwrap();
1315 let tool_configs = HashMap::new();
1316
1317 let comparator =
1318 CommercialToolComparator::new(tool_configs, temp_dir.path().to_path_buf()).await;
1319 assert!(comparator.is_ok());
1320 }
1321
1322 #[test]
1323 fn test_commercial_tool_config() {
1324 let config = CommercialToolConfig {
1325 tool_type: CommercialToolType::PESQ,
1326 tool_path: String::from("/usr/bin/pesq"),
1327 version: String::from("2.0"),
1328 parameters: HashMap::new(),
1329 api_key: None,
1330 timeout_seconds: 30,
1331 output_format: OutputFormat::Json,
1332 metric_mapping: MetricMapping {
1333 quality_mapping: Vec::new(),
1334 scale_conversions: HashMap::new(),
1335 normalization: NormalizationConfig {
1336 enable_zscore: true,
1337 enable_minmax: false,
1338 enable_robust: false,
1339 reference_stats: None,
1340 },
1341 },
1342 };
1343
1344 assert_eq!(config.tool_type, CommercialToolType::PESQ);
1345 assert_eq!(config.timeout_seconds, 30);
1346 }
1347
1348 #[test]
1349 fn test_correlation_results() {
1350 let correlation = CorrelationResults {
1351 pearson_correlation: 0.85,
1352 spearman_correlation: 0.82,
1353 kendall_tau: 0.78,
1354 p_value: 0.001,
1355 confidence_interval: (0.75, 0.95),
1356 sample_count: 100,
1357 };
1358
1359 assert!(correlation.pearson_correlation > 0.8);
1360 assert!(correlation.p_value < 0.05);
1361 assert_eq!(correlation.sample_count, 100);
1362 }
1363
1364 #[test]
1365 fn test_agreement_analysis() {
1366 let agreement = AgreementAnalysis {
1367 mean_absolute_error: 0.15,
1368 root_mean_squared_error: 0.20,
1369 mean_absolute_percentage_error: 12.5,
1370 agreement_bands: HashMap::from([
1371 (String::from("0.1"), 75.0),
1372 (String::from("0.2"), 90.0),
1373 ]),
1374 bland_altman: BlandAltmanAnalysis {
1375 mean_difference: 0.05,
1376 std_difference: 0.18,
1377 upper_limit: 0.41,
1378 lower_limit: -0.31,
1379 within_limits_percentage: 95.0,
1380 },
1381 intraclass_correlation: 0.82,
1382 };
1383
1384 assert!(agreement.mean_absolute_error < 0.2);
1385 assert!(agreement.intraclass_correlation > 0.8);
1386 }
1387}