Skip to main content

voirs_cli/commands/
vocoder_inference.rs

1//! Vocoder inference command - convert mel spectrograms to audio
2
3use crate::GlobalOptions;
4use candle_core::{Device, Tensor};
5use std::path::{Path, PathBuf};
6use voirs_sdk::Result;
7use voirs_vocoder::models::diffwave::{DiffWave, SamplingMethod};
8
9/// Configuration for vocoder inference operations
10///
11/// Consolidates parameters for converting mel spectrograms to audio waveforms
12/// using trained vocoder models. Supports both single-file and batch processing.
13///
14/// # Examples
15///
16/// Single file inference:
17/// ```no_run
18/// use voirs_cli::commands::vocoder_inference::VocoderInferenceConfig;
19/// use std::path::Path;
20///
21/// let config = VocoderInferenceConfig {
22///     checkpoint: Path::new("./checkpoints/vocoder.safetensors"),
23///     mel_path: Some(Path::new("./input.mel")),
24///     output: Path::new("./output.wav"),
25///     steps: 50,
26///     quality: Some("balanced"),
27///     batch_input: None,
28///     batch_output: None,
29///     show_metrics: false,
30/// };
31/// ```
32///
33/// Batch processing:
34/// ```no_run
35/// use voirs_cli::commands::vocoder_inference::VocoderInferenceConfig;
36/// use std::path::{Path, PathBuf};
37///
38/// let config = VocoderInferenceConfig {
39///     checkpoint: Path::new("./checkpoints/vocoder.safetensors"),
40///     mel_path: None,
41///     output: Path::new("./output_dir"),
42///     steps: 50,
43///     quality: Some("high"),
44///     batch_input: Some(&PathBuf::from("./mel_dir")),
45///     batch_output: Some(&PathBuf::from("./audio_dir")),
46///     show_metrics: true,
47/// };
48/// ```
49#[derive(Debug)]
50pub struct VocoderInferenceConfig<'a> {
51    /// Path to trained vocoder checkpoint file
52    pub checkpoint: &'a Path,
53    /// Optional path to input mel spectrogram file (single file mode)
54    pub mel_path: Option<&'a Path>,
55    /// Output path for generated audio file or batch directory
56    pub output: &'a Path,
57    /// Number of diffusion steps for generation (higher = better quality, slower)
58    pub steps: usize,
59    /// Quality preset: "fast" (20 steps), "balanced" (50 steps), or "high" (100 steps)
60    pub quality: Option<&'a str>,
61    /// Optional directory for batch input (batch mode)
62    pub batch_input: Option<&'a PathBuf>,
63    /// Optional directory for batch output (batch mode)
64    pub batch_output: Option<&'a PathBuf>,
65    /// Whether to display performance metrics after inference
66    pub show_metrics: bool,
67}
68
69/// Quality preset for vocoder inference
70#[derive(Debug, Clone, Copy)]
71enum QualityPreset {
72    Fast,     // 20 steps, faster generation
73    Balanced, // 50 steps, balance of speed and quality
74    High,     // 100 steps, best quality
75}
76
77impl QualityPreset {
78    fn from_str(s: &str) -> Result<Self> {
79        match s.to_lowercase().as_str() {
80            "fast" => Ok(Self::Fast),
81            "balanced" => Ok(Self::Balanced),
82            "high" => Ok(Self::High),
83            _ => Err(voirs_sdk::VoirsError::config_error(format!(
84                "Invalid quality preset: {}. Use 'fast', 'balanced', or 'high'",
85                s
86            ))),
87        }
88    }
89
90    fn steps(&self) -> usize {
91        match self {
92            Self::Fast => 20,
93            Self::Balanced => 50,
94            Self::High => 100,
95        }
96    }
97}
98
99/// Run vocoder inference: mel spectrogram → audio waveform
100///
101/// # Arguments
102/// * `config` - Vocoder inference configuration
103/// * `global` - Global CLI options
104pub async fn run_vocoder_inference(
105    config: VocoderInferenceConfig<'_>,
106    global: &GlobalOptions,
107) -> Result<()> {
108    // Check for batch mode
109    if config.batch_input.is_some() || config.batch_output.is_some() {
110        if config.batch_input.is_none() || config.batch_output.is_none() {
111            return Err(voirs_sdk::VoirsError::config_error(
112                "Batch mode requires both --batch-input and --batch-output",
113            ));
114        }
115        return run_batch_inference(
116            config.checkpoint,
117            config.batch_input.unwrap(),
118            config.batch_output.unwrap(),
119            config.steps,
120            config.quality,
121            config.show_metrics,
122            global,
123        )
124        .await;
125    }
126
127    // Single file mode
128    run_single_inference(
129        config.checkpoint,
130        config.mel_path,
131        config.output,
132        config.steps,
133        config.quality,
134        config.show_metrics,
135        global,
136    )
137    .await
138}
139
140/// Run single file inference
141async fn run_single_inference(
142    checkpoint: &Path,
143    mel_path: Option<&Path>,
144    output: &Path,
145    mut steps: usize,
146    quality: Option<&str>,
147    show_metrics: bool,
148    global: &GlobalOptions,
149) -> Result<()> {
150    // Apply quality preset if specified
151    if let Some(quality_str) = quality {
152        let preset = QualityPreset::from_str(quality_str)?;
153        steps = preset.steps();
154        if !global.quiet {
155            println!("Using quality preset: {:?} ({} steps)", preset, steps);
156        }
157    }
158    use std::time::Instant;
159    let total_start = Instant::now();
160
161    if !global.quiet {
162        println!("🎵 VoiRS Vocoder Inference");
163        println!("═══════════════════════════════════════");
164        println!("Checkpoint: {}", checkpoint.display());
165        if let Some(mel) = mel_path {
166            println!("Mel spec:   {}", mel.display());
167        } else {
168            println!("Mel spec:   <generating dummy>");
169        }
170        println!("Output:     {}", output.display());
171        println!("Steps:      {}", steps);
172        println!("═══════════════════════════════════════\n");
173    }
174
175    // Determine device
176    let device = if global.gpu {
177        #[cfg(feature = "cuda")]
178        {
179            Device::new_cuda(0).unwrap_or(Device::Cpu)
180        }
181        #[cfg(not(feature = "cuda"))]
182        {
183            if !global.quiet {
184                println!("⚠️  GPU requested but CUDA not available, using CPU");
185            }
186            Device::Cpu
187        }
188    } else {
189        Device::Cpu
190    };
191
192    if !global.quiet {
193        println!("📦 Loading DiffWave model from checkpoint...");
194    }
195
196    // Load DiffWave model from SafeTensors
197    let model = DiffWave::load_from_safetensors(checkpoint, device.clone()).map_err(|e| {
198        voirs_sdk::VoirsError::config_error(format!("Failed to load DiffWave model: {}", e))
199    })?;
200
201    if !global.quiet {
202        println!("✓ Model loaded successfully");
203        println!("  Parameters: {}", model.num_parameters());
204        println!();
205    }
206
207    // Load or generate mel spectrogram
208    let mel_tensor = if let Some(mel_file) = mel_path {
209        if !global.quiet {
210            println!("📊 Loading mel spectrogram from file...");
211        }
212        load_mel_spectrogram(mel_file, &device)?
213    } else {
214        if !global.quiet {
215            println!("📊 Generating dummy mel spectrogram...");
216        }
217        generate_dummy_mel_spectrogram(&device)?
218    };
219
220    if !global.quiet {
221        println!("✓ Mel spectrogram ready");
222        println!("  Shape: {:?}", mel_tensor.dims());
223        println!();
224    }
225
226    // Run inference
227    if !global.quiet {
228        println!("🔄 Running vocoder inference...");
229        println!("  Sampling method: DDIM");
230        println!("  Diffusion steps: {}", steps);
231    }
232
233    let sampling_method = SamplingMethod::DDIM { steps, eta: 0.0 };
234    let audio_tensor = model
235        .inference(&mel_tensor, sampling_method)
236        .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Inference failed: {}", e)))?;
237
238    if !global.quiet {
239        println!("✓ Inference complete");
240        println!("  Audio shape: {:?}", audio_tensor.dims());
241        println!();
242    }
243
244    // Save audio
245    if !global.quiet {
246        println!("💾 Saving audio to {}...", output.display());
247    }
248
249    save_audio_tensor(&audio_tensor, output, 22050)?;
250
251    let total_time = total_start.elapsed();
252
253    if !global.quiet {
254        println!("✅ Vocoder inference complete!");
255        println!("  Output: {}", output.display());
256    }
257
258    // Show metrics if requested
259    if show_metrics {
260        println!();
261        println!("╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌");
262        println!("Performance Metrics:");
263        println!("╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌");
264        println!("Total time:        {:.3}s", total_time.as_secs_f64());
265        if let Ok(dims) = audio_tensor.dims3() {
266            let (_, _, samples) = dims;
267            let duration_sec = samples as f64 / 22050.0;
268            let rtf = total_time.as_secs_f64() / duration_sec;
269            println!("Audio duration:    {:.2}s", duration_sec);
270            println!("Real-time factor:  {:.3}x", rtf);
271        }
272        println!("╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌");
273    }
274
275    Ok(())
276}
277
278/// Load mel spectrogram from file
279///
280/// Supports multiple formats:
281/// - NumPy (.npy): Native parser for NumPy binary format
282/// - SafeTensors (.safetensors): Uses safetensors crate
283/// - PyTorch (.pt, .pth): Requires conversion (see error message for guidance)
284fn load_mel_spectrogram(path: &Path, device: &Device) -> Result<Tensor> {
285    // Check file extension and load appropriately
286    match path.extension().and_then(|e| e.to_str()) {
287        Some("npy") => load_numpy_file(path, device),
288        Some("pt") | Some("pth") => load_pytorch_file(path, device),
289        Some("safetensors") => load_safetensors_file(path, device),
290        _ => Err(voirs_sdk::VoirsError::UnsupportedFileFormat {
291            path: path.to_path_buf(),
292            format: path
293                .extension()
294                .and_then(|e| e.to_str())
295                .unwrap_or("unknown")
296                .to_string(),
297        }),
298    }
299}
300
301/// Load NumPy file (.npy)
302fn load_numpy_file(path: &Path, device: &Device) -> Result<Tensor> {
303    // Read entire file
304    let data = std::fs::read(path).map_err(|e| voirs_sdk::VoirsError::IoError {
305        path: path.to_path_buf(),
306        operation: voirs_sdk::error::IoOperation::Read,
307        source: e,
308    })?;
309
310    // Parse NumPy .npy format manually
311    // Format: Magic (6 bytes) + Version (2 bytes) + Header Len (2/4 bytes) + Header (JSON-like dict) + Data
312
313    // Check magic number: b'\x93NUMPY'
314    if data.len() < 10 || &data[0..6] != b"\x93NUMPY" {
315        return Err(voirs_sdk::VoirsError::config_error(
316            "Invalid NumPy file: magic number mismatch",
317        ));
318    }
319
320    let major_version = data[6];
321    let minor_version = data[7];
322
323    if major_version != 1 && major_version != 2 {
324        return Err(voirs_sdk::VoirsError::config_error(format!(
325            "Unsupported NumPy version: {}.{}",
326            major_version, minor_version
327        )));
328    }
329
330    // Read header length (little-endian)
331    let header_len = if major_version == 1 {
332        u16::from_le_bytes([data[8], data[9]]) as usize
333    } else {
334        u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize
335    };
336
337    let header_start = if major_version == 1 { 10 } else { 12 };
338    let header_end = header_start + header_len;
339
340    if data.len() < header_end {
341        return Err(voirs_sdk::VoirsError::config_error(
342            "Invalid NumPy file: truncated header",
343        ));
344    }
345
346    // Parse header (Python dict-like string)
347    let header_str = std::str::from_utf8(&data[header_start..header_end])
348        .map_err(|_| voirs_sdk::VoirsError::config_error("Invalid NumPy header: not UTF-8"))?;
349
350    // Extract shape from header (format: 'shape': (dim0, dim1, ...), )
351    let shape = parse_numpy_shape(header_str)?;
352
353    // Extract dtype (we only support float32 for now)
354    let dtype = parse_numpy_dtype(header_str)?;
355    if dtype != "f4" && dtype != "<f4" && dtype != "float32" {
356        return Err(voirs_sdk::VoirsError::config_error(format!(
357            "Unsupported NumPy dtype: {}. Only float32 is supported.",
358            dtype
359        )));
360    }
361
362    // Read data
363    let data_start = header_end;
364    let num_elements: usize = shape.iter().product();
365    let expected_bytes = num_elements * 4; // f32 = 4 bytes
366
367    if data.len() < data_start + expected_bytes {
368        return Err(voirs_sdk::VoirsError::config_error(
369            "Invalid NumPy file: insufficient data",
370        ));
371    }
372
373    // Convert bytes to f32 (little-endian)
374    let f32_data: Vec<f32> = data[data_start..data_start + expected_bytes]
375        .chunks_exact(4)
376        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
377        .collect();
378
379    // Create tensor
380    let tensor = Tensor::from_vec(f32_data, shape.as_slice(), device).map_err(|e| {
381        voirs_sdk::VoirsError::config_error(format!(
382            "Failed to create tensor from NumPy data: {}",
383            e
384        ))
385    })?;
386
387    Ok(tensor)
388}
389
390/// Parse shape from NumPy header
391fn parse_numpy_shape(header: &str) -> Result<Vec<usize>> {
392    // Header format: {'descr': '<f4', 'fortran_order': False, 'shape': (80, 100), }
393    // Extract shape tuple
394    let shape_start = header
395        .find("'shape':")
396        .or_else(|| header.find("\"shape\":"))
397        .ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy header missing 'shape' field"))?;
398
399    let shape_str = &header[shape_start..];
400    let tuple_start = shape_str
401        .find('(')
402        .ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy shape malformed"))?;
403    let tuple_end = shape_str
404        .find(')')
405        .ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy shape malformed"))?;
406
407    let tuple_content = &shape_str[tuple_start + 1..tuple_end];
408
409    if tuple_content.trim().is_empty() {
410        // Scalar array
411        return Ok(vec![1]);
412    }
413
414    // Parse dimensions
415    let dims: Result<Vec<usize>> = tuple_content
416        .split(',')
417        .filter(|s| !s.trim().is_empty())
418        .map(|s| {
419            s.trim().parse::<usize>().map_err(|_| {
420                voirs_sdk::VoirsError::config_error(format!("Invalid dimension: {}", s))
421            })
422        })
423        .collect();
424
425    dims
426}
427
428/// Parse dtype from NumPy header
429fn parse_numpy_dtype(header: &str) -> Result<String> {
430    // Extract descr field
431    let descr_start = header
432        .find("'descr':")
433        .or_else(|| header.find("\"descr\":"))
434        .ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy header missing 'descr' field"))?;
435
436    let descr_str = &header[descr_start..];
437
438    // Find the value (between quotes)
439    let value_start = descr_str
440        .find('\'')
441        .or_else(|| descr_str.find('"'))
442        .ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy descr malformed"))?;
443
444    let value_str = &descr_str[value_start + 1..];
445    let value_end = value_str
446        .find('\'')
447        .or_else(|| value_str.find('"'))
448        .ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy descr malformed"))?;
449
450    Ok(value_str[..value_end].to_string())
451}
452
453/// Load PyTorch file (.pt, .pth)
454fn load_pytorch_file(path: &Path, _device: &Device) -> Result<Tensor> {
455    // PyTorch .pt files use Python's pickle format, which is complex to parse in pure Rust
456    // For now, we provide helpful guidance for users
457
458    Err(voirs_sdk::VoirsError::config_error(format!(
459        "PyTorch .pt file loading requires Python interop or conversion.\n\
460        \n\
461        Alternatives:\n\
462        1. Convert to NumPy: python -c \"import torch, numpy as np; np.save('output.npy', torch.load('{}').numpy())\"\n\
463        2. Convert to SafeTensors: Use safetensors.torch.save_file() in Python\n\
464        3. Use ONNX format: Export model to ONNX and use --input-format onnx\n\
465        \n\
466        For native PyTorch support, compile with 'tch-rs' feature (requires libtorch).",
467        path.display()
468    )))
469}
470
471/// Load SafeTensors file
472fn load_safetensors_file(path: &Path, device: &Device) -> Result<Tensor> {
473    use safetensors::SafeTensors;
474
475    let data = std::fs::read(path).map_err(|e| voirs_sdk::VoirsError::IoError {
476        path: path.to_path_buf(),
477        operation: voirs_sdk::error::IoOperation::Read,
478        source: e,
479    })?;
480
481    let tensors = SafeTensors::deserialize(&data).map_err(|e| {
482        voirs_sdk::VoirsError::config_error(format!("Failed to load SafeTensors: {}", e))
483    })?;
484
485    // Assume the first tensor is the mel spectrogram
486    let names = tensors.names();
487    let tensor_name = names
488        .first()
489        .ok_or_else(|| voirs_sdk::VoirsError::config_error("No tensors found in file"))?;
490
491    let tensor_view = tensors
492        .tensor(tensor_name)
493        .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to get tensor: {}", e)))?;
494
495    let shape: Vec<usize> = tensor_view.shape().to_vec();
496    let data = tensor_view.data();
497
498    // Convert bytes to f32
499    let f32_data: Vec<f32> = data
500        .chunks_exact(4)
501        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
502        .collect();
503
504    let tensor = Tensor::from_vec(f32_data, shape.as_slice(), device).map_err(|e| {
505        voirs_sdk::VoirsError::config_error(format!("Failed to create tensor: {}", e))
506    })?;
507    Ok(tensor)
508}
509
510/// Generate dummy mel spectrogram for testing
511fn generate_dummy_mel_spectrogram(device: &Device) -> Result<Tensor> {
512    // Create a dummy mel spectrogram: [batch=1, mel_channels=80, time=100]
513    let batch_size = 1;
514    let mel_channels = 80;
515    let time_frames = 100;
516
517    // Generate random values (in practice, this would be from an acoustic model)
518    let data: Vec<f32> = (0..(batch_size * mel_channels * time_frames))
519        .map(|_| fastrand::f32() * 2.0 - 1.0) // Random values between -1 and 1
520        .collect();
521
522    let tensor =
523        Tensor::from_vec(data, (batch_size, mel_channels, time_frames), device).map_err(|e| {
524            voirs_sdk::VoirsError::config_error(format!("Failed to create tensor: {}", e))
525        })?;
526    Ok(tensor)
527}
528
529/// Save audio tensor to WAV file
530fn save_audio_tensor(tensor: &Tensor, output: &Path, sample_rate: u32) -> Result<()> {
531    use hound::{WavSpec, WavWriter};
532
533    // Extract audio data from tensor
534    let audio_data: Vec<f32> = tensor
535        .flatten_all()
536        .map_err(|e| {
537            voirs_sdk::VoirsError::config_error(format!("Failed to flatten tensor: {}", e))
538        })?
539        .to_vec1()
540        .map_err(|e| {
541            voirs_sdk::VoirsError::config_error(format!("Failed to convert tensor to vec: {}", e))
542        })?;
543
544    // Create WAV spec
545    let spec = WavSpec {
546        channels: 1,
547        sample_rate,
548        bits_per_sample: 16,
549        sample_format: hound::SampleFormat::Int,
550    };
551
552    // Create WAV writer
553    let mut writer =
554        WavWriter::create(output, spec).map_err(|e| voirs_sdk::VoirsError::IoError {
555            path: output.to_path_buf(),
556            operation: voirs_sdk::error::IoOperation::Write,
557            source: std::io::Error::new(std::io::ErrorKind::Other, e),
558        })?;
559
560    // Write samples (convert f32 to i16)
561    for &sample in &audio_data {
562        let sample_i16 = (sample * 32767.0).clamp(-32768.0, 32767.0) as i16;
563        writer
564            .write_sample(sample_i16)
565            .map_err(|e| voirs_sdk::VoirsError::IoError {
566                path: output.to_path_buf(),
567                operation: voirs_sdk::error::IoOperation::Write,
568                source: std::io::Error::new(std::io::ErrorKind::Other, e),
569            })?;
570    }
571
572    writer
573        .finalize()
574        .map_err(|e| voirs_sdk::VoirsError::IoError {
575            path: output.to_path_buf(),
576            operation: voirs_sdk::error::IoOperation::Write,
577            source: std::io::Error::new(std::io::ErrorKind::Other, e),
578        })?;
579
580    Ok(())
581}
582
583/// Run batch inference on directory of mel spectrograms
584async fn run_batch_inference(
585    checkpoint: &Path,
586    input_dir: &Path,
587    output_dir: &Path,
588    mut steps: usize,
589    quality: Option<&str>,
590    show_metrics: bool,
591    global: &GlobalOptions,
592) -> Result<()> {
593    use std::time::Instant;
594
595    // Apply quality preset
596    if let Some(quality_str) = quality {
597        let preset = QualityPreset::from_str(quality_str)?;
598        steps = preset.steps();
599    }
600
601    if !global.quiet {
602        println!("🎵 VoiRS Batch Vocoder Inference");
603        println!("═══════════════════════════════════════");
604        println!("Checkpoint:  {}", checkpoint.display());
605        println!("Input dir:   {}", input_dir.display());
606        println!("Output dir:  {}", output_dir.display());
607        println!("Steps:       {}", steps);
608        if let Some(q) = quality {
609            println!("Quality:     {}", q);
610        }
611        println!("═══════════════════════════════════════\n");
612    }
613
614    // Validate input directory
615    if !input_dir.is_dir() {
616        return Err(voirs_sdk::VoirsError::config_error(format!(
617            "Input directory not found: {}",
618            input_dir.display()
619        )));
620    }
621
622    // Create output directory
623    std::fs::create_dir_all(output_dir)?;
624
625    // Find all mel spectrogram files
626    let mel_files: Vec<_> = std::fs::read_dir(input_dir)
627        .map_err(|e| voirs_sdk::VoirsError::IoError {
628            path: input_dir.to_path_buf(),
629            operation: voirs_sdk::error::IoOperation::Read,
630            source: e,
631        })?
632        .filter_map(|entry| entry.ok())
633        .map(|entry| entry.path())
634        .filter(|path| {
635            path.extension()
636                .and_then(|e| e.to_str())
637                .map(|ext| matches!(ext, "npy" | "safetensors" | "pt" | "pth"))
638                .unwrap_or(false)
639        })
640        .collect();
641
642    if mel_files.is_empty() {
643        return Err(voirs_sdk::VoirsError::config_error(
644            "No mel spectrogram files found in input directory",
645        ));
646    }
647
648    if !global.quiet {
649        println!("Found {} mel spectrogram files", mel_files.len());
650        println!();
651    }
652
653    // Load model once
654    let device = if global.gpu {
655        #[cfg(feature = "cuda")]
656        {
657            Device::new_cuda(0).unwrap_or(Device::Cpu)
658        }
659        #[cfg(not(feature = "cuda"))]
660        {
661            Device::Cpu
662        }
663    } else {
664        Device::Cpu
665    };
666
667    let model = DiffWave::load_from_safetensors(checkpoint, device.clone())?;
668
669    // Performance tracking
670    let mut total_time = 0.0;
671    let mut successful = 0;
672    let mut failed = 0;
673    let batch_start = Instant::now();
674
675    // Process each file
676    for (idx, mel_file) in mel_files.iter().enumerate() {
677        let file_start = Instant::now();
678
679        let output_name = mel_file
680            .file_stem()
681            .and_then(|n| n.to_str())
682            .unwrap_or("output");
683        let output_path = output_dir.join(format!("{}.wav", output_name));
684
685        if !global.quiet {
686            println!(
687                "[{}/{}] Processing {}...",
688                idx + 1,
689                mel_files.len(),
690                mel_file.display()
691            );
692        }
693
694        // Process file
695        let result =
696            process_single_mel(&model, mel_file, &output_path, steps, &device, global).await;
697
698        let file_time = file_start.elapsed().as_secs_f64();
699        total_time += file_time;
700
701        match result {
702            Ok(_) => {
703                successful += 1;
704                if !global.quiet {
705                    println!("  ✓ Complete in {:.2}s", file_time);
706                }
707            }
708            Err(e) => {
709                failed += 1;
710                eprintln!("  ✗ Failed: {}", e);
711            }
712        }
713    }
714
715    let total_elapsed = batch_start.elapsed().as_secs_f64();
716
717    // Display results
718    if !global.quiet || show_metrics {
719        println!();
720        println!("╔═══════════════════════════════════════╗");
721        println!("║       Batch Inference Complete        ║");
722        println!("╠═══════════════════════════════════════╣");
723        println!("║ Total files:    {:<21} ║", mel_files.len());
724        println!("║ Successful:     {:<21} ║", successful);
725        println!("║ Failed:         {:<21} ║", failed);
726        println!("║ Total time:     {:<18.2}s ║", total_elapsed);
727        println!(
728            "║ Avg time/file:  {:<18.2}s ║",
729            total_time / mel_files.len() as f64
730        );
731        if successful > 0 {
732            println!(
733                "║ Throughput:     {:<18.2}/s ║",
734                successful as f64 / total_elapsed
735            );
736        }
737        println!("╚═══════════════════════════════════════╝");
738    }
739
740    Ok(())
741}
742
743/// Process a single mel spectrogram file
744async fn process_single_mel(
745    model: &DiffWave,
746    mel_path: &Path,
747    output_path: &Path,
748    steps: usize,
749    device: &Device,
750    _global: &GlobalOptions,
751) -> Result<()> {
752    // Load mel spectrogram
753    let mel_tensor = load_mel_spectrogram(mel_path, device)?;
754
755    // Run inference
756    let sampling_method = SamplingMethod::DDIM { steps, eta: 0.0 };
757    let audio_tensor = model
758        .inference(&mel_tensor, sampling_method)
759        .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Inference failed: {}", e)))?;
760
761    // Save audio
762    save_audio_tensor(&audio_tensor, output_path, 22050)?;
763
764    Ok(())
765}