voirs_cli/commands/batch/
parallel.rs1use super::{files::BatchInput, BatchConfig};
4use crate::GlobalOptions;
5use indicatif::{ProgressBar, ProgressStyle};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::{mpsc, Semaphore};
9use voirs_sdk::config::AppConfig;
10use voirs_sdk::types::SynthesisConfig;
11use voirs_sdk::VoirsPipeline;
12use voirs_sdk::{AudioFormat, QualityLevel, Result};
13
14#[derive(Debug, Clone)]
16pub struct ProcessingResult {
17 pub input: BatchInput,
19 pub success: bool,
21 pub error: Option<String>,
23 pub output_path: Option<std::path::PathBuf>,
25 pub duration: Duration,
27 pub audio_duration: Option<f32>,
29}
30
31#[derive(Debug, Clone)]
33pub struct BatchStatistics {
34 pub total_items: usize,
36 pub successful_items: usize,
38 pub failed_items: usize,
40 pub total_time: Duration,
42 pub avg_time_per_item: Duration,
44 pub total_audio_duration: f32,
46 pub throughput: f32,
48}
49
50pub async fn process_inputs_parallel(
52 inputs: &[BatchInput],
53 batch_config: &BatchConfig,
54 app_config: &AppConfig,
55 global: &GlobalOptions,
56) -> Result<()> {
57 if inputs.is_empty() {
58 if !global.quiet {
59 println!("No inputs to process");
60 }
61 return Ok(());
62 }
63
64 let start_time = Instant::now();
65
66 let progress_bar = if !global.quiet {
68 let pb = ProgressBar::new(inputs.len() as u64);
69 pb.set_style(
70 ProgressStyle::default_bar()
71 .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
72 .unwrap()
73 .progress_chars("#>-")
74 );
75 Some(pb)
76 } else {
77 None
78 };
79
80 let (result_tx, mut result_rx) = mpsc::unbounded_channel::<ProcessingResult>();
82
83 let semaphore = Arc::new(Semaphore::new(batch_config.workers));
85
86 let pipeline_config = (
88 batch_config.quality,
89 app_config.pipeline.use_gpu || global.gpu,
90 );
91
92 let mut handles = Vec::new();
94
95 for (index, input) in inputs.iter().enumerate() {
96 let semaphore = semaphore.clone();
97 let batch_config = batch_config.clone();
98 let input = input.clone();
99 let result_tx = result_tx.clone();
100 let (quality, use_gpu) = pipeline_config;
101
102 let handle = tokio::spawn(async move {
103 let _permit = semaphore.acquire().await.unwrap();
104 let result = process_single_input_with_own_pipeline(
105 input,
106 index,
107 &batch_config,
108 quality,
109 use_gpu,
110 )
111 .await;
112 let _ = result_tx.send(result);
113 });
114
115 handles.push(handle);
116 }
117
118 drop(result_tx);
120
121 let mut results = Vec::new();
123 let mut successful_count = 0;
124 let mut failed_count = 0;
125 let mut total_audio_duration = 0.0;
126
127 while let Some(result) = result_rx.recv().await {
128 if result.success {
129 successful_count += 1;
130 if let Some(duration) = result.audio_duration {
131 total_audio_duration += duration;
132 }
133 } else {
134 failed_count += 1;
135 if !global.quiet {
136 if let Some(error) = &result.error {
137 tracing::warn!("Failed to process '{}': {}", result.input.id, error);
138 }
139 }
140 }
141
142 results.push(result);
143
144 if let Some(pb) = &progress_bar {
146 pb.inc(1);
147 pb.set_message(format!("✓ {} ✗ {}", successful_count, failed_count));
148 }
149 }
150
151 for handle in handles {
153 let _ = handle.await;
154 }
155
156 if let Some(pb) = &progress_bar {
157 pb.finish_with_message("Processing complete");
158 }
159
160 let total_time = start_time.elapsed();
162 let statistics = BatchStatistics {
163 total_items: results.len(),
164 successful_items: successful_count,
165 failed_items: failed_count,
166 total_time,
167 avg_time_per_item: if results.len() > 0 {
168 total_time / results.len() as u32
169 } else {
170 Duration::from_secs(0)
171 },
172 total_audio_duration,
173 throughput: if total_time.as_secs_f32() > 0.0 {
174 successful_count as f32 / total_time.as_secs_f32()
175 } else {
176 0.0
177 },
178 };
179
180 display_statistics(&statistics, global);
181
182 if failed_count > 0 && !global.quiet {
184 println!("\nFailed items:");
185 for result in &results {
186 if !result.success {
187 println!(
188 " - {}: {}",
189 result.input.id,
190 result.error.as_deref().unwrap_or("Unknown error")
191 );
192 }
193 }
194 }
195
196 Ok(())
197}
198
199async fn process_single_input_with_own_pipeline(
201 input: BatchInput,
202 index: usize,
203 batch_config: &BatchConfig,
204 quality: QualityLevel,
205 use_gpu: bool,
206) -> ProcessingResult {
207 let start_time = Instant::now();
208
209 let pipeline = match VoirsPipeline::builder()
211 .with_quality(quality)
212 .with_gpu_acceleration(use_gpu)
213 .build()
214 .await
215 {
216 Ok(pipeline) => pipeline,
217 Err(e) => {
218 return ProcessingResult {
219 input,
220 success: false,
221 error: Some(format!("Failed to create pipeline: {}", e)),
222 output_path: None,
223 duration: start_time.elapsed(),
224 audio_duration: None,
225 };
226 }
227 };
228
229 process_single_input_impl(input, index, &pipeline, batch_config, start_time).await
230}
231
232async fn process_single_input_impl(
234 input: BatchInput,
235 index: usize,
236 pipeline: &VoirsPipeline,
237 batch_config: &BatchConfig,
238 start_time: Instant,
239) -> ProcessingResult {
240 let synth_config = SynthesisConfig {
242 speaking_rate: input.rate.unwrap_or(batch_config.speaking_rate),
243 pitch_shift: input.pitch.unwrap_or(batch_config.pitch),
244 volume_gain: input.volume.unwrap_or(batch_config.volume),
245 quality: batch_config.quality,
246 ..Default::default()
247 };
248
249 match pipeline
251 .synthesize_with_config(&input.text, &synth_config)
252 .await
253 {
254 Ok(audio) => {
255 let format = batch_config.format;
257 let filename = super::files::generate_output_filename(&input, index, format);
258 let output_path = batch_config.output_dir.join(filename);
259
260 match audio.save(&output_path, format) {
262 Ok(_) => ProcessingResult {
263 input,
264 success: true,
265 error: None,
266 output_path: Some(output_path),
267 duration: start_time.elapsed(),
268 audio_duration: Some(audio.duration()),
269 },
270 Err(e) => ProcessingResult {
271 input,
272 success: false,
273 error: Some(format!("Failed to save audio: {}", e)),
274 output_path: None,
275 duration: start_time.elapsed(),
276 audio_duration: None,
277 },
278 }
279 }
280 Err(e) => ProcessingResult {
281 input,
282 success: false,
283 error: Some(format!("Synthesis failed: {}", e)),
284 output_path: None,
285 duration: start_time.elapsed(),
286 audio_duration: None,
287 },
288 }
289}
290
291fn display_statistics(stats: &BatchStatistics, global: &GlobalOptions) {
293 if global.quiet {
294 return;
295 }
296
297 println!("\nBatch Processing Statistics:");
298 println!("============================");
299 println!("Total items: {}", stats.total_items);
300 println!(
301 "Successful: {} ({:.1}%)",
302 stats.successful_items,
303 (stats.successful_items as f32 / stats.total_items as f32) * 100.0
304 );
305 println!(
306 "Failed: {} ({:.1}%)",
307 stats.failed_items,
308 (stats.failed_items as f32 / stats.total_items as f32) * 100.0
309 );
310 println!("Total time: {:.2}s", stats.total_time.as_secs_f32());
311 println!(
312 "Average time per item: {:.2}s",
313 stats.avg_time_per_item.as_secs_f32()
314 );
315 println!("Total audio generated: {:.2}s", stats.total_audio_duration);
316 println!("Throughput: {:.2} items/second", stats.throughput);
317
318 if stats.total_audio_duration > 0.0 {
319 let real_time_factor = stats.total_time.as_secs_f32() / stats.total_audio_duration;
320 println!("Real-time factor: {:.2}x", real_time_factor);
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use std::collections::HashMap;
328
329 #[test]
330 fn test_batch_statistics_calculation() {
331 let stats = BatchStatistics {
332 total_items: 100,
333 successful_items: 95,
334 failed_items: 5,
335 total_time: Duration::from_secs(60),
336 avg_time_per_item: Duration::from_millis(600),
337 total_audio_duration: 120.0,
338 throughput: 1.58,
339 };
340
341 assert_eq!(stats.total_items, 100);
342 assert_eq!(stats.successful_items, 95);
343 assert_eq!(stats.failed_items, 5);
344 assert_eq!(stats.throughput, 1.58);
345 }
346
347 #[test]
348 fn test_processing_result_creation() {
349 let input = BatchInput {
350 id: "test".to_string(),
351 text: "Test text".to_string(),
352 filename: None,
353 voice: None,
354 rate: None,
355 pitch: None,
356 volume: None,
357 metadata: HashMap::new(),
358 };
359
360 let result = ProcessingResult {
361 input: input.clone(),
362 success: true,
363 error: None,
364 output_path: Some(std::path::PathBuf::from("/tmp/output.wav")),
365 duration: Duration::from_millis(500),
366 audio_duration: Some(2.5),
367 };
368
369 assert!(result.success);
370 assert!(result.error.is_none());
371 assert!(result.output_path.is_some());
372 assert_eq!(result.input.id, "test");
373 }
374}