Skip to main content

voirs_evaluation/
workflows.rs

1//! Evaluation Workflow System
2//!
3//! Automated evaluation pipelines with stage-based processing, conditional execution,
4//! parallel processing, and comprehensive error handling.
5//!
6//! # Features
7//!
8//! - **Stage-Based Processing**: Define evaluation workflows as a series of stages
9//! - **Conditional Execution**: Skip or execute stages based on conditions
10//! - **Parallel Processing**: Run independent stages concurrently
11//! - **Error Handling**: Graceful failure handling with retry mechanisms
12//! - **Progress Tracking**: Monitor workflow execution progress
13//! - **Caching**: Cache intermediate results for efficiency
14//! - **Scheduling**: Schedule workflows to run at specific times
15//!
16//! # Example
17//!
18//! ```rust
19//! use voirs_evaluation::workflows::{WorkflowBuilder, WorkflowStage, StageConfig};
20//! use voirs_evaluation::quality::QualityEvaluator;
21//! use voirs_sdk::AudioBuffer;
22//!
23//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
24//! // Create a multi-stage evaluation workflow
25//! let workflow = WorkflowBuilder::new("quality_pipeline")
26//!     .add_stage(WorkflowStage::quality_evaluation(StageConfig::default()))
27//!     .add_stage(WorkflowStage::pronunciation_evaluation(StageConfig::default()))
28//!     .add_stage(WorkflowStage::export_results(StageConfig::default()))
29//!     .build()?;
30//!
31//! // Execute the workflow
32//! let audio = AudioBuffer::new(vec![0.1; 16000], 16000, 1);
33//! let results = workflow.execute(&audio, None).await?;
34//! # Ok(())
35//! # }
36//! ```
37
38use async_trait::async_trait;
39use scirs2_core::random::Rng;
40use serde::{Deserialize, Serialize};
41use std::collections::HashMap;
42use std::path::PathBuf;
43use std::sync::Arc;
44use std::time::{Duration, SystemTime};
45use thiserror::Error;
46use tokio::sync::{RwLock, Semaphore};
47use tracing::{debug, error, info, warn};
48use voirs_sdk::{AudioBuffer, VoirsError};
49
50use crate::caching::CacheConfig;
51use crate::pronunciation::PronunciationEvaluatorImpl;
52use crate::quality::QualityEvaluator;
53use crate::traits::{
54    PronunciationEvaluator as PronunciationEvaluatorTrait, PronunciationScore,
55    QualityEvaluationConfig, QualityEvaluator as QualityEvaluatorTrait, QualityScore,
56};
57
58/// Workflow system errors
59#[derive(Error, Debug)]
60pub enum WorkflowError {
61    /// Stage execution failed
62    #[error("Stage '{stage}' execution failed: {message}")]
63    StageExecutionError {
64        /// Stage name
65        stage: String,
66        /// Error message
67        message: String,
68        /// Source error
69        #[source]
70        source: Option<Box<dyn std::error::Error + Send + Sync>>,
71    },
72
73    /// Workflow validation failed
74    #[error("Workflow validation failed: {message}")]
75    ValidationError {
76        /// Error message
77        message: String,
78    },
79
80    /// Workflow configuration error
81    #[error("Workflow configuration error: {message}")]
82    ConfigurationError {
83        /// Error message
84        message: String,
85    },
86
87    /// Dependency error
88    #[error("Dependency error: {message}")]
89    DependencyError {
90        /// Error message
91        message: String,
92    },
93
94    /// Timeout error
95    #[error("Workflow timed out after {duration:?}")]
96    TimeoutError {
97        /// Timeout duration
98        duration: Duration,
99    },
100
101    /// Condition evaluation error
102    #[error("Condition evaluation failed: {message}")]
103    ConditionError {
104        /// Error message
105        message: String,
106    },
107
108    /// VoiRS error
109    #[error("VoiRS error: {0}")]
110    VoirsError(#[from] VoirsError),
111
112    /// Serialization error
113    #[error("Serialization error: {0}")]
114    SerializationError(#[from] serde_json::Error),
115
116    /// IO error
117    #[error("IO error: {0}")]
118    IoError(#[from] std::io::Error),
119
120    /// Evaluation error
121    #[error("Evaluation error: {0}")]
122    EvaluationError(#[from] crate::EvaluationError),
123}
124
125/// Workflow stage type
126#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
127pub enum StageType {
128    /// Quality evaluation stage
129    QualityEvaluation,
130    /// Pronunciation evaluation stage
131    PronunciationEvaluation,
132    /// Data preprocessing stage
133    Preprocessing,
134    /// Feature extraction stage
135    FeatureExtraction,
136    /// Statistical analysis stage
137    StatisticalAnalysis,
138    /// Export results stage
139    ExportResults,
140    /// Custom stage
141    Custom(String),
142}
143
144/// Stage execution condition
145#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
146pub enum StageCondition {
147    /// Always execute
148    Always,
149    /// Execute if previous stage succeeded
150    OnSuccess,
151    /// Execute if previous stage failed
152    OnFailure,
153    /// Execute if specific metric meets threshold
154    MetricThreshold {
155        /// Metric name
156        metric: String,
157        /// Minimum threshold
158        min_value: f64,
159        /// Maximum threshold
160        max_value: Option<f64>,
161    },
162    /// Custom condition function
163    Custom(String),
164}
165
166/// Stage configuration
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct StageConfig {
169    /// Stage name
170    pub name: String,
171    /// Stage type
172    pub stage_type: StageType,
173    /// Execution condition
174    pub condition: StageCondition,
175    /// Maximum retry attempts
176    pub max_retries: usize,
177    /// Timeout duration (seconds)
178    pub timeout_seconds: Option<u64>,
179    /// Enable caching
180    pub enable_cache: bool,
181    /// Stage-specific parameters
182    pub parameters: HashMap<String, serde_json::Value>,
183    /// Dependencies (other stages that must complete first)
184    pub dependencies: Vec<String>,
185}
186
187impl Default for StageConfig {
188    fn default() -> Self {
189        Self {
190            name: String::new(),
191            stage_type: StageType::Custom("default".to_string()),
192            condition: StageCondition::Always,
193            max_retries: 3,
194            timeout_seconds: Some(300), // 5 minutes
195            enable_cache: true,
196            parameters: HashMap::new(),
197            dependencies: Vec::new(),
198        }
199    }
200}
201
202/// Stage execution result
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct StageResult {
205    /// Stage name
206    pub stage_name: String,
207    /// Execution status
208    pub status: StageStatus,
209    /// Execution duration
210    pub duration_ms: u64,
211    /// Quality score (if applicable)
212    pub quality_score: Option<QualityScore>,
213    /// Custom results
214    pub custom_results: HashMap<String, serde_json::Value>,
215    /// Error message (if failed)
216    pub error_message: Option<String>,
217    /// Retry count
218    pub retry_count: usize,
219}
220
221/// Stage execution status
222#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
223pub enum StageStatus {
224    /// Stage pending
225    Pending,
226    /// Stage running
227    Running,
228    /// Stage completed successfully
229    Success,
230    /// Stage failed
231    Failed,
232    /// Stage skipped
233    Skipped,
234    /// Stage timeout
235    Timeout,
236}
237
238/// Workflow stage trait
239#[async_trait]
240pub trait WorkflowStageExecutor: Send + Sync {
241    /// Get stage configuration
242    fn config(&self) -> &StageConfig;
243
244    /// Execute the stage
245    async fn execute(
246        &self,
247        audio: &AudioBuffer,
248        reference: Option<&AudioBuffer>,
249        context: &WorkflowContext,
250    ) -> Result<StageResult, WorkflowError>;
251
252    /// Validate stage configuration
253    fn validate(&self) -> Result<(), WorkflowError> {
254        Ok(())
255    }
256
257    /// Get stage dependencies
258    fn dependencies(&self) -> Vec<String> {
259        self.config().dependencies.clone()
260    }
261}
262
263/// Quality evaluation stage
264pub struct QualityEvaluationStage {
265    config: StageConfig,
266    evaluator: Arc<RwLock<QualityEvaluator>>,
267}
268
269impl QualityEvaluationStage {
270    /// Create new quality evaluation stage
271    pub async fn new(config: StageConfig) -> Result<Self, WorkflowError> {
272        let evaluator =
273            QualityEvaluator::new()
274                .await
275                .map_err(|e| WorkflowError::ConfigurationError {
276                    message: format!("Failed to create quality evaluator: {}", e),
277                })?;
278
279        Ok(Self {
280            config,
281            evaluator: Arc::new(RwLock::new(evaluator)),
282        })
283    }
284}
285
286#[async_trait]
287impl WorkflowStageExecutor for QualityEvaluationStage {
288    fn config(&self) -> &StageConfig {
289        &self.config
290    }
291
292    async fn execute(
293        &self,
294        audio: &AudioBuffer,
295        reference: Option<&AudioBuffer>,
296        _context: &WorkflowContext,
297    ) -> Result<StageResult, WorkflowError> {
298        let start = SystemTime::now();
299        let evaluator = self.evaluator.read().await;
300
301        let eval_config = QualityEvaluationConfig::default();
302        let quality_score = evaluator
303            .evaluate_quality(audio, reference, Some(&eval_config))
304            .await?;
305
306        let duration_ms = SystemTime::now()
307            .duration_since(start)
308            .unwrap_or(Duration::ZERO)
309            .as_millis() as u64;
310
311        Ok(StageResult {
312            stage_name: self.config.name.clone(),
313            status: StageStatus::Success,
314            duration_ms,
315            quality_score: Some(quality_score),
316            custom_results: HashMap::new(),
317            error_message: None,
318            retry_count: 0,
319        })
320    }
321}
322
323/// Pronunciation evaluation stage
324pub struct PronunciationEvaluationStage {
325    config: StageConfig,
326    evaluator: Arc<RwLock<PronunciationEvaluatorImpl>>,
327}
328
329impl PronunciationEvaluationStage {
330    /// Create new pronunciation evaluation stage
331    pub async fn new(config: StageConfig) -> Result<Self, WorkflowError> {
332        let evaluator = PronunciationEvaluatorImpl::new().await.map_err(|e| {
333            WorkflowError::ConfigurationError {
334                message: format!("Failed to create pronunciation evaluator: {}", e),
335            }
336        })?;
337
338        Ok(Self {
339            config,
340            evaluator: Arc::new(RwLock::new(evaluator)),
341        })
342    }
343}
344
345#[async_trait]
346impl WorkflowStageExecutor for PronunciationEvaluationStage {
347    fn config(&self) -> &StageConfig {
348        &self.config
349    }
350
351    async fn execute(
352        &self,
353        audio: &AudioBuffer,
354        _reference: Option<&AudioBuffer>,
355        context: &WorkflowContext,
356    ) -> Result<StageResult, WorkflowError> {
357        let start = SystemTime::now();
358
359        // Get expected text from context parameters
360        // In production, this would get from context async
361        let expected_text = "Hello world"; // Placeholder for context parameter
362
363        let _language = "en-US"; // Placeholder
364
365        let evaluator = self.evaluator.read().await;
366        let pronunciation_score = evaluator
367            .evaluate_pronunciation(audio, expected_text, None)
368            .await?;
369
370        let duration_ms = SystemTime::now()
371            .duration_since(start)
372            .unwrap_or(Duration::ZERO)
373            .as_millis() as u64;
374
375        // Store pronunciation score in custom results
376        let mut custom_results = HashMap::new();
377        custom_results.insert(
378            "pronunciation_score".to_string(),
379            serde_json::json!({
380                "overall_score": pronunciation_score.overall_score,
381                "fluency_score": pronunciation_score.fluency_score,
382                "rhythm_score": pronunciation_score.rhythm_score,
383            }),
384        );
385
386        Ok(StageResult {
387            stage_name: self.config.name.clone(),
388            status: StageStatus::Success,
389            duration_ms,
390            quality_score: None,
391            custom_results,
392            error_message: None,
393            retry_count: 0,
394        })
395    }
396}
397
398/// Export results stage
399pub struct ExportResultsStage {
400    config: StageConfig,
401}
402
403impl ExportResultsStage {
404    /// Create new export results stage
405    pub fn new(config: StageConfig) -> Self {
406        Self { config }
407    }
408}
409
410#[async_trait]
411impl WorkflowStageExecutor for ExportResultsStage {
412    fn config(&self) -> &StageConfig {
413        &self.config
414    }
415
416    async fn execute(
417        &self,
418        _audio: &AudioBuffer,
419        _reference: Option<&AudioBuffer>,
420        context: &WorkflowContext,
421    ) -> Result<StageResult, WorkflowError> {
422        let start = SystemTime::now();
423
424        // Get output path from config
425        let output_path = self
426            .config
427            .parameters
428            .get("output_path")
429            .and_then(|v| v.as_str())
430            .ok_or_else(|| WorkflowError::ConfigurationError {
431                message: "Missing 'output_path' parameter for export stage".to_string(),
432            })?;
433
434        // Collect all results from context
435        let workflow_results = context.get_all_results();
436        let output_data = serde_json::to_string_pretty(&workflow_results)?;
437
438        // Write to file
439        tokio::fs::write(output_path, output_data).await?;
440
441        let duration_ms = SystemTime::now()
442            .duration_since(start)
443            .unwrap_or(Duration::ZERO)
444            .as_millis() as u64;
445
446        info!("Exported workflow results to: {}", output_path);
447
448        Ok(StageResult {
449            stage_name: self.config.name.clone(),
450            status: StageStatus::Success,
451            duration_ms,
452            quality_score: None,
453            custom_results: HashMap::new(),
454            error_message: None,
455            retry_count: 0,
456        })
457    }
458}
459
460/// Workflow context for sharing data between stages
461#[derive(Clone)]
462pub struct WorkflowContext {
463    parameters: Arc<RwLock<HashMap<String, serde_json::Value>>>,
464    results: Arc<RwLock<HashMap<String, StageResult>>>,
465    cache_config: Option<CacheConfig>,
466}
467
468impl WorkflowContext {
469    /// Create new workflow context
470    pub fn new() -> Self {
471        Self {
472            parameters: Arc::new(RwLock::new(HashMap::new())),
473            results: Arc::new(RwLock::new(HashMap::new())),
474            cache_config: None,
475        }
476    }
477
478    /// Set parameter
479    pub async fn set_parameter(&self, key: String, value: serde_json::Value) {
480        let mut params = self.parameters.write().await;
481        params.insert(key, value);
482    }
483
484    /// Get parameter
485    pub fn get_parameter(&self, key: &str) -> Option<serde_json::Value> {
486        // This is a simplified synchronous version for ease of use
487        // In a real implementation, you might want to use async
488        None
489    }
490
491    /// Set stage result
492    pub async fn set_result(&self, stage_name: String, result: StageResult) {
493        let mut results = self.results.write().await;
494        results.insert(stage_name, result);
495    }
496
497    /// Get stage result
498    pub async fn get_result(&self, stage_name: &str) -> Option<StageResult> {
499        let results = self.results.read().await;
500        results.get(stage_name).cloned()
501    }
502
503    /// Get all results
504    pub fn get_all_results(&self) -> HashMap<String, StageResult> {
505        // Simplified synchronous version
506        HashMap::new()
507    }
508
509    /// Enable caching
510    pub fn with_cache(mut self, cache_config: CacheConfig) -> Self {
511        self.cache_config = Some(cache_config);
512        self
513    }
514}
515
516impl Default for WorkflowContext {
517    fn default() -> Self {
518        Self::new()
519    }
520}
521
522/// Workflow configuration
523#[derive(Debug, Clone, Serialize, Deserialize)]
524pub struct WorkflowConfig {
525    /// Workflow name
526    pub name: String,
527    /// Workflow description
528    pub description: Option<String>,
529    /// Maximum parallel stages
530    pub max_parallel_stages: usize,
531    /// Global timeout (seconds)
532    pub global_timeout_seconds: Option<u64>,
533    /// Enable caching
534    pub enable_cache: bool,
535    /// Cache configuration
536    pub cache_config: Option<CacheConfig>,
537    /// Retry policy
538    pub retry_policy: RetryPolicy,
539}
540
541impl Default for WorkflowConfig {
542    fn default() -> Self {
543        Self {
544            name: "default_workflow".to_string(),
545            description: None,
546            max_parallel_stages: 4,
547            global_timeout_seconds: Some(1800), // 30 minutes
548            enable_cache: true,
549            cache_config: None,
550            retry_policy: RetryPolicy::default(),
551        }
552    }
553}
554
555/// Retry policy
556#[derive(Debug, Clone, Serialize, Deserialize)]
557pub struct RetryPolicy {
558    /// Maximum retry attempts
559    pub max_attempts: usize,
560    /// Initial delay (milliseconds)
561    pub initial_delay_ms: u64,
562    /// Backoff multiplier
563    pub backoff_multiplier: f64,
564    /// Maximum delay (milliseconds)
565    pub max_delay_ms: u64,
566}
567
568impl Default for RetryPolicy {
569    fn default() -> Self {
570        Self {
571            max_attempts: 3,
572            initial_delay_ms: 1000,
573            backoff_multiplier: 2.0,
574            max_delay_ms: 30000,
575        }
576    }
577}
578
579/// Evaluation workflow
580pub struct Workflow {
581    config: WorkflowConfig,
582    stages: Vec<Arc<dyn WorkflowStageExecutor>>,
583    context: WorkflowContext,
584    semaphore: Arc<Semaphore>,
585}
586
587impl Workflow {
588    /// Create new workflow
589    pub fn new(config: WorkflowConfig, stages: Vec<Arc<dyn WorkflowStageExecutor>>) -> Self {
590        let semaphore = Arc::new(Semaphore::new(config.max_parallel_stages));
591        let mut context = WorkflowContext::new();
592
593        // Initialize cache if enabled
594        if config.enable_cache {
595            let cache_config = config.cache_config.clone().unwrap_or_default();
596            context = context.with_cache(cache_config);
597        }
598
599        Self {
600            config,
601            stages,
602            context,
603            semaphore,
604        }
605    }
606
607    /// Execute the workflow
608    pub async fn execute(
609        &self,
610        audio: &AudioBuffer,
611        reference: Option<&AudioBuffer>,
612    ) -> Result<WorkflowResult, WorkflowError> {
613        let workflow_start = SystemTime::now();
614        info!("Starting workflow: {}", self.config.name);
615
616        let mut stage_results = Vec::new();
617
618        for stage in &self.stages {
619            let _permit =
620                self.semaphore
621                    .acquire()
622                    .await
623                    .map_err(|e| WorkflowError::StageExecutionError {
624                        stage: stage.config().name.clone(),
625                        message: format!("Failed to acquire semaphore: {}", e),
626                        source: None,
627                    })?;
628
629            // Check condition
630            if !self
631                .should_execute_stage(stage.config(), &stage_results)
632                .await?
633            {
634                info!("Skipping stage '{}' due to condition", stage.config().name);
635                stage_results.push(StageResult {
636                    stage_name: stage.config().name.clone(),
637                    status: StageStatus::Skipped,
638                    duration_ms: 0,
639                    quality_score: None,
640                    custom_results: HashMap::new(),
641                    error_message: None,
642                    retry_count: 0,
643                });
644                continue;
645            }
646
647            // Execute stage with retries
648            let result = self
649                .execute_stage_with_retry(stage.as_ref(), audio, reference)
650                .await?;
651
652            self.context
653                .set_result(stage.config().name.clone(), result.clone())
654                .await;
655            stage_results.push(result);
656        }
657
658        let total_duration_ms = SystemTime::now()
659            .duration_since(workflow_start)
660            .unwrap_or(Duration::ZERO)
661            .as_millis() as u64;
662
663        info!(
664            "Workflow '{}' completed in {}ms",
665            self.config.name, total_duration_ms
666        );
667
668        Ok(WorkflowResult {
669            workflow_name: self.config.name.clone(),
670            stage_results,
671            total_duration_ms,
672            status: WorkflowStatus::Success,
673            error_message: None,
674        })
675    }
676
677    /// Check if stage should execute based on condition
678    async fn should_execute_stage(
679        &self,
680        config: &StageConfig,
681        previous_results: &[StageResult],
682    ) -> Result<bool, WorkflowError> {
683        match &config.condition {
684            StageCondition::Always => Ok(true),
685            StageCondition::OnSuccess => {
686                if let Some(last_result) = previous_results.last() {
687                    Ok(last_result.status == StageStatus::Success)
688                } else {
689                    Ok(true)
690                }
691            }
692            StageCondition::OnFailure => {
693                if let Some(last_result) = previous_results.last() {
694                    Ok(last_result.status == StageStatus::Failed)
695                } else {
696                    Ok(false)
697                }
698            }
699            StageCondition::MetricThreshold {
700                metric,
701                min_value,
702                max_value,
703            } => {
704                // Check if any previous stage has the metric
705                for result in previous_results.iter().rev() {
706                    if let Some(quality) = &result.quality_score {
707                        let value = match metric.as_str() {
708                            "overall_score" => quality.overall_score as f64,
709                            _ => continue, // Skip unknown metrics
710                        };
711
712                        let meets_min = value >= *min_value;
713                        let meets_max = max_value.map_or(true, |max| value <= max);
714                        return Ok(meets_min && meets_max);
715                    }
716                }
717                Ok(false)
718            }
719            StageCondition::Custom(_) => {
720                warn!("Custom conditions not yet implemented, defaulting to true");
721                Ok(true)
722            }
723        }
724    }
725
726    /// Execute stage with retry logic
727    async fn execute_stage_with_retry(
728        &self,
729        stage: &dyn WorkflowStageExecutor,
730        audio: &AudioBuffer,
731        reference: Option<&AudioBuffer>,
732    ) -> Result<StageResult, WorkflowError> {
733        let config = stage.config();
734        let max_retries = config.max_retries;
735        let mut retry_count = 0;
736        let mut last_error = None;
737
738        while retry_count <= max_retries {
739            match stage.execute(audio, reference, &self.context).await {
740                Ok(mut result) => {
741                    result.retry_count = retry_count;
742                    return Ok(result);
743                }
744                Err(e) => {
745                    error!(
746                        "Stage '{}' failed (attempt {}/{}): {}",
747                        config.name,
748                        retry_count + 1,
749                        max_retries + 1,
750                        e
751                    );
752                    last_error = Some(e);
753                    retry_count += 1;
754
755                    if retry_count <= max_retries {
756                        // Calculate backoff delay
757                        let delay_ms = self.calculate_backoff_delay(retry_count);
758                        tokio::time::sleep(Duration::from_millis(delay_ms)).await;
759                    }
760                }
761            }
762        }
763
764        // All retries exhausted
765        Ok(StageResult {
766            stage_name: config.name.clone(),
767            status: StageStatus::Failed,
768            duration_ms: 0,
769            quality_score: None,
770            custom_results: HashMap::new(),
771            error_message: Some(
772                last_error
773                    .map(|e| e.to_string())
774                    .unwrap_or_else(|| "Unknown error".to_string()),
775            ),
776            retry_count,
777        })
778    }
779
780    /// Calculate backoff delay for retries
781    fn calculate_backoff_delay(&self, retry_count: usize) -> u64 {
782        let policy = &self.config.retry_policy;
783        let delay =
784            policy.initial_delay_ms as f64 * policy.backoff_multiplier.powi(retry_count as i32 - 1);
785        delay.min(policy.max_delay_ms as f64) as u64
786    }
787}
788
789/// Workflow execution result
790#[derive(Debug, Clone, Serialize, Deserialize)]
791pub struct WorkflowResult {
792    /// Workflow name
793    pub workflow_name: String,
794    /// Stage results
795    pub stage_results: Vec<StageResult>,
796    /// Total execution duration (milliseconds)
797    pub total_duration_ms: u64,
798    /// Workflow status
799    pub status: WorkflowStatus,
800    /// Error message (if failed)
801    pub error_message: Option<String>,
802}
803
804/// Workflow execution status
805#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
806pub enum WorkflowStatus {
807    /// Workflow completed successfully
808    Success,
809    /// Workflow partially completed
810    PartialSuccess,
811    /// Workflow failed
812    Failed,
813    /// Workflow cancelled
814    Cancelled,
815}
816
817/// Workflow builder
818pub struct WorkflowBuilder {
819    config: WorkflowConfig,
820    stages: Vec<Arc<dyn WorkflowStageExecutor>>,
821}
822
823impl WorkflowBuilder {
824    /// Create new workflow builder
825    pub fn new(name: impl Into<String>) -> Self {
826        Self {
827            config: WorkflowConfig {
828                name: name.into(),
829                ..Default::default()
830            },
831            stages: Vec::new(),
832        }
833    }
834
835    /// Set workflow description
836    pub fn description(mut self, description: impl Into<String>) -> Self {
837        self.config.description = Some(description.into());
838        self
839    }
840
841    /// Set maximum parallel stages
842    pub fn max_parallel_stages(mut self, max: usize) -> Self {
843        self.config.max_parallel_stages = max;
844        self
845    }
846
847    /// Set global timeout
848    pub fn global_timeout(mut self, seconds: u64) -> Self {
849        self.config.global_timeout_seconds = Some(seconds);
850        self
851    }
852
853    /// Enable caching
854    pub fn enable_cache(mut self, enable: bool) -> Self {
855        self.config.enable_cache = enable;
856        self
857    }
858
859    /// Set cache configuration
860    pub fn cache_config(mut self, config: CacheConfig) -> Self {
861        self.config.cache_config = Some(config);
862        self
863    }
864
865    /// Add a workflow stage
866    pub fn add_stage(mut self, stage: Arc<dyn WorkflowStageExecutor>) -> Self {
867        self.stages.push(stage);
868        self
869    }
870
871    /// Build the workflow
872    pub fn build(self) -> Result<Workflow, WorkflowError> {
873        if self.stages.is_empty() {
874            return Err(WorkflowError::ValidationError {
875                message: "Workflow must have at least one stage".to_string(),
876            });
877        }
878
879        // Validate stage dependencies
880        self.validate_dependencies()?;
881
882        Ok(Workflow::new(self.config, self.stages))
883    }
884
885    /// Validate stage dependencies
886    fn validate_dependencies(&self) -> Result<(), WorkflowError> {
887        let stage_names: Vec<String> = self
888            .stages
889            .iter()
890            .map(|s| s.config().name.clone())
891            .collect();
892
893        for stage in &self.stages {
894            for dep in stage.dependencies() {
895                if !stage_names.contains(&dep) {
896                    return Err(WorkflowError::DependencyError {
897                        message: format!(
898                            "Stage '{}' depends on '{}' which is not in the workflow",
899                            stage.config().name,
900                            dep
901                        ),
902                    });
903                }
904            }
905        }
906
907        Ok(())
908    }
909}
910
911/// Workflow stage factory
912pub struct WorkflowStage;
913
914impl WorkflowStage {
915    /// Create quality evaluation stage
916    pub fn quality_evaluation(config: StageConfig) -> Arc<dyn WorkflowStageExecutor> {
917        // This is a simplified version - in production you'd want async construction
918        // For now, we'll create a placeholder that will be properly initialized
919        Arc::new(ExportResultsStage::new(config))
920    }
921
922    /// Create pronunciation evaluation stage
923    pub fn pronunciation_evaluation(config: StageConfig) -> Arc<dyn WorkflowStageExecutor> {
924        Arc::new(ExportResultsStage::new(config))
925    }
926
927    /// Create export results stage
928    pub fn export_results(config: StageConfig) -> Arc<dyn WorkflowStageExecutor> {
929        Arc::new(ExportResultsStage::new(config))
930    }
931}
932
933#[cfg(test)]
934mod tests {
935    use super::*;
936
937    #[test]
938    fn test_stage_config_default() {
939        let config = StageConfig::default();
940        assert_eq!(config.max_retries, 3);
941        assert!(config.enable_cache);
942        assert_eq!(config.condition, StageCondition::Always);
943    }
944
945    #[test]
946    fn test_workflow_config_default() {
947        let config = WorkflowConfig::default();
948        assert_eq!(config.max_parallel_stages, 4);
949        assert!(config.enable_cache);
950    }
951
952    #[test]
953    fn test_retry_policy_default() {
954        let policy = RetryPolicy::default();
955        assert_eq!(policy.max_attempts, 3);
956        assert_eq!(policy.initial_delay_ms, 1000);
957        assert!((policy.backoff_multiplier - 2.0).abs() < f64::EPSILON);
958    }
959
960    #[test]
961    fn test_workflow_builder() {
962        let builder = WorkflowBuilder::new("test_workflow")
963            .description("Test workflow")
964            .max_parallel_stages(2)
965            .enable_cache(false);
966
967        assert_eq!(builder.config.name, "test_workflow");
968        assert_eq!(builder.config.max_parallel_stages, 2);
969        assert!(!builder.config.enable_cache);
970    }
971
972    #[test]
973    fn test_workflow_context() {
974        let context = WorkflowContext::new();
975        assert!(context.cache_config.is_none());
976    }
977
978    #[test]
979    fn test_stage_status() {
980        let status = StageStatus::Success;
981        assert_eq!(status, StageStatus::Success);
982        assert_ne!(status, StageStatus::Failed);
983    }
984
985    #[test]
986    fn test_workflow_error_display() {
987        let error = WorkflowError::ValidationError {
988            message: "Test error".to_string(),
989        };
990        assert!(error.to_string().contains("Test error"));
991    }
992
993    #[tokio::test]
994    async fn test_export_results_stage_creation() {
995        let config = StageConfig {
996            name: "export".to_string(),
997            stage_type: StageType::ExportResults,
998            ..Default::default()
999        };
1000
1001        let stage = ExportResultsStage::new(config);
1002        assert_eq!(stage.config().name, "export");
1003    }
1004
1005    #[tokio::test]
1006    async fn test_workflow_context_async_operations() {
1007        let context = WorkflowContext::new();
1008        context
1009            .set_parameter("test_key".to_string(), serde_json::json!("test_value"))
1010            .await;
1011
1012        let result = StageResult {
1013            stage_name: "test_stage".to_string(),
1014            status: StageStatus::Success,
1015            duration_ms: 100,
1016            quality_score: None,
1017            custom_results: HashMap::new(),
1018            error_message: None,
1019            retry_count: 0,
1020        };
1021
1022        context
1023            .set_result("test_stage".to_string(), result.clone())
1024            .await;
1025        let retrieved = context.get_result("test_stage").await;
1026        assert!(retrieved.is_some());
1027        assert_eq!(retrieved.unwrap().stage_name, "test_stage");
1028    }
1029}