1use crate::GlobalOptions;
4use candle_core::{Device, Tensor};
5use std::path::{Path, PathBuf};
6use voirs_sdk::Result;
7use voirs_vocoder::models::diffwave::{DiffWave, SamplingMethod};
8
9#[derive(Debug)]
50pub struct VocoderInferenceConfig<'a> {
51 pub checkpoint: &'a Path,
53 pub mel_path: Option<&'a Path>,
55 pub output: &'a Path,
57 pub steps: usize,
59 pub quality: Option<&'a str>,
61 pub batch_input: Option<&'a PathBuf>,
63 pub batch_output: Option<&'a PathBuf>,
65 pub show_metrics: bool,
67}
68
69#[derive(Debug, Clone, Copy)]
71enum QualityPreset {
72 Fast, Balanced, High, }
76
77impl QualityPreset {
78 fn from_str(s: &str) -> Result<Self> {
79 match s.to_lowercase().as_str() {
80 "fast" => Ok(Self::Fast),
81 "balanced" => Ok(Self::Balanced),
82 "high" => Ok(Self::High),
83 _ => Err(voirs_sdk::VoirsError::config_error(format!(
84 "Invalid quality preset: {}. Use 'fast', 'balanced', or 'high'",
85 s
86 ))),
87 }
88 }
89
90 fn steps(&self) -> usize {
91 match self {
92 Self::Fast => 20,
93 Self::Balanced => 50,
94 Self::High => 100,
95 }
96 }
97}
98
99pub async fn run_vocoder_inference(
105 config: VocoderInferenceConfig<'_>,
106 global: &GlobalOptions,
107) -> Result<()> {
108 if config.batch_input.is_some() || config.batch_output.is_some() {
110 if config.batch_input.is_none() || config.batch_output.is_none() {
111 return Err(voirs_sdk::VoirsError::config_error(
112 "Batch mode requires both --batch-input and --batch-output",
113 ));
114 }
115 return run_batch_inference(
116 config.checkpoint,
117 config.batch_input.unwrap(),
118 config.batch_output.unwrap(),
119 config.steps,
120 config.quality,
121 config.show_metrics,
122 global,
123 )
124 .await;
125 }
126
127 run_single_inference(
129 config.checkpoint,
130 config.mel_path,
131 config.output,
132 config.steps,
133 config.quality,
134 config.show_metrics,
135 global,
136 )
137 .await
138}
139
140async fn run_single_inference(
142 checkpoint: &Path,
143 mel_path: Option<&Path>,
144 output: &Path,
145 mut steps: usize,
146 quality: Option<&str>,
147 show_metrics: bool,
148 global: &GlobalOptions,
149) -> Result<()> {
150 if let Some(quality_str) = quality {
152 let preset = QualityPreset::from_str(quality_str)?;
153 steps = preset.steps();
154 if !global.quiet {
155 println!("Using quality preset: {:?} ({} steps)", preset, steps);
156 }
157 }
158 use std::time::Instant;
159 let total_start = Instant::now();
160
161 if !global.quiet {
162 println!("🎵 VoiRS Vocoder Inference");
163 println!("═══════════════════════════════════════");
164 println!("Checkpoint: {}", checkpoint.display());
165 if let Some(mel) = mel_path {
166 println!("Mel spec: {}", mel.display());
167 } else {
168 println!("Mel spec: <generating dummy>");
169 }
170 println!("Output: {}", output.display());
171 println!("Steps: {}", steps);
172 println!("═══════════════════════════════════════\n");
173 }
174
175 let device = if global.gpu {
177 #[cfg(feature = "cuda")]
178 {
179 Device::new_cuda(0).unwrap_or(Device::Cpu)
180 }
181 #[cfg(not(feature = "cuda"))]
182 {
183 if !global.quiet {
184 println!("⚠️ GPU requested but CUDA not available, using CPU");
185 }
186 Device::Cpu
187 }
188 } else {
189 Device::Cpu
190 };
191
192 if !global.quiet {
193 println!("📦 Loading DiffWave model from checkpoint...");
194 }
195
196 let model = DiffWave::load_from_safetensors(checkpoint, device.clone()).map_err(|e| {
198 voirs_sdk::VoirsError::config_error(format!("Failed to load DiffWave model: {}", e))
199 })?;
200
201 if !global.quiet {
202 println!("✓ Model loaded successfully");
203 println!(" Parameters: {}", model.num_parameters());
204 println!();
205 }
206
207 let mel_tensor = if let Some(mel_file) = mel_path {
209 if !global.quiet {
210 println!("📊 Loading mel spectrogram from file...");
211 }
212 load_mel_spectrogram(mel_file, &device)?
213 } else {
214 if !global.quiet {
215 println!("📊 Generating dummy mel spectrogram...");
216 }
217 generate_dummy_mel_spectrogram(&device)?
218 };
219
220 if !global.quiet {
221 println!("✓ Mel spectrogram ready");
222 println!(" Shape: {:?}", mel_tensor.dims());
223 println!();
224 }
225
226 if !global.quiet {
228 println!("🔄 Running vocoder inference...");
229 println!(" Sampling method: DDIM");
230 println!(" Diffusion steps: {}", steps);
231 }
232
233 let sampling_method = SamplingMethod::DDIM { steps, eta: 0.0 };
234 let audio_tensor = model
235 .inference(&mel_tensor, sampling_method)
236 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Inference failed: {}", e)))?;
237
238 if !global.quiet {
239 println!("✓ Inference complete");
240 println!(" Audio shape: {:?}", audio_tensor.dims());
241 println!();
242 }
243
244 if !global.quiet {
246 println!("💾 Saving audio to {}...", output.display());
247 }
248
249 save_audio_tensor(&audio_tensor, output, 22050)?;
250
251 let total_time = total_start.elapsed();
252
253 if !global.quiet {
254 println!("✅ Vocoder inference complete!");
255 println!(" Output: {}", output.display());
256 }
257
258 if show_metrics {
260 println!();
261 println!("╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌");
262 println!("Performance Metrics:");
263 println!("╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌");
264 println!("Total time: {:.3}s", total_time.as_secs_f64());
265 if let Ok(dims) = audio_tensor.dims3() {
266 let (_, _, samples) = dims;
267 let duration_sec = samples as f64 / 22050.0;
268 let rtf = total_time.as_secs_f64() / duration_sec;
269 println!("Audio duration: {:.2}s", duration_sec);
270 println!("Real-time factor: {:.3}x", rtf);
271 }
272 println!("╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌");
273 }
274
275 Ok(())
276}
277
278fn load_mel_spectrogram(path: &Path, device: &Device) -> Result<Tensor> {
285 match path.extension().and_then(|e| e.to_str()) {
287 Some("npy") => load_numpy_file(path, device),
288 Some("pt") | Some("pth") => load_pytorch_file(path, device),
289 Some("safetensors") => load_safetensors_file(path, device),
290 _ => Err(voirs_sdk::VoirsError::UnsupportedFileFormat {
291 path: path.to_path_buf(),
292 format: path
293 .extension()
294 .and_then(|e| e.to_str())
295 .unwrap_or("unknown")
296 .to_string(),
297 }),
298 }
299}
300
301fn load_numpy_file(path: &Path, device: &Device) -> Result<Tensor> {
303 let data = std::fs::read(path).map_err(|e| voirs_sdk::VoirsError::IoError {
305 path: path.to_path_buf(),
306 operation: voirs_sdk::error::IoOperation::Read,
307 source: e,
308 })?;
309
310 if data.len() < 10 || &data[0..6] != b"\x93NUMPY" {
315 return Err(voirs_sdk::VoirsError::config_error(
316 "Invalid NumPy file: magic number mismatch",
317 ));
318 }
319
320 let major_version = data[6];
321 let minor_version = data[7];
322
323 if major_version != 1 && major_version != 2 {
324 return Err(voirs_sdk::VoirsError::config_error(format!(
325 "Unsupported NumPy version: {}.{}",
326 major_version, minor_version
327 )));
328 }
329
330 let header_len = if major_version == 1 {
332 u16::from_le_bytes([data[8], data[9]]) as usize
333 } else {
334 u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize
335 };
336
337 let header_start = if major_version == 1 { 10 } else { 12 };
338 let header_end = header_start + header_len;
339
340 if data.len() < header_end {
341 return Err(voirs_sdk::VoirsError::config_error(
342 "Invalid NumPy file: truncated header",
343 ));
344 }
345
346 let header_str = std::str::from_utf8(&data[header_start..header_end])
348 .map_err(|_| voirs_sdk::VoirsError::config_error("Invalid NumPy header: not UTF-8"))?;
349
350 let shape = parse_numpy_shape(header_str)?;
352
353 let dtype = parse_numpy_dtype(header_str)?;
355 if dtype != "f4" && dtype != "<f4" && dtype != "float32" {
356 return Err(voirs_sdk::VoirsError::config_error(format!(
357 "Unsupported NumPy dtype: {}. Only float32 is supported.",
358 dtype
359 )));
360 }
361
362 let data_start = header_end;
364 let num_elements: usize = shape.iter().product();
365 let expected_bytes = num_elements * 4; if data.len() < data_start + expected_bytes {
368 return Err(voirs_sdk::VoirsError::config_error(
369 "Invalid NumPy file: insufficient data",
370 ));
371 }
372
373 let f32_data: Vec<f32> = data[data_start..data_start + expected_bytes]
375 .chunks_exact(4)
376 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
377 .collect();
378
379 let tensor = Tensor::from_vec(f32_data, shape.as_slice(), device).map_err(|e| {
381 voirs_sdk::VoirsError::config_error(format!(
382 "Failed to create tensor from NumPy data: {}",
383 e
384 ))
385 })?;
386
387 Ok(tensor)
388}
389
390fn parse_numpy_shape(header: &str) -> Result<Vec<usize>> {
392 let shape_start = header
395 .find("'shape':")
396 .or_else(|| header.find("\"shape\":"))
397 .ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy header missing 'shape' field"))?;
398
399 let shape_str = &header[shape_start..];
400 let tuple_start = shape_str
401 .find('(')
402 .ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy shape malformed"))?;
403 let tuple_end = shape_str
404 .find(')')
405 .ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy shape malformed"))?;
406
407 let tuple_content = &shape_str[tuple_start + 1..tuple_end];
408
409 if tuple_content.trim().is_empty() {
410 return Ok(vec![1]);
412 }
413
414 let dims: Result<Vec<usize>> = tuple_content
416 .split(',')
417 .filter(|s| !s.trim().is_empty())
418 .map(|s| {
419 s.trim().parse::<usize>().map_err(|_| {
420 voirs_sdk::VoirsError::config_error(format!("Invalid dimension: {}", s))
421 })
422 })
423 .collect();
424
425 dims
426}
427
428fn parse_numpy_dtype(header: &str) -> Result<String> {
430 let descr_start = header
432 .find("'descr':")
433 .or_else(|| header.find("\"descr\":"))
434 .ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy header missing 'descr' field"))?;
435
436 let descr_str = &header[descr_start..];
437
438 let value_start = descr_str
440 .find('\'')
441 .or_else(|| descr_str.find('"'))
442 .ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy descr malformed"))?;
443
444 let value_str = &descr_str[value_start + 1..];
445 let value_end = value_str
446 .find('\'')
447 .or_else(|| value_str.find('"'))
448 .ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy descr malformed"))?;
449
450 Ok(value_str[..value_end].to_string())
451}
452
453fn load_pytorch_file(path: &Path, _device: &Device) -> Result<Tensor> {
455 Err(voirs_sdk::VoirsError::config_error(format!(
459 "PyTorch .pt file loading requires Python interop or conversion.\n\
460 \n\
461 Alternatives:\n\
462 1. Convert to NumPy: python -c \"import torch, numpy as np; np.save('output.npy', torch.load('{}').numpy())\"\n\
463 2. Convert to SafeTensors: Use safetensors.torch.save_file() in Python\n\
464 3. Use ONNX format: Export model to ONNX and use --input-format onnx\n\
465 \n\
466 For native PyTorch support, compile with 'tch-rs' feature (requires libtorch).",
467 path.display()
468 )))
469}
470
471fn load_safetensors_file(path: &Path, device: &Device) -> Result<Tensor> {
473 use safetensors::SafeTensors;
474
475 let data = std::fs::read(path).map_err(|e| voirs_sdk::VoirsError::IoError {
476 path: path.to_path_buf(),
477 operation: voirs_sdk::error::IoOperation::Read,
478 source: e,
479 })?;
480
481 let tensors = SafeTensors::deserialize(&data).map_err(|e| {
482 voirs_sdk::VoirsError::config_error(format!("Failed to load SafeTensors: {}", e))
483 })?;
484
485 let names = tensors.names();
487 let tensor_name = names
488 .first()
489 .ok_or_else(|| voirs_sdk::VoirsError::config_error("No tensors found in file"))?;
490
491 let tensor_view = tensors
492 .tensor(tensor_name)
493 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to get tensor: {}", e)))?;
494
495 let shape: Vec<usize> = tensor_view.shape().to_vec();
496 let data = tensor_view.data();
497
498 let f32_data: Vec<f32> = data
500 .chunks_exact(4)
501 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
502 .collect();
503
504 let tensor = Tensor::from_vec(f32_data, shape.as_slice(), device).map_err(|e| {
505 voirs_sdk::VoirsError::config_error(format!("Failed to create tensor: {}", e))
506 })?;
507 Ok(tensor)
508}
509
510fn generate_dummy_mel_spectrogram(device: &Device) -> Result<Tensor> {
512 let batch_size = 1;
514 let mel_channels = 80;
515 let time_frames = 100;
516
517 let data: Vec<f32> = (0..(batch_size * mel_channels * time_frames))
519 .map(|_| fastrand::f32() * 2.0 - 1.0) .collect();
521
522 let tensor =
523 Tensor::from_vec(data, (batch_size, mel_channels, time_frames), device).map_err(|e| {
524 voirs_sdk::VoirsError::config_error(format!("Failed to create tensor: {}", e))
525 })?;
526 Ok(tensor)
527}
528
529fn save_audio_tensor(tensor: &Tensor, output: &Path, sample_rate: u32) -> Result<()> {
531 use hound::{WavSpec, WavWriter};
532
533 let audio_data: Vec<f32> = tensor
535 .flatten_all()
536 .map_err(|e| {
537 voirs_sdk::VoirsError::config_error(format!("Failed to flatten tensor: {}", e))
538 })?
539 .to_vec1()
540 .map_err(|e| {
541 voirs_sdk::VoirsError::config_error(format!("Failed to convert tensor to vec: {}", e))
542 })?;
543
544 let spec = WavSpec {
546 channels: 1,
547 sample_rate,
548 bits_per_sample: 16,
549 sample_format: hound::SampleFormat::Int,
550 };
551
552 let mut writer =
554 WavWriter::create(output, spec).map_err(|e| voirs_sdk::VoirsError::IoError {
555 path: output.to_path_buf(),
556 operation: voirs_sdk::error::IoOperation::Write,
557 source: std::io::Error::new(std::io::ErrorKind::Other, e),
558 })?;
559
560 for &sample in &audio_data {
562 let sample_i16 = (sample * 32767.0).clamp(-32768.0, 32767.0) as i16;
563 writer
564 .write_sample(sample_i16)
565 .map_err(|e| voirs_sdk::VoirsError::IoError {
566 path: output.to_path_buf(),
567 operation: voirs_sdk::error::IoOperation::Write,
568 source: std::io::Error::new(std::io::ErrorKind::Other, e),
569 })?;
570 }
571
572 writer
573 .finalize()
574 .map_err(|e| voirs_sdk::VoirsError::IoError {
575 path: output.to_path_buf(),
576 operation: voirs_sdk::error::IoOperation::Write,
577 source: std::io::Error::new(std::io::ErrorKind::Other, e),
578 })?;
579
580 Ok(())
581}
582
583async fn run_batch_inference(
585 checkpoint: &Path,
586 input_dir: &Path,
587 output_dir: &Path,
588 mut steps: usize,
589 quality: Option<&str>,
590 show_metrics: bool,
591 global: &GlobalOptions,
592) -> Result<()> {
593 use std::time::Instant;
594
595 if let Some(quality_str) = quality {
597 let preset = QualityPreset::from_str(quality_str)?;
598 steps = preset.steps();
599 }
600
601 if !global.quiet {
602 println!("🎵 VoiRS Batch Vocoder Inference");
603 println!("═══════════════════════════════════════");
604 println!("Checkpoint: {}", checkpoint.display());
605 println!("Input dir: {}", input_dir.display());
606 println!("Output dir: {}", output_dir.display());
607 println!("Steps: {}", steps);
608 if let Some(q) = quality {
609 println!("Quality: {}", q);
610 }
611 println!("═══════════════════════════════════════\n");
612 }
613
614 if !input_dir.is_dir() {
616 return Err(voirs_sdk::VoirsError::config_error(format!(
617 "Input directory not found: {}",
618 input_dir.display()
619 )));
620 }
621
622 std::fs::create_dir_all(output_dir)?;
624
625 let mel_files: Vec<_> = std::fs::read_dir(input_dir)
627 .map_err(|e| voirs_sdk::VoirsError::IoError {
628 path: input_dir.to_path_buf(),
629 operation: voirs_sdk::error::IoOperation::Read,
630 source: e,
631 })?
632 .filter_map(|entry| entry.ok())
633 .map(|entry| entry.path())
634 .filter(|path| {
635 path.extension()
636 .and_then(|e| e.to_str())
637 .map(|ext| matches!(ext, "npy" | "safetensors" | "pt" | "pth"))
638 .unwrap_or(false)
639 })
640 .collect();
641
642 if mel_files.is_empty() {
643 return Err(voirs_sdk::VoirsError::config_error(
644 "No mel spectrogram files found in input directory",
645 ));
646 }
647
648 if !global.quiet {
649 println!("Found {} mel spectrogram files", mel_files.len());
650 println!();
651 }
652
653 let device = if global.gpu {
655 #[cfg(feature = "cuda")]
656 {
657 Device::new_cuda(0).unwrap_or(Device::Cpu)
658 }
659 #[cfg(not(feature = "cuda"))]
660 {
661 Device::Cpu
662 }
663 } else {
664 Device::Cpu
665 };
666
667 let model = DiffWave::load_from_safetensors(checkpoint, device.clone())?;
668
669 let mut total_time = 0.0;
671 let mut successful = 0;
672 let mut failed = 0;
673 let batch_start = Instant::now();
674
675 for (idx, mel_file) in mel_files.iter().enumerate() {
677 let file_start = Instant::now();
678
679 let output_name = mel_file
680 .file_stem()
681 .and_then(|n| n.to_str())
682 .unwrap_or("output");
683 let output_path = output_dir.join(format!("{}.wav", output_name));
684
685 if !global.quiet {
686 println!(
687 "[{}/{}] Processing {}...",
688 idx + 1,
689 mel_files.len(),
690 mel_file.display()
691 );
692 }
693
694 let result =
696 process_single_mel(&model, mel_file, &output_path, steps, &device, global).await;
697
698 let file_time = file_start.elapsed().as_secs_f64();
699 total_time += file_time;
700
701 match result {
702 Ok(_) => {
703 successful += 1;
704 if !global.quiet {
705 println!(" ✓ Complete in {:.2}s", file_time);
706 }
707 }
708 Err(e) => {
709 failed += 1;
710 eprintln!(" ✗ Failed: {}", e);
711 }
712 }
713 }
714
715 let total_elapsed = batch_start.elapsed().as_secs_f64();
716
717 if !global.quiet || show_metrics {
719 println!();
720 println!("╔═══════════════════════════════════════╗");
721 println!("║ Batch Inference Complete ║");
722 println!("╠═══════════════════════════════════════╣");
723 println!("║ Total files: {:<21} ║", mel_files.len());
724 println!("║ Successful: {:<21} ║", successful);
725 println!("║ Failed: {:<21} ║", failed);
726 println!("║ Total time: {:<18.2}s ║", total_elapsed);
727 println!(
728 "║ Avg time/file: {:<18.2}s ║",
729 total_time / mel_files.len() as f64
730 );
731 if successful > 0 {
732 println!(
733 "║ Throughput: {:<18.2}/s ║",
734 successful as f64 / total_elapsed
735 );
736 }
737 println!("╚═══════════════════════════════════════╝");
738 }
739
740 Ok(())
741}
742
743async fn process_single_mel(
745 model: &DiffWave,
746 mel_path: &Path,
747 output_path: &Path,
748 steps: usize,
749 device: &Device,
750 _global: &GlobalOptions,
751) -> Result<()> {
752 let mel_tensor = load_mel_spectrogram(mel_path, device)?;
754
755 let sampling_method = SamplingMethod::DDIM { steps, eta: 0.0 };
757 let audio_tensor = model
758 .inference(&mel_tensor, sampling_method)
759 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Inference failed: {}", e)))?;
760
761 save_audio_tensor(&audio_tensor, output_path, 22050)?;
763
764 Ok(())
765}