1use crate::{error::CliError, output::OutputFormatter};
4use clap::{Args, Subcommand};
5use dasp::{Frame, Sample};
6use hound;
7use std::path::PathBuf;
8#[cfg(feature = "cloning")]
9use voirs_cloning::{
10 CloningConfig, CloningMethod, SpeakerProfile, VoiceCloneRequest, VoiceCloner, VoiceSample,
11};
12
13#[cfg(feature = "cloning")]
15#[derive(Debug, Clone, Subcommand)]
16pub enum CloningCommand {
17 Clone(CloneArgs),
19 Quick(QuickCloneArgs),
21 ListProfiles(ListProfilesArgs),
23 Validate(ValidateArgs),
25 ClearCache(ClearCacheArgs),
27}
28
29#[derive(Debug, Clone, Args)]
30pub struct CloneArgs {
31 #[arg(long, required = true)]
33 pub reference_files: Vec<PathBuf>,
34 pub text: String,
36 pub output: PathBuf,
38 #[arg(long, default_value = "few-shot")]
40 pub method: String,
41 #[arg(long)]
43 pub speaker_id: Option<String>,
44 #[arg(long, default_value = "0.7")]
46 pub quality_threshold: f32,
47 #[arg(long, default_value = "22050")]
49 pub sample_rate: u32,
50}
51
52#[derive(Debug, Clone, Args)]
53pub struct QuickCloneArgs {
54 pub reference_file: PathBuf,
56 pub text: String,
58 pub output: PathBuf,
60 #[arg(long, default_value = "22050")]
62 pub sample_rate: u32,
63}
64
65#[derive(Debug, Clone, Args)]
66pub struct ListProfilesArgs {
67 #[arg(long, default_value = "table")]
69 pub format: String,
70 #[arg(long)]
72 pub detailed: bool,
73}
74
75#[derive(Debug, Clone, Args)]
76pub struct ValidateArgs {
77 #[arg(required = true)]
79 pub audio_files: Vec<PathBuf>,
80 #[arg(long, default_value = "table")]
82 pub format: String,
83 #[arg(long, default_value = "0.6")]
85 pub min_quality: f32,
86}
87
88#[derive(Debug, Clone, Args)]
89pub struct ClearCacheArgs {
90 #[arg(long)]
92 pub yes: bool,
93}
94
95#[cfg(feature = "cloning")]
97pub async fn execute_cloning_command(
98 command: CloningCommand,
99 output_formatter: &OutputFormatter,
100) -> Result<(), CliError> {
101 match command {
102 CloningCommand::Clone(args) => execute_clone(args, output_formatter).await,
103 CloningCommand::Quick(args) => execute_quick_clone(args, output_formatter).await,
104 CloningCommand::ListProfiles(args) => execute_list_profiles(args, output_formatter).await,
105 CloningCommand::Validate(args) => execute_validate(args, output_formatter).await,
106 CloningCommand::ClearCache(args) => execute_clear_cache(args, output_formatter).await,
107 }
108}
109
110#[cfg(feature = "cloning")]
111async fn execute_clone(
112 args: CloneArgs,
113 output_formatter: &OutputFormatter,
114) -> Result<(), CliError> {
115 let method = match args.method.to_lowercase().as_str() {
117 "few-shot" | "few_shot" => CloningMethod::FewShot,
118 "one-shot" | "one_shot" => CloningMethod::OneShot,
119 "zero-shot" | "zero_shot" => CloningMethod::ZeroShot,
120 "fine-tuning" | "fine_tuning" => CloningMethod::FineTuning,
121 "voice-conversion" | "voice_conversion" => CloningMethod::VoiceConversion,
122 "hybrid" => CloningMethod::Hybrid,
123 _ => return Err(CliError::invalid_parameter("method", "Invalid cloning method. Use: few-shot, one-shot, zero-shot, fine-tuning, voice-conversion, or hybrid")),
124 };
125
126 if args.quality_threshold < 0.0 || args.quality_threshold > 1.0 {
127 return Err(CliError::invalid_parameter(
128 "quality_threshold",
129 "Quality threshold must be between 0.0 and 1.0",
130 ));
131 }
132
133 println!(
135 "Loading {} reference audio files...",
136 args.reference_files.len()
137 );
138 let mut voice_samples = Vec::new();
139
140 for (i, file_path) in args.reference_files.iter().enumerate() {
141 if !file_path.exists() {
142 return Err(CliError::config(format!(
143 "Reference file not found: {}",
144 file_path.display()
145 )));
146 }
147
148 println!(" Loading sample {}: {}", i + 1, file_path.display());
149
150 let audio_data = load_audio_file(file_path).map_err(|e| {
152 CliError::config(format!(
153 "Failed to load audio file {}: {}",
154 file_path.display(),
155 e
156 ))
157 })?;
158
159 let sample = VoiceSample::new(
160 format!("sample_{}", i),
161 audio_data.samples,
162 audio_data.sample_rate,
163 );
164 voice_samples.push(sample);
165 }
166
167 let cloner = VoiceCloner::new()
169 .map_err(|e| CliError::config(format!("Failed to create voice cloner: {}", e)))?;
170
171 let speaker_id = args
173 .speaker_id
174 .unwrap_or_else(|| format!("speaker_{}", fastrand::u64(..)));
175
176 println!("Cloning voice with method: {:?}", method);
177 println!("Speaker ID: {}", speaker_id);
178 println!("Target text: '{}'", args.text);
179
180 println!("Processing voice cloning...");
182 let mut speaker_data = voirs_cloning::SpeakerData::new(SpeakerProfile::new(
183 speaker_id.clone(),
184 speaker_id.clone(),
185 ));
186
187 speaker_data.reference_samples = voice_samples;
189
190 let request = VoiceCloneRequest::new(
191 format!("clone_{}", fastrand::u64(..)),
192 speaker_data,
193 method,
194 args.text.clone(),
195 );
196
197 let result = cloner
198 .clone_voice(request)
199 .await
200 .map_err(|e| CliError::config(format!("Voice cloning failed: {}", e)))?;
201
202 if result.success {
203 save_audio_file(&result.audio, args.sample_rate, &args.output)
205 .map_err(|e| CliError::AudioError(format!("Failed to save audio: {}", e)))?;
206
207 output_formatter.success(&format!(
208 "Voice cloning completed! Quality score: {:.2}, Output saved to: {}",
209 result.similarity_score,
210 args.output.display()
211 ));
212 } else {
213 let error_msg = result.error_message.unwrap_or("Unknown error".to_string());
214 return Err(CliError::AudioError(format!(
215 "Voice cloning failed: {}",
216 error_msg
217 )));
218 }
219
220 Ok(())
221}
222
223#[cfg(feature = "cloning")]
224async fn execute_quick_clone(
225 args: QuickCloneArgs,
226 output_formatter: &OutputFormatter,
227) -> Result<(), CliError> {
228 if !args.reference_file.exists() {
229 return Err(CliError::config(format!(
230 "Reference file not found: {}",
231 args.reference_file.display()
232 )));
233 }
234
235 println!("Quick cloning from: {}", args.reference_file.display());
236 println!("Target text: '{}'", args.text);
237
238 let audio_data = load_audio_file(&args.reference_file)
240 .map_err(|e| CliError::AudioError(format!("Failed to load reference audio: {}", e)))?;
241
242 let cloner = VoiceCloner::new()
244 .map_err(|e| CliError::config(format!("Failed to create voice cloner: {}", e)))?;
245
246 let voice_sample = VoiceSample::new(
248 "quick_clone_sample".to_string(),
249 audio_data.samples,
250 audio_data.sample_rate,
251 );
252
253 let mut speaker_data = voirs_cloning::SpeakerData::new(SpeakerProfile::new(
254 "quick_clone_speaker".to_string(),
255 "Quick Clone".to_string(),
256 ));
257 speaker_data.reference_samples = vec![voice_sample];
258
259 let request = VoiceCloneRequest::new(
260 format!("quick_clone_{}", fastrand::u64(..)),
261 speaker_data,
262 CloningMethod::OneShot,
263 args.text.clone(),
264 );
265
266 println!("Processing quick voice cloning...");
268 let result = cloner
269 .clone_voice(request)
270 .await
271 .map_err(|e| CliError::AudioError(format!("Quick cloning failed: {}", e)))?;
272
273 if result.success {
274 save_audio_file(&result.audio, args.sample_rate, &args.output)
276 .map_err(|e| CliError::AudioError(format!("Failed to save audio: {}", e)))?;
277
278 output_formatter.success(&format!(
279 "Quick cloning completed! Quality score: {:.2}, Output saved to: {}",
280 result.similarity_score,
281 args.output.display()
282 ));
283 } else {
284 let error_msg = result.error_message.unwrap_or("Unknown error".to_string());
285 return Err(CliError::AudioError(format!(
286 "Quick cloning failed: {}",
287 error_msg
288 )));
289 }
290
291 Ok(())
292}
293
294#[cfg(feature = "cloning")]
295async fn execute_list_profiles(
296 args: ListProfilesArgs,
297 output_formatter: &OutputFormatter,
298) -> Result<(), CliError> {
299 let cloner = VoiceCloner::new()
301 .map_err(|e| CliError::config(format!("Failed to create voice cloner: {}", e)))?;
302
303 let profiles = cloner.list_cached_speakers().await;
304
305 if profiles.is_empty() {
306 println!("No cached speaker profiles found.");
307 return Ok(());
308 }
309
310 match args.format.as_str() {
311 "table" => {
312 println!("{:<20} {:<30} Samples", "Speaker ID", "Description");
313 println!("{}", "-".repeat(60));
314 for profile in profiles {
315 println!(
316 "{:<20} {:<30} {}",
317 profile.id,
318 profile.name,
319 profile.samples.len()
320 );
321 }
322 }
323 "json" => {
324 let json_profiles: Vec<_> = profiles
325 .iter()
326 .map(|id| {
327 serde_json::json!({
328 "speaker_id": id,
329 "description": "Cached speaker",
330 "sample_count": "N/A",
331 "details": if args.detailed {
332 Some(serde_json::json!({"cached": true}))
333 } else {
334 None
335 }
336 })
337 })
338 .collect();
339
340 println!(
341 "{}",
342 serde_json::to_string_pretty(&json_profiles).map_err(CliError::Serialization)?
343 );
344 }
345 _ => {
346 for profile in profiles {
347 println!("{}: {}", profile.id, profile.name);
348 }
349 }
350 }
351
352 Ok(())
353}
354
355#[cfg(feature = "cloning")]
356async fn execute_validate(
357 args: ValidateArgs,
358 output_formatter: &OutputFormatter,
359) -> Result<(), CliError> {
360 if args.min_quality < 0.0 || args.min_quality > 1.0 {
361 return Err(CliError::invalid_parameter(
362 "min_quality",
363 "Minimum quality must be between 0.0 and 1.0",
364 ));
365 }
366
367 println!(
368 "Validating {} audio files for cloning...",
369 args.audio_files.len()
370 );
371
372 let mut validation_results = Vec::new();
373 let mut all_valid = true;
374
375 for (i, file_path) in args.audio_files.iter().enumerate() {
376 if !file_path.exists() {
377 validation_results.push((
378 format!("File {}", i + 1),
379 file_path
380 .file_name()
381 .unwrap_or_default()
382 .to_string_lossy()
383 .to_string(),
384 "NOT_FOUND".to_string(),
385 0.0,
386 ));
387 all_valid = false;
388 continue;
389 }
390
391 println!(" Validating: {}", file_path.display());
392
393 let (quality_score, status) = match load_audio_file(file_path) {
395 Ok(audio_data) => {
396 let quality = validate_audio_quality(&audio_data);
397 let status = if quality >= args.min_quality {
398 "VALID"
399 } else {
400 all_valid = false;
401 "LOW_QUALITY"
402 };
403 (quality, status)
404 }
405 Err(_) => {
406 all_valid = false;
407 (0.0, "LOAD_ERROR")
408 }
409 };
410
411 validation_results.push((
412 format!("File {}", i + 1),
413 file_path
414 .file_name()
415 .unwrap_or_default()
416 .to_string_lossy()
417 .to_string(),
418 status.to_string(),
419 quality_score,
420 ));
421 }
422
423 match args.format.as_str() {
425 "table" => {
426 println!("{:<10} {:<30} {:<12} Quality", "File", "Name", "Status");
427 println!("{}", "-".repeat(70));
428 for (file_num, name, status, quality) in validation_results {
429 println!(
430 "{:<10} {:<30} {:<12} {:.2}",
431 file_num, name, status, quality
432 );
433 }
434 }
435 "json" => {
436 let json_results: Vec<_> = validation_results
437 .into_iter()
438 .map(|(file_num, name, status, quality)| {
439 serde_json::json!({
440 "file": file_num,
441 "filename": name,
442 "status": status,
443 "quality_score": quality
444 })
445 })
446 .collect();
447
448 println!(
449 "{}",
450 serde_json::to_string_pretty(&json_results).map_err(CliError::Serialization)?
451 );
452 }
453 _ => {
454 for (file_num, name, status, quality) in validation_results {
455 println!(
456 "{} ({}): {} - Quality: {:.2}",
457 file_num, name, status, quality
458 );
459 }
460 }
461 }
462
463 if all_valid {
464 output_formatter.success("All audio files are valid for voice cloning!");
465 } else {
466 output_formatter
467 .warning("Some audio files may not be suitable for high-quality voice cloning.");
468 }
469
470 Ok(())
471}
472
473#[cfg(feature = "cloning")]
474async fn execute_clear_cache(
475 args: ClearCacheArgs,
476 output_formatter: &OutputFormatter,
477) -> Result<(), CliError> {
478 if !args.yes {
479 println!("This will clear all cached speaker profiles. Continue? (y/N)");
480 let mut input = String::new();
481 std::io::stdin()
482 .read_line(&mut input)
483 .map_err(CliError::Io)?;
484
485 if !input.trim().to_lowercase().starts_with('y') {
486 println!("Cache clearing cancelled.");
487 return Ok(());
488 }
489 }
490
491 let cloner = VoiceCloner::new()
493 .map_err(|e| CliError::config(format!("Failed to create voice cloner: {}", e)))?;
494
495 cloner
496 .clear_cache()
497 .await
498 .map_err(|e| CliError::config(format!("Failed to clear cache: {}", e)))?;
499
500 output_formatter.success("Speaker profile cache cleared successfully!");
501 Ok(())
502}
503
504#[derive(Debug)]
506struct AudioData {
507 samples: Vec<f32>,
508 sample_rate: u32,
509}
510
511fn load_audio_file(path: &PathBuf) -> Result<AudioData, Box<dyn std::error::Error>> {
513 let mut reader = hound::WavReader::open(path)?;
514 let spec = reader.spec();
515
516 let samples: Result<Vec<f32>, _> = match spec.sample_format {
518 hound::SampleFormat::Float => reader.samples::<f32>().collect(),
519 hound::SampleFormat::Int => match spec.bits_per_sample {
520 8 => reader
521 .samples::<i8>()
522 .map(|s| s.map(|sample| sample as f32 / i8::MAX as f32))
523 .collect(),
524 16 => reader
525 .samples::<i16>()
526 .map(|s| s.map(|sample| sample as f32 / i16::MAX as f32))
527 .collect(),
528 24 => reader
529 .samples::<i32>()
530 .map(|s| s.map(|sample| (sample >> 8) as f32 / (i32::MAX >> 8) as f32))
531 .collect(),
532 32 => reader
533 .samples::<i32>()
534 .map(|s| s.map(|sample| sample as f32 / i32::MAX as f32))
535 .collect(),
536 _ => {
537 return Err(format!("Unsupported bit depth: {}", spec.bits_per_sample).into());
538 }
539 },
540 };
541
542 let samples = samples?;
543
544 let mono_samples = if spec.channels == 2 {
546 samples
547 .chunks(2)
548 .map(|frame| (frame[0] + frame[1]) / 2.0)
549 .collect()
550 } else {
551 samples
552 };
553
554 Ok(AudioData {
555 samples: mono_samples,
556 sample_rate: spec.sample_rate,
557 })
558}
559
560fn save_audio_file(
562 audio_data: &[f32],
563 sample_rate: u32,
564 path: &PathBuf,
565) -> Result<(), Box<dyn std::error::Error>> {
566 let spec = hound::WavSpec {
567 channels: 1,
568 sample_rate,
569 bits_per_sample: 16,
570 sample_format: hound::SampleFormat::Int,
571 };
572
573 let mut writer = hound::WavWriter::create(path, spec)?;
574
575 for &sample in audio_data {
576 let sample_i16 = (sample * i16::MAX as f32) as i16;
577 writer.write_sample(sample_i16)?;
578 }
579
580 writer.finalize()?;
581 Ok(())
582}
583
584fn validate_audio_quality(audio_data: &AudioData) -> f32 {
586 let samples = &audio_data.samples;
587
588 if samples.is_empty() {
589 return 0.0;
590 }
591
592 let mut quality_score: f32 = 1.0;
594
595 let duration_seconds = samples.len() as f32 / audio_data.sample_rate as f32;
597 if duration_seconds < 1.0 {
598 quality_score *= 0.3; } else if duration_seconds < 3.0 {
600 quality_score *= 0.7; } else if duration_seconds > 30.0 {
602 quality_score *= 0.8; }
604
605 let rms = (samples.iter().map(|&x| x * x).sum::<f32>() / samples.len() as f32).sqrt();
607 if rms < 0.01 {
608 quality_score *= 0.2; } else if rms < 0.05 {
610 quality_score *= 0.6; }
612
613 let clipped_samples = samples.iter().filter(|&&x| x.abs() > 0.95).count();
615 let clipping_ratio = clipped_samples as f32 / samples.len() as f32;
616 if clipping_ratio > 0.1 {
617 quality_score *= 0.4; } else if clipping_ratio > 0.01 {
619 quality_score *= 0.7; }
621
622 let max_val = samples.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
624 let min_non_zero = samples
625 .iter()
626 .filter(|&&x| x.abs() > 0.001)
627 .map(|x| x.abs())
628 .fold(1.0f32, f32::min);
629
630 let dynamic_range = if min_non_zero > 0.0 {
631 20.0 * (max_val / min_non_zero).log10()
632 } else {
633 0.0
634 };
635
636 if dynamic_range < 20.0 {
637 quality_score *= 0.5; } else if dynamic_range < 40.0 {
639 quality_score *= 0.8; }
641
642 if audio_data.sample_rate < 16000 {
644 quality_score *= 0.6; } else if audio_data.sample_rate < 22050 {
646 quality_score *= 0.9; }
648
649 quality_score.clamp(0.0, 1.0)
650}