Skip to main content

voirs_cli/commands/batch/
parallel.rs

1//! Parallel processing for batch operations.
2
3use super::{files::BatchInput, BatchConfig};
4use crate::GlobalOptions;
5use indicatif::{ProgressBar, ProgressStyle};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::{mpsc, Semaphore};
9use voirs_sdk::config::AppConfig;
10use voirs_sdk::types::SynthesisConfig;
11use voirs_sdk::VoirsPipeline;
12use voirs_sdk::{AudioFormat, QualityLevel, Result};
13
14/// Result of processing a single batch item
15#[derive(Debug, Clone)]
16pub struct ProcessingResult {
17    /// Input that was processed
18    pub input: BatchInput,
19    /// Whether processing succeeded
20    pub success: bool,
21    /// Error message if failed
22    pub error: Option<String>,
23    /// Output file path if successful
24    pub output_path: Option<std::path::PathBuf>,
25    /// Processing time
26    pub duration: Duration,
27    /// Generated audio duration
28    pub audio_duration: Option<f32>,
29}
30
31/// Statistics for batch processing
32#[derive(Debug, Clone)]
33pub struct BatchStatistics {
34    /// Total items processed
35    pub total_items: usize,
36    /// Successfully processed items
37    pub successful_items: usize,
38    /// Failed items
39    pub failed_items: usize,
40    /// Total processing time
41    pub total_time: Duration,
42    /// Average processing time per item
43    pub avg_time_per_item: Duration,
44    /// Total audio duration generated
45    pub total_audio_duration: f32,
46    /// Items processed per second
47    pub throughput: f32,
48}
49
50/// Process multiple inputs in parallel
51pub async fn process_inputs_parallel(
52    inputs: &[BatchInput],
53    batch_config: &BatchConfig,
54    app_config: &AppConfig,
55    global: &GlobalOptions,
56) -> Result<()> {
57    if inputs.is_empty() {
58        if !global.quiet {
59            println!("No inputs to process");
60        }
61        return Ok(());
62    }
63
64    let start_time = Instant::now();
65
66    // Create progress bar
67    let progress_bar = if !global.quiet {
68        let pb = ProgressBar::new(inputs.len() as u64);
69        pb.set_style(
70            ProgressStyle::default_bar()
71                .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
72                .unwrap()
73                .progress_chars("#>-")
74        );
75        Some(pb)
76    } else {
77        None
78    };
79
80    // Create channels for results
81    let (result_tx, mut result_rx) = mpsc::unbounded_channel::<ProcessingResult>();
82
83    // Create semaphore to limit concurrent workers
84    let semaphore = Arc::new(Semaphore::new(batch_config.workers));
85
86    // Pipeline configuration that will be used by each worker
87    let pipeline_config = (
88        batch_config.quality,
89        app_config.pipeline.use_gpu || global.gpu,
90    );
91
92    // Spawn worker tasks
93    let mut handles = Vec::new();
94
95    for (index, input) in inputs.iter().enumerate() {
96        let semaphore = semaphore.clone();
97        let batch_config = batch_config.clone();
98        let input = input.clone();
99        let result_tx = result_tx.clone();
100        let (quality, use_gpu) = pipeline_config;
101
102        let handle = tokio::spawn(async move {
103            let _permit = semaphore.acquire().await.unwrap();
104            let result = process_single_input_with_own_pipeline(
105                input,
106                index,
107                &batch_config,
108                quality,
109                use_gpu,
110            )
111            .await;
112            let _ = result_tx.send(result);
113        });
114
115        handles.push(handle);
116    }
117
118    // Drop the original sender so the receiver knows when all workers are done
119    drop(result_tx);
120
121    // Collect results
122    let mut results = Vec::new();
123    let mut successful_count = 0;
124    let mut failed_count = 0;
125    let mut total_audio_duration = 0.0;
126
127    while let Some(result) = result_rx.recv().await {
128        if result.success {
129            successful_count += 1;
130            if let Some(duration) = result.audio_duration {
131                total_audio_duration += duration;
132            }
133        } else {
134            failed_count += 1;
135            if !global.quiet {
136                if let Some(error) = &result.error {
137                    tracing::warn!("Failed to process '{}': {}", result.input.id, error);
138                }
139            }
140        }
141
142        results.push(result);
143
144        // Update progress
145        if let Some(pb) = &progress_bar {
146            pb.inc(1);
147            pb.set_message(format!("✓ {} ✗ {}", successful_count, failed_count));
148        }
149    }
150
151    // Wait for all workers to complete
152    for handle in handles {
153        let _ = handle.await;
154    }
155
156    if let Some(pb) = &progress_bar {
157        pb.finish_with_message("Processing complete");
158    }
159
160    // Calculate and display statistics
161    let total_time = start_time.elapsed();
162    let statistics = BatchStatistics {
163        total_items: results.len(),
164        successful_items: successful_count,
165        failed_items: failed_count,
166        total_time,
167        avg_time_per_item: if results.len() > 0 {
168            total_time / results.len() as u32
169        } else {
170            Duration::from_secs(0)
171        },
172        total_audio_duration,
173        throughput: if total_time.as_secs_f32() > 0.0 {
174            successful_count as f32 / total_time.as_secs_f32()
175        } else {
176            0.0
177        },
178    };
179
180    display_statistics(&statistics, global);
181
182    // Handle failed items
183    if failed_count > 0 && !global.quiet {
184        println!("\nFailed items:");
185        for result in &results {
186            if !result.success {
187                println!(
188                    "  - {}: {}",
189                    result.input.id,
190                    result.error.as_deref().unwrap_or("Unknown error")
191                );
192            }
193        }
194    }
195
196    Ok(())
197}
198
199/// Process a single input item with its own pipeline instance
200async fn process_single_input_with_own_pipeline(
201    input: BatchInput,
202    index: usize,
203    batch_config: &BatchConfig,
204    quality: QualityLevel,
205    use_gpu: bool,
206) -> ProcessingResult {
207    let start_time = Instant::now();
208
209    // Create a separate pipeline instance for this worker to avoid race conditions
210    let pipeline = match VoirsPipeline::builder()
211        .with_quality(quality)
212        .with_gpu_acceleration(use_gpu)
213        .build()
214        .await
215    {
216        Ok(pipeline) => pipeline,
217        Err(e) => {
218            return ProcessingResult {
219                input,
220                success: false,
221                error: Some(format!("Failed to create pipeline: {}", e)),
222                output_path: None,
223                duration: start_time.elapsed(),
224                audio_duration: None,
225            };
226        }
227    };
228
229    process_single_input_impl(input, index, &pipeline, batch_config, start_time).await
230}
231
232/// Process a single input item (implementation)
233async fn process_single_input_impl(
234    input: BatchInput,
235    index: usize,
236    pipeline: &VoirsPipeline,
237    batch_config: &BatchConfig,
238    start_time: Instant,
239) -> ProcessingResult {
240    // Create synthesis config with overrides from input
241    let synth_config = SynthesisConfig {
242        speaking_rate: input.rate.unwrap_or(batch_config.speaking_rate),
243        pitch_shift: input.pitch.unwrap_or(batch_config.pitch),
244        volume_gain: input.volume.unwrap_or(batch_config.volume),
245        quality: batch_config.quality,
246        ..Default::default()
247    };
248
249    // Attempt synthesis
250    match pipeline
251        .synthesize_with_config(&input.text, &synth_config)
252        .await
253    {
254        Ok(audio) => {
255            // Generate output filename
256            let format = batch_config.format;
257            let filename = super::files::generate_output_filename(&input, index, format);
258            let output_path = batch_config.output_dir.join(filename);
259
260            // Save audio file
261            match audio.save(&output_path, format) {
262                Ok(_) => ProcessingResult {
263                    input,
264                    success: true,
265                    error: None,
266                    output_path: Some(output_path),
267                    duration: start_time.elapsed(),
268                    audio_duration: Some(audio.duration()),
269                },
270                Err(e) => ProcessingResult {
271                    input,
272                    success: false,
273                    error: Some(format!("Failed to save audio: {}", e)),
274                    output_path: None,
275                    duration: start_time.elapsed(),
276                    audio_duration: None,
277                },
278            }
279        }
280        Err(e) => ProcessingResult {
281            input,
282            success: false,
283            error: Some(format!("Synthesis failed: {}", e)),
284            output_path: None,
285            duration: start_time.elapsed(),
286            audio_duration: None,
287        },
288    }
289}
290
291/// Display batch processing statistics
292fn display_statistics(stats: &BatchStatistics, global: &GlobalOptions) {
293    if global.quiet {
294        return;
295    }
296
297    println!("\nBatch Processing Statistics:");
298    println!("============================");
299    println!("Total items: {}", stats.total_items);
300    println!(
301        "Successful: {} ({:.1}%)",
302        stats.successful_items,
303        (stats.successful_items as f32 / stats.total_items as f32) * 100.0
304    );
305    println!(
306        "Failed: {} ({:.1}%)",
307        stats.failed_items,
308        (stats.failed_items as f32 / stats.total_items as f32) * 100.0
309    );
310    println!("Total time: {:.2}s", stats.total_time.as_secs_f32());
311    println!(
312        "Average time per item: {:.2}s",
313        stats.avg_time_per_item.as_secs_f32()
314    );
315    println!("Total audio generated: {:.2}s", stats.total_audio_duration);
316    println!("Throughput: {:.2} items/second", stats.throughput);
317
318    if stats.total_audio_duration > 0.0 {
319        let real_time_factor = stats.total_time.as_secs_f32() / stats.total_audio_duration;
320        println!("Real-time factor: {:.2}x", real_time_factor);
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use std::collections::HashMap;
328
329    #[test]
330    fn test_batch_statistics_calculation() {
331        let stats = BatchStatistics {
332            total_items: 100,
333            successful_items: 95,
334            failed_items: 5,
335            total_time: Duration::from_secs(60),
336            avg_time_per_item: Duration::from_millis(600),
337            total_audio_duration: 120.0,
338            throughput: 1.58,
339        };
340
341        assert_eq!(stats.total_items, 100);
342        assert_eq!(stats.successful_items, 95);
343        assert_eq!(stats.failed_items, 5);
344        assert_eq!(stats.throughput, 1.58);
345    }
346
347    #[test]
348    fn test_processing_result_creation() {
349        let input = BatchInput {
350            id: "test".to_string(),
351            text: "Test text".to_string(),
352            filename: None,
353            voice: None,
354            rate: None,
355            pitch: None,
356            volume: None,
357            metadata: HashMap::new(),
358        };
359
360        let result = ProcessingResult {
361            input: input.clone(),
362            success: true,
363            error: None,
364            output_path: Some(std::path::PathBuf::from("/tmp/output.wav")),
365            duration: Duration::from_millis(500),
366            audio_duration: Some(2.5),
367        };
368
369        assert!(result.success);
370        assert!(result.error.is_none());
371        assert!(result.output_path.is_some());
372        assert_eq!(result.input.id, "test");
373    }
374}