Skip to main content

voirs_cli/commands/
cloning.rs

1//! Voice cloning commands for the VoiRS CLI
2
3use crate::{error::CliError, output::OutputFormatter};
4use clap::{Args, Subcommand};
5use dasp::{Frame, Sample};
6use hound;
7use std::path::PathBuf;
8#[cfg(feature = "cloning")]
9use voirs_cloning::{
10    CloningConfig, CloningMethod, SpeakerProfile, VoiceCloneRequest, VoiceCloner, VoiceSample,
11};
12
13/// Voice cloning commands
14#[cfg(feature = "cloning")]
15#[derive(Debug, Clone, Subcommand)]
16pub enum CloningCommand {
17    /// Clone voice from reference samples
18    Clone(CloneArgs),
19    /// Quick clone from single audio file
20    Quick(QuickCloneArgs),
21    /// List cached speaker profiles
22    ListProfiles(ListProfilesArgs),
23    /// Validate reference audio for cloning
24    Validate(ValidateArgs),
25    /// Clear speaker cache
26    ClearCache(ClearCacheArgs),
27}
28
29#[derive(Debug, Clone, Args)]
30pub struct CloneArgs {
31    /// Reference audio files (multiple samples for better quality)
32    #[arg(long, required = true)]
33    pub reference_files: Vec<PathBuf>,
34    /// Text to synthesize with cloned voice
35    pub text: String,
36    /// Output audio file path
37    pub output: PathBuf,
38    /// Cloning method (few-shot, one-shot, zero-shot, fine-tuning)
39    #[arg(long, default_value = "few-shot")]
40    pub method: String,
41    /// Speaker name/ID for caching
42    #[arg(long)]
43    pub speaker_id: Option<String>,
44    /// Quality threshold (0.0 to 1.0)
45    #[arg(long, default_value = "0.7")]
46    pub quality_threshold: f32,
47    /// Sample rate for output audio
48    #[arg(long, default_value = "22050")]
49    pub sample_rate: u32,
50}
51
52#[derive(Debug, Clone, Args)]
53pub struct QuickCloneArgs {
54    /// Single reference audio file
55    pub reference_file: PathBuf,
56    /// Text to synthesize with cloned voice
57    pub text: String,
58    /// Output audio file path
59    pub output: PathBuf,
60    /// Sample rate for output audio
61    #[arg(long, default_value = "22050")]
62    pub sample_rate: u32,
63}
64
65#[derive(Debug, Clone, Args)]
66pub struct ListProfilesArgs {
67    /// Output format for the profile list
68    #[arg(long, default_value = "table")]
69    pub format: String,
70    /// Show detailed profile information
71    #[arg(long)]
72    pub detailed: bool,
73}
74
75#[derive(Debug, Clone, Args)]
76pub struct ValidateArgs {
77    /// Reference audio files to validate
78    #[arg(required = true)]
79    pub audio_files: Vec<PathBuf>,
80    /// Output validation report format
81    #[arg(long, default_value = "table")]
82    pub format: String,
83    /// Minimum quality threshold
84    #[arg(long, default_value = "0.6")]
85    pub min_quality: f32,
86}
87
88#[derive(Debug, Clone, Args)]
89pub struct ClearCacheArgs {
90    /// Confirm cache clearing without prompt
91    #[arg(long)]
92    pub yes: bool,
93}
94
95/// Execute cloning commands
96#[cfg(feature = "cloning")]
97pub async fn execute_cloning_command(
98    command: CloningCommand,
99    output_formatter: &OutputFormatter,
100) -> Result<(), CliError> {
101    match command {
102        CloningCommand::Clone(args) => execute_clone(args, output_formatter).await,
103        CloningCommand::Quick(args) => execute_quick_clone(args, output_formatter).await,
104        CloningCommand::ListProfiles(args) => execute_list_profiles(args, output_formatter).await,
105        CloningCommand::Validate(args) => execute_validate(args, output_formatter).await,
106        CloningCommand::ClearCache(args) => execute_clear_cache(args, output_formatter).await,
107    }
108}
109
110#[cfg(feature = "cloning")]
111async fn execute_clone(
112    args: CloneArgs,
113    output_formatter: &OutputFormatter,
114) -> Result<(), CliError> {
115    // Validate cloning method
116    let method = match args.method.to_lowercase().as_str() {
117        "few-shot" | "few_shot" => CloningMethod::FewShot,
118        "one-shot" | "one_shot" => CloningMethod::OneShot,
119        "zero-shot" | "zero_shot" => CloningMethod::ZeroShot,
120        "fine-tuning" | "fine_tuning" => CloningMethod::FineTuning,
121        "voice-conversion" | "voice_conversion" => CloningMethod::VoiceConversion,
122        "hybrid" => CloningMethod::Hybrid,
123        _ => return Err(CliError::invalid_parameter("method", "Invalid cloning method. Use: few-shot, one-shot, zero-shot, fine-tuning, voice-conversion, or hybrid")),
124    };
125
126    if args.quality_threshold < 0.0 || args.quality_threshold > 1.0 {
127        return Err(CliError::invalid_parameter(
128            "quality_threshold",
129            "Quality threshold must be between 0.0 and 1.0",
130        ));
131    }
132
133    // Load reference audio files
134    println!(
135        "Loading {} reference audio files...",
136        args.reference_files.len()
137    );
138    let mut voice_samples = Vec::new();
139
140    for (i, file_path) in args.reference_files.iter().enumerate() {
141        if !file_path.exists() {
142            return Err(CliError::config(format!(
143                "Reference file not found: {}",
144                file_path.display()
145            )));
146        }
147
148        println!("  Loading sample {}: {}", i + 1, file_path.display());
149
150        // Load actual audio file
151        let audio_data = load_audio_file(file_path).map_err(|e| {
152            CliError::config(format!(
153                "Failed to load audio file {}: {}",
154                file_path.display(),
155                e
156            ))
157        })?;
158
159        let sample = VoiceSample::new(
160            format!("sample_{}", i),
161            audio_data.samples,
162            audio_data.sample_rate,
163        );
164        voice_samples.push(sample);
165    }
166
167    // Create voice cloner
168    let cloner = VoiceCloner::new()
169        .map_err(|e| CliError::config(format!("Failed to create voice cloner: {}", e)))?;
170
171    // Create speaker profile
172    let speaker_id = args
173        .speaker_id
174        .unwrap_or_else(|| format!("speaker_{}", fastrand::u64(..)));
175
176    println!("Cloning voice with method: {:?}", method);
177    println!("Speaker ID: {}", speaker_id);
178    println!("Target text: '{}'", args.text);
179
180    // Perform cloning
181    println!("Processing voice cloning...");
182    let mut speaker_data = voirs_cloning::SpeakerData::new(SpeakerProfile::new(
183        speaker_id.clone(),
184        speaker_id.clone(),
185    ));
186
187    // Add voice samples to speaker data
188    speaker_data.reference_samples = voice_samples;
189
190    let request = VoiceCloneRequest::new(
191        format!("clone_{}", fastrand::u64(..)),
192        speaker_data,
193        method,
194        args.text.clone(),
195    );
196
197    let result = cloner
198        .clone_voice(request)
199        .await
200        .map_err(|e| CliError::config(format!("Voice cloning failed: {}", e)))?;
201
202    if result.success {
203        // Save audio to output file
204        save_audio_file(&result.audio, args.sample_rate, &args.output)
205            .map_err(|e| CliError::AudioError(format!("Failed to save audio: {}", e)))?;
206
207        output_formatter.success(&format!(
208            "Voice cloning completed! Quality score: {:.2}, Output saved to: {}",
209            result.similarity_score,
210            args.output.display()
211        ));
212    } else {
213        let error_msg = result.error_message.unwrap_or("Unknown error".to_string());
214        return Err(CliError::AudioError(format!(
215            "Voice cloning failed: {}",
216            error_msg
217        )));
218    }
219
220    Ok(())
221}
222
223#[cfg(feature = "cloning")]
224async fn execute_quick_clone(
225    args: QuickCloneArgs,
226    output_formatter: &OutputFormatter,
227) -> Result<(), CliError> {
228    if !args.reference_file.exists() {
229        return Err(CliError::config(format!(
230            "Reference file not found: {}",
231            args.reference_file.display()
232        )));
233    }
234
235    println!("Quick cloning from: {}", args.reference_file.display());
236    println!("Target text: '{}'", args.text);
237
238    // Load reference audio
239    let audio_data = load_audio_file(&args.reference_file)
240        .map_err(|e| CliError::AudioError(format!("Failed to load reference audio: {}", e)))?;
241
242    // Create voice cloner
243    let cloner = VoiceCloner::new()
244        .map_err(|e| CliError::config(format!("Failed to create voice cloner: {}", e)))?;
245
246    // Create voice sample and speaker data
247    let voice_sample = VoiceSample::new(
248        "quick_clone_sample".to_string(),
249        audio_data.samples,
250        audio_data.sample_rate,
251    );
252
253    let mut speaker_data = voirs_cloning::SpeakerData::new(SpeakerProfile::new(
254        "quick_clone_speaker".to_string(),
255        "Quick Clone".to_string(),
256    ));
257    speaker_data.reference_samples = vec![voice_sample];
258
259    let request = VoiceCloneRequest::new(
260        format!("quick_clone_{}", fastrand::u64(..)),
261        speaker_data,
262        CloningMethod::OneShot,
263        args.text.clone(),
264    );
265
266    // Perform quick cloning
267    println!("Processing quick voice cloning...");
268    let result = cloner
269        .clone_voice(request)
270        .await
271        .map_err(|e| CliError::AudioError(format!("Quick cloning failed: {}", e)))?;
272
273    if result.success {
274        // Save audio to output file
275        save_audio_file(&result.audio, args.sample_rate, &args.output)
276            .map_err(|e| CliError::AudioError(format!("Failed to save audio: {}", e)))?;
277
278        output_formatter.success(&format!(
279            "Quick cloning completed! Quality score: {:.2}, Output saved to: {}",
280            result.similarity_score,
281            args.output.display()
282        ));
283    } else {
284        let error_msg = result.error_message.unwrap_or("Unknown error".to_string());
285        return Err(CliError::AudioError(format!(
286            "Quick cloning failed: {}",
287            error_msg
288        )));
289    }
290
291    Ok(())
292}
293
294#[cfg(feature = "cloning")]
295async fn execute_list_profiles(
296    args: ListProfilesArgs,
297    output_formatter: &OutputFormatter,
298) -> Result<(), CliError> {
299    // Create voice cloner to access cached profiles
300    let cloner = VoiceCloner::new()
301        .map_err(|e| CliError::config(format!("Failed to create voice cloner: {}", e)))?;
302
303    let profiles = cloner.list_cached_speakers().await;
304
305    if profiles.is_empty() {
306        println!("No cached speaker profiles found.");
307        return Ok(());
308    }
309
310    match args.format.as_str() {
311        "table" => {
312            println!("{:<20} {:<30} Samples", "Speaker ID", "Description");
313            println!("{}", "-".repeat(60));
314            for profile in profiles {
315                println!(
316                    "{:<20} {:<30} {}",
317                    profile.id,
318                    profile.name,
319                    profile.samples.len()
320                );
321            }
322        }
323        "json" => {
324            let json_profiles: Vec<_> = profiles
325                .iter()
326                .map(|id| {
327                    serde_json::json!({
328                        "speaker_id": id,
329                        "description": "Cached speaker",
330                        "sample_count": "N/A",
331                        "details": if args.detailed {
332                            Some(serde_json::json!({"cached": true}))
333                        } else {
334                            None
335                        }
336                    })
337                })
338                .collect();
339
340            println!(
341                "{}",
342                serde_json::to_string_pretty(&json_profiles).map_err(CliError::Serialization)?
343            );
344        }
345        _ => {
346            for profile in profiles {
347                println!("{}: {}", profile.id, profile.name);
348            }
349        }
350    }
351
352    Ok(())
353}
354
355#[cfg(feature = "cloning")]
356async fn execute_validate(
357    args: ValidateArgs,
358    output_formatter: &OutputFormatter,
359) -> Result<(), CliError> {
360    if args.min_quality < 0.0 || args.min_quality > 1.0 {
361        return Err(CliError::invalid_parameter(
362            "min_quality",
363            "Minimum quality must be between 0.0 and 1.0",
364        ));
365    }
366
367    println!(
368        "Validating {} audio files for cloning...",
369        args.audio_files.len()
370    );
371
372    let mut validation_results = Vec::new();
373    let mut all_valid = true;
374
375    for (i, file_path) in args.audio_files.iter().enumerate() {
376        if !file_path.exists() {
377            validation_results.push((
378                format!("File {}", i + 1),
379                file_path
380                    .file_name()
381                    .unwrap_or_default()
382                    .to_string_lossy()
383                    .to_string(),
384                "NOT_FOUND".to_string(),
385                0.0,
386            ));
387            all_valid = false;
388            continue;
389        }
390
391        println!("  Validating: {}", file_path.display());
392
393        // Load and validate audio file
394        let (quality_score, status) = match load_audio_file(file_path) {
395            Ok(audio_data) => {
396                let quality = validate_audio_quality(&audio_data);
397                let status = if quality >= args.min_quality {
398                    "VALID"
399                } else {
400                    all_valid = false;
401                    "LOW_QUALITY"
402                };
403                (quality, status)
404            }
405            Err(_) => {
406                all_valid = false;
407                (0.0, "LOAD_ERROR")
408            }
409        };
410
411        validation_results.push((
412            format!("File {}", i + 1),
413            file_path
414                .file_name()
415                .unwrap_or_default()
416                .to_string_lossy()
417                .to_string(),
418            status.to_string(),
419            quality_score,
420        ));
421    }
422
423    // Display results
424    match args.format.as_str() {
425        "table" => {
426            println!("{:<10} {:<30} {:<12} Quality", "File", "Name", "Status");
427            println!("{}", "-".repeat(70));
428            for (file_num, name, status, quality) in validation_results {
429                println!(
430                    "{:<10} {:<30} {:<12} {:.2}",
431                    file_num, name, status, quality
432                );
433            }
434        }
435        "json" => {
436            let json_results: Vec<_> = validation_results
437                .into_iter()
438                .map(|(file_num, name, status, quality)| {
439                    serde_json::json!({
440                        "file": file_num,
441                        "filename": name,
442                        "status": status,
443                        "quality_score": quality
444                    })
445                })
446                .collect();
447
448            println!(
449                "{}",
450                serde_json::to_string_pretty(&json_results).map_err(CliError::Serialization)?
451            );
452        }
453        _ => {
454            for (file_num, name, status, quality) in validation_results {
455                println!(
456                    "{} ({}): {} - Quality: {:.2}",
457                    file_num, name, status, quality
458                );
459            }
460        }
461    }
462
463    if all_valid {
464        output_formatter.success("All audio files are valid for voice cloning!");
465    } else {
466        output_formatter
467            .warning("Some audio files may not be suitable for high-quality voice cloning.");
468    }
469
470    Ok(())
471}
472
473#[cfg(feature = "cloning")]
474async fn execute_clear_cache(
475    args: ClearCacheArgs,
476    output_formatter: &OutputFormatter,
477) -> Result<(), CliError> {
478    if !args.yes {
479        println!("This will clear all cached speaker profiles. Continue? (y/N)");
480        let mut input = String::new();
481        std::io::stdin()
482            .read_line(&mut input)
483            .map_err(CliError::Io)?;
484
485        if !input.trim().to_lowercase().starts_with('y') {
486            println!("Cache clearing cancelled.");
487            return Ok(());
488        }
489    }
490
491    // Create voice cloner to clear cache
492    let cloner = VoiceCloner::new()
493        .map_err(|e| CliError::config(format!("Failed to create voice cloner: {}", e)))?;
494
495    cloner
496        .clear_cache()
497        .await
498        .map_err(|e| CliError::config(format!("Failed to clear cache: {}", e)))?;
499
500    output_formatter.success("Speaker profile cache cleared successfully!");
501    Ok(())
502}
503
504/// Audio data structure
505#[derive(Debug)]
506struct AudioData {
507    samples: Vec<f32>,
508    sample_rate: u32,
509}
510
511/// Load audio file and convert to f32 samples
512fn load_audio_file(path: &PathBuf) -> Result<AudioData, Box<dyn std::error::Error>> {
513    let mut reader = hound::WavReader::open(path)?;
514    let spec = reader.spec();
515
516    // Convert samples to f32
517    let samples: Result<Vec<f32>, _> = match spec.sample_format {
518        hound::SampleFormat::Float => reader.samples::<f32>().collect(),
519        hound::SampleFormat::Int => match spec.bits_per_sample {
520            8 => reader
521                .samples::<i8>()
522                .map(|s| s.map(|sample| sample as f32 / i8::MAX as f32))
523                .collect(),
524            16 => reader
525                .samples::<i16>()
526                .map(|s| s.map(|sample| sample as f32 / i16::MAX as f32))
527                .collect(),
528            24 => reader
529                .samples::<i32>()
530                .map(|s| s.map(|sample| (sample >> 8) as f32 / (i32::MAX >> 8) as f32))
531                .collect(),
532            32 => reader
533                .samples::<i32>()
534                .map(|s| s.map(|sample| sample as f32 / i32::MAX as f32))
535                .collect(),
536            _ => {
537                return Err(format!("Unsupported bit depth: {}", spec.bits_per_sample).into());
538            }
539        },
540    };
541
542    let samples = samples?;
543
544    // Convert to mono if stereo
545    let mono_samples = if spec.channels == 2 {
546        samples
547            .chunks(2)
548            .map(|frame| (frame[0] + frame[1]) / 2.0)
549            .collect()
550    } else {
551        samples
552    };
553
554    Ok(AudioData {
555        samples: mono_samples,
556        sample_rate: spec.sample_rate,
557    })
558}
559
560/// Save audio data to WAV file
561fn save_audio_file(
562    audio_data: &[f32],
563    sample_rate: u32,
564    path: &PathBuf,
565) -> Result<(), Box<dyn std::error::Error>> {
566    let spec = hound::WavSpec {
567        channels: 1,
568        sample_rate,
569        bits_per_sample: 16,
570        sample_format: hound::SampleFormat::Int,
571    };
572
573    let mut writer = hound::WavWriter::create(path, spec)?;
574
575    for &sample in audio_data {
576        let sample_i16 = (sample * i16::MAX as f32) as i16;
577        writer.write_sample(sample_i16)?;
578    }
579
580    writer.finalize()?;
581    Ok(())
582}
583
584/// Validate audio quality for voice cloning
585fn validate_audio_quality(audio_data: &AudioData) -> f32 {
586    let samples = &audio_data.samples;
587
588    if samples.is_empty() {
589        return 0.0;
590    }
591
592    // Calculate quality metrics
593    let mut quality_score: f32 = 1.0;
594
595    // 1. Check duration (should be at least 1 second, ideally 3-10 seconds)
596    let duration_seconds = samples.len() as f32 / audio_data.sample_rate as f32;
597    if duration_seconds < 1.0 {
598        quality_score *= 0.3; // Very short audio
599    } else if duration_seconds < 3.0 {
600        quality_score *= 0.7; // Short but usable
601    } else if duration_seconds > 30.0 {
602        quality_score *= 0.8; // Very long, might have issues
603    }
604
605    // 2. Check for silence or very low volume
606    let rms = (samples.iter().map(|&x| x * x).sum::<f32>() / samples.len() as f32).sqrt();
607    if rms < 0.01 {
608        quality_score *= 0.2; // Too quiet
609    } else if rms < 0.05 {
610        quality_score *= 0.6; // Quite quiet
611    }
612
613    // 3. Check for clipping
614    let clipped_samples = samples.iter().filter(|&&x| x.abs() > 0.95).count();
615    let clipping_ratio = clipped_samples as f32 / samples.len() as f32;
616    if clipping_ratio > 0.1 {
617        quality_score *= 0.4; // High clipping
618    } else if clipping_ratio > 0.01 {
619        quality_score *= 0.7; // Some clipping
620    }
621
622    // 4. Check dynamic range
623    let max_val = samples.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
624    let min_non_zero = samples
625        .iter()
626        .filter(|&&x| x.abs() > 0.001)
627        .map(|x| x.abs())
628        .fold(1.0f32, f32::min);
629
630    let dynamic_range = if min_non_zero > 0.0 {
631        20.0 * (max_val / min_non_zero).log10()
632    } else {
633        0.0
634    };
635
636    if dynamic_range < 20.0 {
637        quality_score *= 0.5; // Poor dynamic range
638    } else if dynamic_range < 40.0 {
639        quality_score *= 0.8; // Okay dynamic range
640    }
641
642    // 5. Check sample rate appropriateness
643    if audio_data.sample_rate < 16000 {
644        quality_score *= 0.6; // Low sample rate
645    } else if audio_data.sample_rate < 22050 {
646        quality_score *= 0.9; // Acceptable sample rate
647    }
648
649    quality_score.clamp(0.0, 1.0)
650}