1use async_trait::async_trait;
39use scirs2_core::random::Rng;
40use serde::{Deserialize, Serialize};
41use std::collections::HashMap;
42use std::path::PathBuf;
43use std::sync::Arc;
44use std::time::{Duration, SystemTime};
45use thiserror::Error;
46use tokio::sync::{RwLock, Semaphore};
47use tracing::{debug, error, info, warn};
48use voirs_sdk::{AudioBuffer, VoirsError};
49
50use crate::caching::CacheConfig;
51use crate::pronunciation::PronunciationEvaluatorImpl;
52use crate::quality::QualityEvaluator;
53use crate::traits::{
54 PronunciationEvaluator as PronunciationEvaluatorTrait, PronunciationScore,
55 QualityEvaluationConfig, QualityEvaluator as QualityEvaluatorTrait, QualityScore,
56};
57
58#[derive(Error, Debug)]
60pub enum WorkflowError {
61 #[error("Stage '{stage}' execution failed: {message}")]
63 StageExecutionError {
64 stage: String,
66 message: String,
68 #[source]
70 source: Option<Box<dyn std::error::Error + Send + Sync>>,
71 },
72
73 #[error("Workflow validation failed: {message}")]
75 ValidationError {
76 message: String,
78 },
79
80 #[error("Workflow configuration error: {message}")]
82 ConfigurationError {
83 message: String,
85 },
86
87 #[error("Dependency error: {message}")]
89 DependencyError {
90 message: String,
92 },
93
94 #[error("Workflow timed out after {duration:?}")]
96 TimeoutError {
97 duration: Duration,
99 },
100
101 #[error("Condition evaluation failed: {message}")]
103 ConditionError {
104 message: String,
106 },
107
108 #[error("VoiRS error: {0}")]
110 VoirsError(#[from] VoirsError),
111
112 #[error("Serialization error: {0}")]
114 SerializationError(#[from] serde_json::Error),
115
116 #[error("IO error: {0}")]
118 IoError(#[from] std::io::Error),
119
120 #[error("Evaluation error: {0}")]
122 EvaluationError(#[from] crate::EvaluationError),
123}
124
125#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
127pub enum StageType {
128 QualityEvaluation,
130 PronunciationEvaluation,
132 Preprocessing,
134 FeatureExtraction,
136 StatisticalAnalysis,
138 ExportResults,
140 Custom(String),
142}
143
144#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
146pub enum StageCondition {
147 Always,
149 OnSuccess,
151 OnFailure,
153 MetricThreshold {
155 metric: String,
157 min_value: f64,
159 max_value: Option<f64>,
161 },
162 Custom(String),
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct StageConfig {
169 pub name: String,
171 pub stage_type: StageType,
173 pub condition: StageCondition,
175 pub max_retries: usize,
177 pub timeout_seconds: Option<u64>,
179 pub enable_cache: bool,
181 pub parameters: HashMap<String, serde_json::Value>,
183 pub dependencies: Vec<String>,
185}
186
187impl Default for StageConfig {
188 fn default() -> Self {
189 Self {
190 name: String::new(),
191 stage_type: StageType::Custom("default".to_string()),
192 condition: StageCondition::Always,
193 max_retries: 3,
194 timeout_seconds: Some(300), enable_cache: true,
196 parameters: HashMap::new(),
197 dependencies: Vec::new(),
198 }
199 }
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct StageResult {
205 pub stage_name: String,
207 pub status: StageStatus,
209 pub duration_ms: u64,
211 pub quality_score: Option<QualityScore>,
213 pub custom_results: HashMap<String, serde_json::Value>,
215 pub error_message: Option<String>,
217 pub retry_count: usize,
219}
220
221#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
223pub enum StageStatus {
224 Pending,
226 Running,
228 Success,
230 Failed,
232 Skipped,
234 Timeout,
236}
237
238#[async_trait]
240pub trait WorkflowStageExecutor: Send + Sync {
241 fn config(&self) -> &StageConfig;
243
244 async fn execute(
246 &self,
247 audio: &AudioBuffer,
248 reference: Option<&AudioBuffer>,
249 context: &WorkflowContext,
250 ) -> Result<StageResult, WorkflowError>;
251
252 fn validate(&self) -> Result<(), WorkflowError> {
254 Ok(())
255 }
256
257 fn dependencies(&self) -> Vec<String> {
259 self.config().dependencies.clone()
260 }
261}
262
263pub struct QualityEvaluationStage {
265 config: StageConfig,
266 evaluator: Arc<RwLock<QualityEvaluator>>,
267}
268
269impl QualityEvaluationStage {
270 pub async fn new(config: StageConfig) -> Result<Self, WorkflowError> {
272 let evaluator =
273 QualityEvaluator::new()
274 .await
275 .map_err(|e| WorkflowError::ConfigurationError {
276 message: format!("Failed to create quality evaluator: {}", e),
277 })?;
278
279 Ok(Self {
280 config,
281 evaluator: Arc::new(RwLock::new(evaluator)),
282 })
283 }
284}
285
286#[async_trait]
287impl WorkflowStageExecutor for QualityEvaluationStage {
288 fn config(&self) -> &StageConfig {
289 &self.config
290 }
291
292 async fn execute(
293 &self,
294 audio: &AudioBuffer,
295 reference: Option<&AudioBuffer>,
296 _context: &WorkflowContext,
297 ) -> Result<StageResult, WorkflowError> {
298 let start = SystemTime::now();
299 let evaluator = self.evaluator.read().await;
300
301 let eval_config = QualityEvaluationConfig::default();
302 let quality_score = evaluator
303 .evaluate_quality(audio, reference, Some(&eval_config))
304 .await?;
305
306 let duration_ms = SystemTime::now()
307 .duration_since(start)
308 .unwrap_or(Duration::ZERO)
309 .as_millis() as u64;
310
311 Ok(StageResult {
312 stage_name: self.config.name.clone(),
313 status: StageStatus::Success,
314 duration_ms,
315 quality_score: Some(quality_score),
316 custom_results: HashMap::new(),
317 error_message: None,
318 retry_count: 0,
319 })
320 }
321}
322
323pub struct PronunciationEvaluationStage {
325 config: StageConfig,
326 evaluator: Arc<RwLock<PronunciationEvaluatorImpl>>,
327}
328
329impl PronunciationEvaluationStage {
330 pub async fn new(config: StageConfig) -> Result<Self, WorkflowError> {
332 let evaluator = PronunciationEvaluatorImpl::new().await.map_err(|e| {
333 WorkflowError::ConfigurationError {
334 message: format!("Failed to create pronunciation evaluator: {}", e),
335 }
336 })?;
337
338 Ok(Self {
339 config,
340 evaluator: Arc::new(RwLock::new(evaluator)),
341 })
342 }
343}
344
345#[async_trait]
346impl WorkflowStageExecutor for PronunciationEvaluationStage {
347 fn config(&self) -> &StageConfig {
348 &self.config
349 }
350
351 async fn execute(
352 &self,
353 audio: &AudioBuffer,
354 _reference: Option<&AudioBuffer>,
355 context: &WorkflowContext,
356 ) -> Result<StageResult, WorkflowError> {
357 let start = SystemTime::now();
358
359 let expected_text = "Hello world"; let _language = "en-US"; let evaluator = self.evaluator.read().await;
366 let pronunciation_score = evaluator
367 .evaluate_pronunciation(audio, expected_text, None)
368 .await?;
369
370 let duration_ms = SystemTime::now()
371 .duration_since(start)
372 .unwrap_or(Duration::ZERO)
373 .as_millis() as u64;
374
375 let mut custom_results = HashMap::new();
377 custom_results.insert(
378 "pronunciation_score".to_string(),
379 serde_json::json!({
380 "overall_score": pronunciation_score.overall_score,
381 "fluency_score": pronunciation_score.fluency_score,
382 "rhythm_score": pronunciation_score.rhythm_score,
383 }),
384 );
385
386 Ok(StageResult {
387 stage_name: self.config.name.clone(),
388 status: StageStatus::Success,
389 duration_ms,
390 quality_score: None,
391 custom_results,
392 error_message: None,
393 retry_count: 0,
394 })
395 }
396}
397
398pub struct ExportResultsStage {
400 config: StageConfig,
401}
402
403impl ExportResultsStage {
404 pub fn new(config: StageConfig) -> Self {
406 Self { config }
407 }
408}
409
410#[async_trait]
411impl WorkflowStageExecutor for ExportResultsStage {
412 fn config(&self) -> &StageConfig {
413 &self.config
414 }
415
416 async fn execute(
417 &self,
418 _audio: &AudioBuffer,
419 _reference: Option<&AudioBuffer>,
420 context: &WorkflowContext,
421 ) -> Result<StageResult, WorkflowError> {
422 let start = SystemTime::now();
423
424 let output_path = self
426 .config
427 .parameters
428 .get("output_path")
429 .and_then(|v| v.as_str())
430 .ok_or_else(|| WorkflowError::ConfigurationError {
431 message: "Missing 'output_path' parameter for export stage".to_string(),
432 })?;
433
434 let workflow_results = context.get_all_results();
436 let output_data = serde_json::to_string_pretty(&workflow_results)?;
437
438 tokio::fs::write(output_path, output_data).await?;
440
441 let duration_ms = SystemTime::now()
442 .duration_since(start)
443 .unwrap_or(Duration::ZERO)
444 .as_millis() as u64;
445
446 info!("Exported workflow results to: {}", output_path);
447
448 Ok(StageResult {
449 stage_name: self.config.name.clone(),
450 status: StageStatus::Success,
451 duration_ms,
452 quality_score: None,
453 custom_results: HashMap::new(),
454 error_message: None,
455 retry_count: 0,
456 })
457 }
458}
459
460#[derive(Clone)]
462pub struct WorkflowContext {
463 parameters: Arc<RwLock<HashMap<String, serde_json::Value>>>,
464 results: Arc<RwLock<HashMap<String, StageResult>>>,
465 cache_config: Option<CacheConfig>,
466}
467
468impl WorkflowContext {
469 pub fn new() -> Self {
471 Self {
472 parameters: Arc::new(RwLock::new(HashMap::new())),
473 results: Arc::new(RwLock::new(HashMap::new())),
474 cache_config: None,
475 }
476 }
477
478 pub async fn set_parameter(&self, key: String, value: serde_json::Value) {
480 let mut params = self.parameters.write().await;
481 params.insert(key, value);
482 }
483
484 pub fn get_parameter(&self, key: &str) -> Option<serde_json::Value> {
486 None
489 }
490
491 pub async fn set_result(&self, stage_name: String, result: StageResult) {
493 let mut results = self.results.write().await;
494 results.insert(stage_name, result);
495 }
496
497 pub async fn get_result(&self, stage_name: &str) -> Option<StageResult> {
499 let results = self.results.read().await;
500 results.get(stage_name).cloned()
501 }
502
503 pub fn get_all_results(&self) -> HashMap<String, StageResult> {
505 HashMap::new()
507 }
508
509 pub fn with_cache(mut self, cache_config: CacheConfig) -> Self {
511 self.cache_config = Some(cache_config);
512 self
513 }
514}
515
516impl Default for WorkflowContext {
517 fn default() -> Self {
518 Self::new()
519 }
520}
521
522#[derive(Debug, Clone, Serialize, Deserialize)]
524pub struct WorkflowConfig {
525 pub name: String,
527 pub description: Option<String>,
529 pub max_parallel_stages: usize,
531 pub global_timeout_seconds: Option<u64>,
533 pub enable_cache: bool,
535 pub cache_config: Option<CacheConfig>,
537 pub retry_policy: RetryPolicy,
539}
540
541impl Default for WorkflowConfig {
542 fn default() -> Self {
543 Self {
544 name: "default_workflow".to_string(),
545 description: None,
546 max_parallel_stages: 4,
547 global_timeout_seconds: Some(1800), enable_cache: true,
549 cache_config: None,
550 retry_policy: RetryPolicy::default(),
551 }
552 }
553}
554
555#[derive(Debug, Clone, Serialize, Deserialize)]
557pub struct RetryPolicy {
558 pub max_attempts: usize,
560 pub initial_delay_ms: u64,
562 pub backoff_multiplier: f64,
564 pub max_delay_ms: u64,
566}
567
568impl Default for RetryPolicy {
569 fn default() -> Self {
570 Self {
571 max_attempts: 3,
572 initial_delay_ms: 1000,
573 backoff_multiplier: 2.0,
574 max_delay_ms: 30000,
575 }
576 }
577}
578
579pub struct Workflow {
581 config: WorkflowConfig,
582 stages: Vec<Arc<dyn WorkflowStageExecutor>>,
583 context: WorkflowContext,
584 semaphore: Arc<Semaphore>,
585}
586
587impl Workflow {
588 pub fn new(config: WorkflowConfig, stages: Vec<Arc<dyn WorkflowStageExecutor>>) -> Self {
590 let semaphore = Arc::new(Semaphore::new(config.max_parallel_stages));
591 let mut context = WorkflowContext::new();
592
593 if config.enable_cache {
595 let cache_config = config.cache_config.clone().unwrap_or_default();
596 context = context.with_cache(cache_config);
597 }
598
599 Self {
600 config,
601 stages,
602 context,
603 semaphore,
604 }
605 }
606
607 pub async fn execute(
609 &self,
610 audio: &AudioBuffer,
611 reference: Option<&AudioBuffer>,
612 ) -> Result<WorkflowResult, WorkflowError> {
613 let workflow_start = SystemTime::now();
614 info!("Starting workflow: {}", self.config.name);
615
616 let mut stage_results = Vec::new();
617
618 for stage in &self.stages {
619 let _permit =
620 self.semaphore
621 .acquire()
622 .await
623 .map_err(|e| WorkflowError::StageExecutionError {
624 stage: stage.config().name.clone(),
625 message: format!("Failed to acquire semaphore: {}", e),
626 source: None,
627 })?;
628
629 if !self
631 .should_execute_stage(stage.config(), &stage_results)
632 .await?
633 {
634 info!("Skipping stage '{}' due to condition", stage.config().name);
635 stage_results.push(StageResult {
636 stage_name: stage.config().name.clone(),
637 status: StageStatus::Skipped,
638 duration_ms: 0,
639 quality_score: None,
640 custom_results: HashMap::new(),
641 error_message: None,
642 retry_count: 0,
643 });
644 continue;
645 }
646
647 let result = self
649 .execute_stage_with_retry(stage.as_ref(), audio, reference)
650 .await?;
651
652 self.context
653 .set_result(stage.config().name.clone(), result.clone())
654 .await;
655 stage_results.push(result);
656 }
657
658 let total_duration_ms = SystemTime::now()
659 .duration_since(workflow_start)
660 .unwrap_or(Duration::ZERO)
661 .as_millis() as u64;
662
663 info!(
664 "Workflow '{}' completed in {}ms",
665 self.config.name, total_duration_ms
666 );
667
668 Ok(WorkflowResult {
669 workflow_name: self.config.name.clone(),
670 stage_results,
671 total_duration_ms,
672 status: WorkflowStatus::Success,
673 error_message: None,
674 })
675 }
676
677 async fn should_execute_stage(
679 &self,
680 config: &StageConfig,
681 previous_results: &[StageResult],
682 ) -> Result<bool, WorkflowError> {
683 match &config.condition {
684 StageCondition::Always => Ok(true),
685 StageCondition::OnSuccess => {
686 if let Some(last_result) = previous_results.last() {
687 Ok(last_result.status == StageStatus::Success)
688 } else {
689 Ok(true)
690 }
691 }
692 StageCondition::OnFailure => {
693 if let Some(last_result) = previous_results.last() {
694 Ok(last_result.status == StageStatus::Failed)
695 } else {
696 Ok(false)
697 }
698 }
699 StageCondition::MetricThreshold {
700 metric,
701 min_value,
702 max_value,
703 } => {
704 for result in previous_results.iter().rev() {
706 if let Some(quality) = &result.quality_score {
707 let value = match metric.as_str() {
708 "overall_score" => quality.overall_score as f64,
709 _ => continue, };
711
712 let meets_min = value >= *min_value;
713 let meets_max = max_value.map_or(true, |max| value <= max);
714 return Ok(meets_min && meets_max);
715 }
716 }
717 Ok(false)
718 }
719 StageCondition::Custom(_) => {
720 warn!("Custom conditions not yet implemented, defaulting to true");
721 Ok(true)
722 }
723 }
724 }
725
726 async fn execute_stage_with_retry(
728 &self,
729 stage: &dyn WorkflowStageExecutor,
730 audio: &AudioBuffer,
731 reference: Option<&AudioBuffer>,
732 ) -> Result<StageResult, WorkflowError> {
733 let config = stage.config();
734 let max_retries = config.max_retries;
735 let mut retry_count = 0;
736 let mut last_error = None;
737
738 while retry_count <= max_retries {
739 match stage.execute(audio, reference, &self.context).await {
740 Ok(mut result) => {
741 result.retry_count = retry_count;
742 return Ok(result);
743 }
744 Err(e) => {
745 error!(
746 "Stage '{}' failed (attempt {}/{}): {}",
747 config.name,
748 retry_count + 1,
749 max_retries + 1,
750 e
751 );
752 last_error = Some(e);
753 retry_count += 1;
754
755 if retry_count <= max_retries {
756 let delay_ms = self.calculate_backoff_delay(retry_count);
758 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
759 }
760 }
761 }
762 }
763
764 Ok(StageResult {
766 stage_name: config.name.clone(),
767 status: StageStatus::Failed,
768 duration_ms: 0,
769 quality_score: None,
770 custom_results: HashMap::new(),
771 error_message: Some(
772 last_error
773 .map(|e| e.to_string())
774 .unwrap_or_else(|| "Unknown error".to_string()),
775 ),
776 retry_count,
777 })
778 }
779
780 fn calculate_backoff_delay(&self, retry_count: usize) -> u64 {
782 let policy = &self.config.retry_policy;
783 let delay =
784 policy.initial_delay_ms as f64 * policy.backoff_multiplier.powi(retry_count as i32 - 1);
785 delay.min(policy.max_delay_ms as f64) as u64
786 }
787}
788
789#[derive(Debug, Clone, Serialize, Deserialize)]
791pub struct WorkflowResult {
792 pub workflow_name: String,
794 pub stage_results: Vec<StageResult>,
796 pub total_duration_ms: u64,
798 pub status: WorkflowStatus,
800 pub error_message: Option<String>,
802}
803
804#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
806pub enum WorkflowStatus {
807 Success,
809 PartialSuccess,
811 Failed,
813 Cancelled,
815}
816
817pub struct WorkflowBuilder {
819 config: WorkflowConfig,
820 stages: Vec<Arc<dyn WorkflowStageExecutor>>,
821}
822
823impl WorkflowBuilder {
824 pub fn new(name: impl Into<String>) -> Self {
826 Self {
827 config: WorkflowConfig {
828 name: name.into(),
829 ..Default::default()
830 },
831 stages: Vec::new(),
832 }
833 }
834
835 pub fn description(mut self, description: impl Into<String>) -> Self {
837 self.config.description = Some(description.into());
838 self
839 }
840
841 pub fn max_parallel_stages(mut self, max: usize) -> Self {
843 self.config.max_parallel_stages = max;
844 self
845 }
846
847 pub fn global_timeout(mut self, seconds: u64) -> Self {
849 self.config.global_timeout_seconds = Some(seconds);
850 self
851 }
852
853 pub fn enable_cache(mut self, enable: bool) -> Self {
855 self.config.enable_cache = enable;
856 self
857 }
858
859 pub fn cache_config(mut self, config: CacheConfig) -> Self {
861 self.config.cache_config = Some(config);
862 self
863 }
864
865 pub fn add_stage(mut self, stage: Arc<dyn WorkflowStageExecutor>) -> Self {
867 self.stages.push(stage);
868 self
869 }
870
871 pub fn build(self) -> Result<Workflow, WorkflowError> {
873 if self.stages.is_empty() {
874 return Err(WorkflowError::ValidationError {
875 message: "Workflow must have at least one stage".to_string(),
876 });
877 }
878
879 self.validate_dependencies()?;
881
882 Ok(Workflow::new(self.config, self.stages))
883 }
884
885 fn validate_dependencies(&self) -> Result<(), WorkflowError> {
887 let stage_names: Vec<String> = self
888 .stages
889 .iter()
890 .map(|s| s.config().name.clone())
891 .collect();
892
893 for stage in &self.stages {
894 for dep in stage.dependencies() {
895 if !stage_names.contains(&dep) {
896 return Err(WorkflowError::DependencyError {
897 message: format!(
898 "Stage '{}' depends on '{}' which is not in the workflow",
899 stage.config().name,
900 dep
901 ),
902 });
903 }
904 }
905 }
906
907 Ok(())
908 }
909}
910
911pub struct WorkflowStage;
913
914impl WorkflowStage {
915 pub fn quality_evaluation(config: StageConfig) -> Arc<dyn WorkflowStageExecutor> {
917 Arc::new(ExportResultsStage::new(config))
920 }
921
922 pub fn pronunciation_evaluation(config: StageConfig) -> Arc<dyn WorkflowStageExecutor> {
924 Arc::new(ExportResultsStage::new(config))
925 }
926
927 pub fn export_results(config: StageConfig) -> Arc<dyn WorkflowStageExecutor> {
929 Arc::new(ExportResultsStage::new(config))
930 }
931}
932
933#[cfg(test)]
934mod tests {
935 use super::*;
936
937 #[test]
938 fn test_stage_config_default() {
939 let config = StageConfig::default();
940 assert_eq!(config.max_retries, 3);
941 assert!(config.enable_cache);
942 assert_eq!(config.condition, StageCondition::Always);
943 }
944
945 #[test]
946 fn test_workflow_config_default() {
947 let config = WorkflowConfig::default();
948 assert_eq!(config.max_parallel_stages, 4);
949 assert!(config.enable_cache);
950 }
951
952 #[test]
953 fn test_retry_policy_default() {
954 let policy = RetryPolicy::default();
955 assert_eq!(policy.max_attempts, 3);
956 assert_eq!(policy.initial_delay_ms, 1000);
957 assert!((policy.backoff_multiplier - 2.0).abs() < f64::EPSILON);
958 }
959
960 #[test]
961 fn test_workflow_builder() {
962 let builder = WorkflowBuilder::new("test_workflow")
963 .description("Test workflow")
964 .max_parallel_stages(2)
965 .enable_cache(false);
966
967 assert_eq!(builder.config.name, "test_workflow");
968 assert_eq!(builder.config.max_parallel_stages, 2);
969 assert!(!builder.config.enable_cache);
970 }
971
972 #[test]
973 fn test_workflow_context() {
974 let context = WorkflowContext::new();
975 assert!(context.cache_config.is_none());
976 }
977
978 #[test]
979 fn test_stage_status() {
980 let status = StageStatus::Success;
981 assert_eq!(status, StageStatus::Success);
982 assert_ne!(status, StageStatus::Failed);
983 }
984
985 #[test]
986 fn test_workflow_error_display() {
987 let error = WorkflowError::ValidationError {
988 message: "Test error".to_string(),
989 };
990 assert!(error.to_string().contains("Test error"));
991 }
992
993 #[tokio::test]
994 async fn test_export_results_stage_creation() {
995 let config = StageConfig {
996 name: "export".to_string(),
997 stage_type: StageType::ExportResults,
998 ..Default::default()
999 };
1000
1001 let stage = ExportResultsStage::new(config);
1002 assert_eq!(stage.config().name, "export");
1003 }
1004
1005 #[tokio::test]
1006 async fn test_workflow_context_async_operations() {
1007 let context = WorkflowContext::new();
1008 context
1009 .set_parameter("test_key".to_string(), serde_json::json!("test_value"))
1010 .await;
1011
1012 let result = StageResult {
1013 stage_name: "test_stage".to_string(),
1014 status: StageStatus::Success,
1015 duration_ms: 100,
1016 quality_score: None,
1017 custom_results: HashMap::new(),
1018 error_message: None,
1019 retry_count: 0,
1020 };
1021
1022 context
1023 .set_result("test_stage".to_string(), result.clone())
1024 .await;
1025 let retrieved = context.get_result("test_stage").await;
1026 assert!(retrieved.is_some());
1027 assert_eq!(retrieved.unwrap().stage_name, "test_stage");
1028 }
1029}