1use crate::GlobalOptions;
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6use voirs_g2p::{
7 accuracy::{AccuracyBenchmark, TestCase},
8 LanguageCode,
9};
10use voirs_sdk::config::AppConfig;
11use voirs_sdk::types::SynthesisConfig;
12use voirs_sdk::VoirsPipeline;
13use voirs_sdk::{QualityLevel, Result};
14
15#[derive(Debug, Clone)]
17pub struct BenchmarkResult {
18 pub model_id: String,
19 pub avg_synthesis_time: Duration,
20 pub avg_audio_duration: Duration,
21 pub real_time_factor: f64,
22 pub memory_usage_mb: f64,
23 pub quality_score: f64,
24 pub success_rate: f64,
25 pub phoneme_accuracy: Option<f64>,
26 pub word_accuracy: Option<f64>,
27 pub accuracy_target_met: Option<bool>,
28 pub english_accuracy_target_met: Option<bool>,
29 pub japanese_accuracy_target_met: Option<bool>,
30 pub latency_target_met: bool,
31 pub memory_target_met: bool,
32}
33
34pub async fn run_benchmark_models(
36 model_ids: &[String],
37 iterations: u32,
38 include_accuracy: bool,
39 config: &AppConfig,
40 global: &GlobalOptions,
41) -> Result<()> {
42 if !global.quiet {
43 println!("Benchmarking TTS Models");
44 println!("=======================");
45 println!("Iterations: {}", iterations);
46 println!("Models: {}", model_ids.len());
47 if include_accuracy {
48 println!("Accuracy Testing: Enabled (CMU Test Set)");
49 }
50 println!();
51 }
52
53 let test_sentences = get_test_sentences();
54 let mut results = Vec::new();
55
56 let accuracy_benchmark = if include_accuracy {
58 if !global.quiet {
59 println!("Loading CMU accuracy test data...");
60 }
61 Some(load_cmu_accuracy_benchmark()?)
62 } else {
63 None
64 };
65
66 for model_id in model_ids {
67 if !global.quiet {
68 println!("Benchmarking model: {}", model_id);
69 }
70
71 let result = benchmark_model(
72 model_id,
73 &test_sentences,
74 iterations,
75 accuracy_benchmark.as_ref(),
76 config,
77 global,
78 )
79 .await?;
80 results.push(result);
81
82 if !global.quiet {
83 println!(" ✓ Completed\n");
84 }
85 }
86
87 display_benchmark_results(&results, global);
89
90 if results.len() > 1 {
92 generate_comparison_report(&results, global);
93 }
94
95 Ok(())
96}
97
98async fn benchmark_model(
100 model_id: &str,
101 test_sentences: &[String],
102 iterations: u32,
103 accuracy_benchmark: Option<&AccuracyBenchmark>,
104 config: &AppConfig,
105 global: &GlobalOptions,
106) -> Result<BenchmarkResult> {
107 if !global.quiet {
109 println!(" Loading model: {}", model_id);
110 }
111
112 let pipeline = load_model_pipeline(model_id, config, global).await?;
113
114 let synth_config = SynthesisConfig {
115 quality: QualityLevel::High,
116 ..Default::default()
117 };
118
119 let mut total_synthesis_time = Duration::from_secs(0);
120 let mut total_audio_duration = 0.0;
121 let mut successful_runs = 0;
122 let mut memory_samples = Vec::new();
123
124 for i in 0..iterations {
126 if !global.quiet && iterations > 1 {
127 print!(" Progress: {}/{}\r", i + 1, iterations);
128 }
129
130 for sentence in test_sentences {
131 let start_time = Instant::now();
132
133 let memory_before = get_memory_usage();
135
136 match pipeline
138 .synthesize_with_config(sentence, &synth_config)
139 .await
140 {
141 Ok(audio) => {
142 let synthesis_time = start_time.elapsed();
143 total_synthesis_time += synthesis_time;
144 total_audio_duration += audio.duration();
145 successful_runs += 1;
146
147 let memory_after = get_memory_usage();
149 memory_samples.push(memory_after - memory_before);
150 }
151 Err(e) => {
152 tracing::warn!("Synthesis failed for '{}': {}", sentence, e);
153 }
154 }
155 }
156 }
157
158 let total_runs = iterations as usize * test_sentences.len();
160 let avg_synthesis_time = total_synthesis_time / total_runs as u32;
161 let avg_audio_duration =
162 Duration::from_secs_f64(total_audio_duration as f64 / successful_runs as f64);
163 let real_time_factor = avg_synthesis_time.as_secs_f64() / avg_audio_duration.as_secs_f64();
164 let success_rate = successful_runs as f64 / total_runs as f64;
165 let avg_memory_usage = memory_samples.iter().sum::<f64>() / memory_samples.len() as f64;
166
167 let quality_score = calculate_quality_score(model_id, &real_time_factor, &success_rate);
169
170 let latency_target_met = check_latency_target(&avg_synthesis_time);
172 let memory_target_met = check_memory_target(avg_memory_usage);
173
174 if !global.quiet {
175 println!(" Performance Targets:");
176 println!(
177 " Latency (<1ms): {} - {:.2}ms",
178 if latency_target_met {
179 "✅ PASSED"
180 } else {
181 "❌ FAILED"
182 },
183 avg_synthesis_time.as_millis()
184 );
185 println!(
186 " Memory (<100MB): {} - {:.1}MB",
187 if memory_target_met {
188 "✅ PASSED"
189 } else {
190 "❌ FAILED"
191 },
192 avg_memory_usage
193 );
194 }
195
196 let (
198 phoneme_accuracy,
199 word_accuracy,
200 accuracy_target_met,
201 english_accuracy_target_met,
202 japanese_accuracy_target_met,
203 ) = if let Some(benchmark) = accuracy_benchmark {
204 if !global.quiet {
205 println!(" Running accuracy tests...");
206 }
207
208 match run_accuracy_test(&pipeline, benchmark, global).await {
212 Ok((phoneme_acc, word_acc, target_met, en_target_met, ja_target_met)) => {
213 if !global.quiet {
214 println!(" Phoneme Accuracy: {:.2}%", phoneme_acc * 100.0);
215 println!(" Word Accuracy: {:.2}%", word_acc * 100.0);
216 println!(
217 " Overall Target Met: {}",
218 if target_met { "✅" } else { "❌" }
219 );
220 println!(
221 " English Target (>95%): {}",
222 if en_target_met { "✅" } else { "❌" }
223 );
224 println!(
225 " Japanese Target (>90%): {}",
226 if ja_target_met { "✅" } else { "❌" }
227 );
228 }
229 (
230 Some(phoneme_acc),
231 Some(word_acc),
232 Some(target_met),
233 Some(en_target_met),
234 Some(ja_target_met),
235 )
236 }
237 Err(e) => {
238 if !global.quiet {
239 println!(" Accuracy test failed: {}", e);
240 }
241 (None, None, None, None, None)
242 }
243 }
244 } else {
245 (None, None, None, None, None)
246 };
247
248 Ok(BenchmarkResult {
249 model_id: model_id.to_string(),
250 avg_synthesis_time,
251 avg_audio_duration,
252 real_time_factor,
253 memory_usage_mb: avg_memory_usage,
254 quality_score,
255 success_rate,
256 phoneme_accuracy,
257 word_accuracy,
258 accuracy_target_met,
259 english_accuracy_target_met,
260 japanese_accuracy_target_met,
261 latency_target_met,
262 memory_target_met,
263 })
264}
265
266fn get_test_sentences() -> Vec<String> {
268 vec![
269 "The quick brown fox jumps over the lazy dog.".to_string(),
270 "Hello, this is a test of the text-to-speech system.".to_string(),
271 "Artificial intelligence is transforming the way we communicate.".to_string(),
272 "The weather today is absolutely beautiful with clear skies.".to_string(),
273 "Machine learning models require careful tuning and validation.".to_string(),
274 ]
275}
276
277fn get_memory_usage() -> f64 {
279 if let Ok(status) = std::fs::read_to_string("/proc/self/status") {
283 for line in status.lines() {
284 if line.starts_with("VmRSS:") {
285 if let Some(kb_str) = line.split_whitespace().nth(1) {
286 if let Ok(kb) = kb_str.parse::<f64>() {
287 return kb / 1024.0; }
289 }
290 }
291 }
292 }
293
294 #[cfg(unix)]
296 {
297 unsafe {
298 let mut usage = std::mem::MaybeUninit::<libc::rusage>::uninit();
299 if libc::getrusage(libc::RUSAGE_SELF, usage.as_mut_ptr()) == 0 {
300 let usage = usage.assume_init();
301 #[cfg(target_os = "linux")]
303 return usage.ru_maxrss as f64 / 1024.0; #[cfg(target_os = "macos")]
306 return usage.ru_maxrss as f64 / (1024.0 * 1024.0); }
308 }
309 }
310
311 if let Ok(meminfo) = std::fs::read_to_string("/proc/meminfo") {
313 let mut total_kb = None;
314 let mut available_kb = None;
315
316 for line in meminfo.lines() {
317 if line.starts_with("MemTotal:") {
318 if let Some(kb_str) = line.split_whitespace().nth(1) {
319 if let Ok(kb) = kb_str.parse::<f64>() {
320 total_kb = Some(kb);
321 }
322 }
323 } else if line.starts_with("MemAvailable:") {
324 if let Some(kb_str) = line.split_whitespace().nth(1) {
325 if let Ok(kb) = kb_str.parse::<f64>() {
326 available_kb = Some(kb);
327 }
328 }
329 }
330 }
331
332 if let (Some(total), Some(available)) = (total_kb, available_kb) {
333 let used_mb = (total - available) / 1024.0; return used_mb;
335 }
336 }
337
338 50.0 }
342
343fn calculate_quality_score(model_id: &str, real_time_factor: &f64, success_rate: &f64) -> f64 {
345 let performance_score = if *real_time_factor < 0.05 {
348 5.0
349 } else if *real_time_factor < 0.1 {
350 4.5 + 0.5 * (0.1 - real_time_factor) / 0.05 } else if *real_time_factor < 0.25 {
352 3.5 + 1.0 * (0.25 - real_time_factor) / 0.15 } else if *real_time_factor < 0.5 {
354 2.5 + 1.0 * (0.5 - real_time_factor) / 0.25 } else if *real_time_factor < 1.0 {
356 1.5 + 1.0 * (1.0 - real_time_factor) / 0.5 } else if *real_time_factor < 2.0 {
358 0.5 + 1.0 * (2.0 - real_time_factor) / 1.0 } else {
360 (5.0 / real_time_factor).min(0.5) };
362
363 let reliability_score = if *success_rate >= 0.99 {
365 5.0
366 } else if *success_rate >= 0.95 {
367 4.0 + 1.0 * (success_rate - 0.95) / 0.04 } else if *success_rate >= 0.90 {
369 3.0 + 1.0 * (success_rate - 0.90) / 0.05 } else if *success_rate >= 0.75 {
371 1.5 + 1.5 * (success_rate - 0.75) / 0.15 } else {
373 success_rate * 2.0 };
375
376 let (model_quality_baseline, model_speed_expectation) = if model_id.contains("hifigan") {
380 (0.6, 0.15) } else if model_id.contains("waveglow") || model_id.contains("wavernn") {
382 (0.8, 0.5) } else if model_id.contains("melgan") || model_id.contains("parallel-wavegan") {
384 (0.5, 0.1) } else if model_id.contains("tacotron") {
386 (0.5, 0.3) } else if model_id.contains("fastspeech") {
388 (0.6, 0.2) } else if model_id.contains("vits") {
390 (0.7, 0.25) } else if model_id.contains("diffwave") || model_id.contains("diffusion") {
392 (0.9, 1.0) } else {
394 (0.3, 0.5) };
396
397 let speed_bonus = if *real_time_factor <= model_speed_expectation {
399 model_quality_baseline
400 } else if *real_time_factor <= model_speed_expectation * 2.0 {
401 model_quality_baseline
403 * (1.0 - (real_time_factor - model_speed_expectation) / model_speed_expectation)
404 } else {
405 0.0 };
407
408 let total_score = performance_score * 0.4 + reliability_score * 0.4 + speed_bonus * 0.2;
410
411 total_score.clamp(0.0, 5.0)
412}
413
414fn display_benchmark_results(results: &[BenchmarkResult], global: &GlobalOptions) {
416 if global.quiet {
417 return;
418 }
419
420 println!("Benchmark Results:");
421 println!("==================");
422
423 for result in results {
424 println!("\nModel: {}", result.model_id);
425 println!(
426 " Avg Synthesis Time: {:.2}ms",
427 result.avg_synthesis_time.as_millis()
428 );
429 println!(
430 " Avg Audio Duration: {:.2}ms",
431 result.avg_audio_duration.as_millis()
432 );
433 println!(" Real-time Factor: {:.2}x", result.real_time_factor);
434 println!(" Memory Usage: {:.1} MB", result.memory_usage_mb);
435 println!(" Quality Score: {:.1}/5.0", result.quality_score);
436 println!(" Success Rate: {:.1}%", result.success_rate * 100.0);
437
438 if let Some(phoneme_acc) = result.phoneme_accuracy {
440 println!(" Phoneme Accuracy: {:.2}%", phoneme_acc * 100.0);
441 }
442 if let Some(word_acc) = result.word_accuracy {
443 println!(" Word Accuracy: {:.2}%", word_acc * 100.0);
444 }
445 if let Some(target_met) = result.accuracy_target_met {
446 println!(" Accuracy Targets:");
447 if let Some(en_target) = result.english_accuracy_target_met {
448 println!(
449 " English (>95%): {}",
450 if en_target {
451 "✅ PASSED"
452 } else {
453 "❌ FAILED"
454 }
455 );
456 }
457 if let Some(ja_target) = result.japanese_accuracy_target_met {
458 println!(
459 " Japanese (>90%): {}",
460 if ja_target {
461 "✅ PASSED"
462 } else {
463 "❌ FAILED"
464 }
465 );
466 }
467 println!(
468 " Overall: {}",
469 if target_met {
470 "✅ PASSED"
471 } else {
472 "❌ FAILED"
473 }
474 );
475 }
476
477 println!(
479 " Latency Target (<1ms): {}",
480 if result.latency_target_met {
481 "✅ PASSED"
482 } else {
483 "❌ FAILED"
484 }
485 );
486 println!(
487 " Memory Target (<100MB): {}",
488 if result.memory_target_met {
489 "✅ PASSED"
490 } else {
491 "❌ FAILED"
492 }
493 );
494 }
495}
496
497fn generate_comparison_report(results: &[BenchmarkResult], global: &GlobalOptions) {
499 if global.quiet {
500 return;
501 }
502
503 println!("\n\nComparison Report:");
504 println!("==================");
505
506 let fastest = results
508 .iter()
509 .min_by(|a, b| a.real_time_factor.partial_cmp(&b.real_time_factor).unwrap());
510 let most_reliable = results
511 .iter()
512 .max_by(|a, b| a.success_rate.partial_cmp(&b.success_rate).unwrap());
513 let highest_quality = results
514 .iter()
515 .max_by(|a, b| a.quality_score.partial_cmp(&b.quality_score).unwrap());
516 let most_efficient = results
517 .iter()
518 .min_by(|a, b| a.memory_usage_mb.partial_cmp(&b.memory_usage_mb).unwrap());
519
520 if let Some(model) = fastest {
521 println!(
522 "🏃 Fastest Model: {} ({:.2}x real-time)",
523 model.model_id, model.real_time_factor
524 );
525 }
526
527 if let Some(model) = most_reliable {
528 println!(
529 "🎯 Most Reliable: {} ({:.1}% success rate)",
530 model.model_id,
531 model.success_rate * 100.0
532 );
533 }
534
535 if let Some(model) = highest_quality {
536 println!(
537 "⭐ Highest Quality: {} ({:.1}/5.0)",
538 model.model_id, model.quality_score
539 );
540 }
541
542 if let Some(model) = most_efficient {
543 println!(
544 "💾 Most Memory Efficient: {} ({:.1} MB)",
545 model.model_id, model.memory_usage_mb
546 );
547 }
548
549 println!("\n📊 Performance Target Analysis:");
551 let models_meeting_latency = results.iter().filter(|r| r.latency_target_met).count();
552 let models_meeting_memory = results.iter().filter(|r| r.memory_target_met).count();
553 let models_meeting_all_targets = results
554 .iter()
555 .filter(|r| {
556 r.latency_target_met && r.memory_target_met && r.accuracy_target_met.unwrap_or(false)
557 })
558 .count();
559
560 println!(
561 " 🚀 Models meeting latency target (<1ms): {}/{}",
562 models_meeting_latency,
563 results.len()
564 );
565 println!(
566 " 🧠 Models meeting memory target (<100MB): {}/{}",
567 models_meeting_memory,
568 results.len()
569 );
570
571 if results.iter().any(|r| r.accuracy_target_met.is_some()) {
572 println!(
573 " 🎯 Models meeting all targets: {}/{}",
574 models_meeting_all_targets,
575 results.len()
576 );
577
578 if models_meeting_all_targets > 0 {
579 println!(" ✅ Production-ready models found!");
580 } else {
581 println!(" ⚠️ No models currently meet all production targets");
582 }
583 }
584}
585
586async fn load_model_pipeline(
588 model_id: &str,
589 config: &AppConfig,
590 global: &GlobalOptions,
591) -> Result<VoirsPipeline> {
592 let cache_dir = config.pipeline.effective_cache_dir();
594 let model_path = cache_dir.join("models").join(model_id);
595
596 if !model_path.exists() {
597 return Err(voirs_sdk::VoirsError::config_error(format!(
598 "Model '{}' not found in cache. Please download it first using 'voirs download-model {}'",
599 model_id, model_id
600 )));
601 }
602
603 let model_config_path = model_path.join("config.json");
605 let model_config = if model_config_path.exists() {
606 let config_content = std::fs::read_to_string(&model_config_path).map_err(|e| {
607 voirs_sdk::VoirsError::IoError {
608 path: model_config_path.clone(),
609 operation: voirs_sdk::error::IoOperation::Read,
610 source: e,
611 }
612 })?;
613
614 serde_json::from_str::<ModelMetadata>(&config_content).map_err(|e| {
615 voirs_sdk::VoirsError::config_error(format!(
616 "Invalid model config for '{}': {}",
617 model_id, e
618 ))
619 })?
620 } else {
621 ModelMetadata {
623 id: model_id.to_string(),
624 name: model_id.to_string(),
625 description: format!("Model {}", model_id),
626 model_type: "neural".to_string(),
627 quality: QualityLevel::High,
628 requires_gpu: false,
629 memory_requirements_mb: 512,
630 acoustic_model: "model.safetensors".to_string(),
631 vocoder_model: "vocoder.safetensors".to_string(),
632 g2p_model: None,
633 }
634 };
635
636 if !global.quiet {
637 println!(
638 " Model: {} ({})",
639 model_config.name, model_config.description
640 );
641 println!(
642 " Type: {}, Quality: {:?}",
643 model_config.model_type, model_config.quality
644 );
645 println!(
646 " Memory required: {} MB",
647 model_config.memory_requirements_mb
648 );
649 }
650
651 let available_memory = get_memory_usage(); if model_config.memory_requirements_mb as f64 > available_memory {
654 tracing::warn!(
655 "Model '{}' requires {} MB but only {:.1} MB may be available",
656 model_id,
657 model_config.memory_requirements_mb,
658 available_memory
659 );
660 }
661
662 let acoustic_path = model_path.join(&model_config.acoustic_model);
664 let vocoder_path = model_path.join(&model_config.vocoder_model);
665
666 if !acoustic_path.exists() {
667 return Err(voirs_sdk::VoirsError::config_error(format!(
668 "Acoustic model file not found: {}",
669 acoustic_path.display()
670 )));
671 }
672
673 if !vocoder_path.exists() {
674 return Err(voirs_sdk::VoirsError::config_error(format!(
675 "Vocoder model file not found: {}",
676 vocoder_path.display()
677 )));
678 }
679
680 let mut builder = VoirsPipeline::builder().with_quality(model_config.quality);
682
683 if model_config.requires_gpu && (config.pipeline.use_gpu || global.gpu) {
685 builder = builder.with_gpu_acceleration(true);
686 if !global.quiet {
687 println!(" GPU acceleration: enabled (required by model)");
688 }
689 } else if config.pipeline.use_gpu || global.gpu {
690 builder = builder.with_gpu_acceleration(true);
691 if !global.quiet {
692 println!(" GPU acceleration: enabled");
693 }
694 } else if !global.quiet {
695 println!(" GPU acceleration: disabled");
696 }
697
698 if let Some(threads) = config.pipeline.num_threads {
700 builder = builder.with_threads(threads);
701 if !global.quiet {
702 println!(" Threads: {}", threads);
703 }
704 } else {
705 let default_threads = config.pipeline.effective_thread_count();
706 builder = builder.with_threads(default_threads);
707 if !global.quiet {
708 println!(" Threads: {} (auto)", default_threads);
709 }
710 }
711
712 let pipeline = builder.build().await.map_err(|e| {
714 voirs_sdk::VoirsError::config_error(format!("Failed to load model '{}': {}", model_id, e))
715 })?;
716
717 if !global.quiet {
718 println!(" ✓ Model loaded successfully");
719 }
720
721 Ok(pipeline)
722}
723
724fn load_cmu_accuracy_benchmark() -> Result<AccuracyBenchmark> {
726 let mut benchmark = AccuracyBenchmark::new();
727
728 let cmu_test_cases = vec![
730 ("hello", vec!["h", "ə", "ˈl", "oʊ"], LanguageCode::EnUs),
732 ("world", vec!["w", "ɜːr", "l", "d"], LanguageCode::EnUs),
733 ("cat", vec!["k", "æ", "t"], LanguageCode::EnUs),
734 ("dog", vec!["d", "ɔː", "ɡ"], LanguageCode::EnUs),
735 ("house", vec!["h", "aʊ", "s"], LanguageCode::EnUs),
736 ("tree", vec!["t", "r", "iː"], LanguageCode::EnUs),
737 ("water", vec!["ˈw", "ɔː", "t", "ər"], LanguageCode::EnUs),
738 ("phone", vec!["f", "oʊ", "n"], LanguageCode::EnUs),
739 ("beat", vec!["b", "iː", "t"], LanguageCode::EnUs),
741 ("bit", vec!["b", "ɪ", "t"], LanguageCode::EnUs),
742 ("bet", vec!["b", "ɛ", "t"], LanguageCode::EnUs),
743 ("bat", vec!["b", "æ", "t"], LanguageCode::EnUs),
744 ("bot", vec!["b", "ɑː", "t"], LanguageCode::EnUs),
745 ("boat", vec!["b", "oʊ", "t"], LanguageCode::EnUs),
746 ("boot", vec!["b", "uː", "t"], LanguageCode::EnUs),
747 ("but", vec!["b", "ʌ", "t"], LanguageCode::EnUs),
748 ("street", vec!["s", "t", "r", "iː", "t"], LanguageCode::EnUs),
750 ("spring", vec!["s", "p", "r", "ɪ", "ŋ"], LanguageCode::EnUs),
751 ("school", vec!["s", "k", "uː", "l"], LanguageCode::EnUs),
752 ("throw", vec!["θ", "r", "oʊ"], LanguageCode::EnUs),
753 ("three", vec!["θ", "r", "iː"], LanguageCode::EnUs),
754 ("one", vec!["w", "ʌ", "n"], LanguageCode::EnUs),
756 ("two", vec!["t", "uː"], LanguageCode::EnUs),
757 ("eight", vec!["eɪ", "t"], LanguageCode::EnUs),
758 ("through", vec!["θ", "r", "uː"], LanguageCode::EnUs),
759 ("though", vec!["ð", "oʊ"], LanguageCode::EnUs),
760 ("rough", vec!["r", "ʌ", "f"], LanguageCode::EnUs),
761 (
763 "computer",
764 vec!["k", "ə", "m", "ˈp", "j", "uː", "t", "ər"],
765 LanguageCode::EnUs,
766 ),
767 (
768 "beautiful",
769 vec!["ˈb", "j", "uː", "t", "ɪ", "f", "ə", "l"],
770 LanguageCode::EnUs,
771 ),
772 (
773 "restaurant",
774 vec!["ˈr", "ɛ", "s", "t", "ər", "ɑː", "n", "t"],
775 LanguageCode::EnUs,
776 ),
777 (
778 "university",
779 vec!["j", "uː", "n", "ɪ", "ˈv", "ɜːr", "s", "ə", "t", "i"],
780 LanguageCode::EnUs,
781 ),
782 (
783 "pronunciation",
784 vec!["p", "r", "ə", "n", "ʌ", "n", "s", "i", "ˈeɪ", "ʃ", "ə", "n"],
785 LanguageCode::EnUs,
786 ),
787 (
789 "california",
790 vec!["k", "æ", "l", "ɪ", "ˈf", "ɔːr", "n", "j", "ə"],
791 LanguageCode::EnUs,
792 ),
793 (
794 "washington",
795 vec!["ˈw", "ɑː", "ʃ", "ɪ", "ŋ", "t", "ə", "n"],
796 LanguageCode::EnUs,
797 ),
798 (
799 "america",
800 vec!["ə", "ˈm", "ɛr", "ɪ", "k", "ə"],
801 LanguageCode::EnUs,
802 ),
803 (
805 "technology",
806 vec!["t", "ɛ", "k", "ˈn", "ɑː", "l", "ə", "dʒ", "i"],
807 LanguageCode::EnUs,
808 ),
809 (
810 "artificial",
811 vec!["ɑːr", "t", "ɪ", "ˈf", "ɪ", "ʃ", "ə", "l"],
812 LanguageCode::EnUs,
813 ),
814 (
815 "intelligence",
816 vec!["ɪ", "n", "ˈt", "ɛ", "l", "ɪ", "dʒ", "ə", "n", "s"],
817 LanguageCode::EnUs,
818 ),
819 (
820 "synthesis",
821 vec!["ˈs", "ɪ", "n", "θ", "ə", "s", "ɪ", "s"],
822 LanguageCode::EnUs,
823 ),
824 (
826 "こんにちは",
827 vec!["k", "o", "n", "n", "i", "ch", "i", "w", "a"],
828 LanguageCode::Ja,
829 ),
830 (
831 "ありがとう",
832 vec!["a", "r", "i", "g", "a", "t", "o", "u"],
833 LanguageCode::Ja,
834 ),
835 (
836 "おはよう",
837 vec!["o", "h", "a", "y", "o", "u"],
838 LanguageCode::Ja,
839 ),
840 (
841 "さよなら",
842 vec!["s", "a", "y", "o", "n", "a", "r", "a"],
843 LanguageCode::Ja,
844 ),
845 (
846 "コンピュータ",
847 vec!["k", "o", "n", "p", "y", "u", "u", "t", "a"],
848 LanguageCode::Ja,
849 ),
850 (
851 "テクノロジー",
852 vec!["t", "e", "k", "u", "n", "o", "r", "o", "j", "i", "i"],
853 LanguageCode::Ja,
854 ),
855 (
856 "アニメーション",
857 vec!["a", "n", "i", "m", "e", "e", "sh", "o", "n"],
858 LanguageCode::Ja,
859 ),
860 (
861 "大学",
862 vec!["d", "a", "i", "g", "a", "k", "u"],
863 LanguageCode::Ja,
864 ),
865 (
866 "東京",
867 vec!["t", "o", "u", "k", "y", "o", "u"],
868 LanguageCode::Ja,
869 ),
870 (
871 "日本語",
872 vec!["n", "i", "h", "o", "n", "g", "o"],
873 LanguageCode::Ja,
874 ),
875 ];
876
877 for (word, phonemes, lang) in cmu_test_cases {
878 benchmark.add_test_case(TestCase {
879 word: word.to_string(),
880 expected_phonemes: phonemes.into_iter().map(|p| p.to_string()).collect(),
881 language: lang,
882 });
883 }
884
885 Ok(benchmark)
886}
887
888async fn run_accuracy_test(
890 _pipeline: &VoirsPipeline,
891 benchmark: &AccuracyBenchmark,
892 _global: &GlobalOptions,
893) -> Result<(f64, f64, bool, bool, bool)> {
894 let g2p = create_test_g2p_system();
905
906 let metrics = benchmark
907 .evaluate(&g2p)
908 .await
909 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Accuracy test failed: {}", e)))?;
910
911 let english_target_met =
914 if let Some(en_metrics) = metrics.language_metrics.get(&LanguageCode::EnUs) {
915 en_metrics.accuracy >= 0.95
916 } else {
917 false
918 };
919
920 let japanese_target_met =
921 if let Some(ja_metrics) = metrics.language_metrics.get(&LanguageCode::Ja) {
922 ja_metrics.accuracy >= 0.90
923 } else {
924 false
925 };
926
927 let target_met = english_target_met || japanese_target_met;
930
931 Ok((
932 metrics.phoneme_accuracy,
933 metrics.word_accuracy,
934 target_met,
935 english_target_met,
936 japanese_target_met,
937 ))
938}
939
940fn create_test_g2p_system() -> impl voirs_g2p::G2p {
942 voirs_g2p::DummyG2p::new()
945}
946
947fn check_latency_target(avg_synthesis_time: &Duration) -> bool {
949 avg_synthesis_time.as_millis() < 1
950}
951
952fn check_memory_target(memory_usage_mb: f64) -> bool {
954 memory_usage_mb < 100.0
955}
956
957#[derive(Debug, serde::Deserialize, serde::Serialize)]
959struct ModelMetadata {
960 pub id: String,
961 pub name: String,
962 pub description: String,
963 pub model_type: String,
964 pub quality: QualityLevel,
965 pub requires_gpu: bool,
966 pub memory_requirements_mb: u32,
967 pub acoustic_model: String,
968 pub vocoder_model: String,
969 pub g2p_model: Option<String>,
970}
971
972#[cfg(test)]
973mod tests {
974 use super::*;
975
976 #[test]
977 fn test_get_test_sentences() {
978 let sentences = get_test_sentences();
979 assert!(!sentences.is_empty());
980 assert!(sentences.len() >= 3);
981 }
982
983 #[test]
984 fn test_calculate_quality_score() {
985 let score = calculate_quality_score("hifigan-v1", &0.5, &1.0);
986 assert!(score >= 0.0 && score <= 5.0);
987 }
988
989 #[test]
990 fn test_get_memory_usage() {
991 let usage = get_memory_usage();
992 assert!(usage >= 0.0);
993 }
994
995 #[test]
996 fn test_check_latency_target() {
997 let fast_duration = Duration::from_micros(500);
999 assert!(check_latency_target(&fast_duration));
1000
1001 let slow_duration = Duration::from_millis(2);
1003 assert!(!check_latency_target(&slow_duration));
1004
1005 let edge_duration = Duration::from_millis(1);
1007 assert!(!check_latency_target(&edge_duration));
1008 }
1009
1010 #[test]
1011 fn test_check_memory_target() {
1012 assert!(check_memory_target(50.0));
1014
1015 assert!(!check_memory_target(150.0));
1017
1018 assert!(!check_memory_target(100.0));
1020 }
1021}