Skip to main content

voirs_cli/workflow/
definition.rs

1//! Workflow Definition Types
2//!
3//! This module defines the data structures for workflow definitions,
4//! including steps, conditions, retries, and metadata.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::Path;
9
10use crate::error::CliError;
11
12type Result<T> = std::result::Result<T, CliError>;
13
14/// Workflow definition
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Workflow {
17    /// Workflow metadata
18    pub metadata: WorkflowMetadata,
19    /// Global variables
20    #[serde(default)]
21    pub variables: HashMap<String, Variable>,
22    /// Workflow steps
23    pub steps: Vec<Step>,
24    /// Workflow-level configuration
25    #[serde(default)]
26    pub config: WorkflowConfig,
27}
28
29impl Workflow {
30    /// Create a new workflow
31    pub fn new(name: &str, version: &str, description: &str) -> Self {
32        Self {
33            metadata: WorkflowMetadata {
34                name: name.to_string(),
35                version: version.to_string(),
36                description: description.to_string(),
37                author: None,
38                tags: Vec::new(),
39            },
40            variables: HashMap::new(),
41            steps: Vec::new(),
42            config: WorkflowConfig::default(),
43        }
44    }
45
46    /// Load workflow from YAML file
47    pub async fn load_from_file(path: &Path) -> Result<Self> {
48        let content = tokio::fs::read_to_string(path).await?;
49
50        if path.extension().is_some_and(|ext| ext == "json") {
51            Ok(serde_json::from_str(&content)?)
52        } else {
53            // Assume YAML for .yaml, .yml, or no extension
54            Ok(serde_yaml::from_str(&content).map_err(|e| {
55                CliError::SerializationError(format!("Failed to parse YAML: {}", e))
56            })?)
57        }
58    }
59
60    /// Save workflow to file
61    pub async fn save_to_file(&self, path: &Path) -> Result<()> {
62        let content = if path.extension().is_some_and(|ext| ext == "json") {
63            serde_json::to_string_pretty(self)?
64        } else {
65            serde_yaml::to_string(self).map_err(|e| {
66                CliError::SerializationError(format!("Failed to serialize to YAML: {}", e))
67            })?
68        };
69
70        tokio::fs::write(path, content).await?;
71        Ok(())
72    }
73
74    /// Add a step to the workflow
75    pub fn add_step(&mut self, step: Step) {
76        self.steps.push(step);
77    }
78
79    /// Add a variable
80    pub fn add_variable(&mut self, name: String, value: Variable) {
81        self.variables.insert(name, value);
82    }
83
84    /// Get step by name
85    pub fn get_step(&self, name: &str) -> Option<&Step> {
86        self.steps.iter().find(|s| s.name == name)
87    }
88
89    /// Validate workflow structure
90    pub fn validate(&self) -> Result<()> {
91        // Check for duplicate step names
92        let mut step_names = std::collections::HashSet::new();
93        for step in &self.steps {
94            if !step_names.insert(&step.name) {
95                return Err(CliError::Workflow(format!(
96                    "Duplicate step name: {}",
97                    step.name
98                )));
99            }
100        }
101
102        // Check dependencies exist
103        for step in &self.steps {
104            for dep in &step.depends_on {
105                if !self.steps.iter().any(|s| s.name == dep.step_name) {
106                    return Err(CliError::Workflow(format!(
107                        "Step '{}' depends on non-existent step '{}'",
108                        step.name, dep.step_name
109                    )));
110                }
111            }
112        }
113
114        Ok(())
115    }
116}
117
118/// Workflow metadata
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct WorkflowMetadata {
121    /// Workflow name
122    pub name: String,
123    /// Version string
124    pub version: String,
125    /// Human-readable description
126    pub description: String,
127    /// Optional author information
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub author: Option<String>,
130    /// Tags for categorization
131    #[serde(default)]
132    pub tags: Vec<String>,
133}
134
135/// Workflow configuration
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct WorkflowConfig {
138    /// Maximum parallel steps
139    #[serde(default = "default_max_parallel")]
140    pub max_parallel: usize,
141    /// Timeout in seconds (0 = no timeout)
142    #[serde(default)]
143    pub timeout_seconds: u64,
144    /// Whether to continue on error
145    #[serde(default)]
146    pub continue_on_error: bool,
147    /// Whether to save state for resumption
148    #[serde(default = "default_true")]
149    pub save_state: bool,
150}
151
152fn default_max_parallel() -> usize {
153    4
154}
155
156fn default_true() -> bool {
157    true
158}
159
160impl Default for WorkflowConfig {
161    fn default() -> Self {
162        Self {
163            max_parallel: default_max_parallel(),
164            timeout_seconds: 0,
165            continue_on_error: false,
166            save_state: true,
167        }
168    }
169}
170
171/// Workflow step definition
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct Step {
174    /// Unique step name
175    pub name: String,
176    /// Step type
177    #[serde(rename = "type")]
178    pub step_type: StepType,
179    /// Optional description
180    #[serde(skip_serializing_if = "Option::is_none")]
181    pub description: Option<String>,
182    /// Step parameters
183    #[serde(default)]
184    pub parameters: HashMap<String, serde_json::Value>,
185    /// Execution condition
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub condition: Option<Condition>,
188    /// Dependencies on other steps
189    #[serde(default)]
190    pub depends_on: Vec<StepDependency>,
191    /// Retry strategy
192    #[serde(skip_serializing_if = "Option::is_none")]
193    pub retry: Option<RetryStrategy>,
194    /// For-each loop variable
195    #[serde(skip_serializing_if = "Option::is_none")]
196    pub for_each: Option<String>,
197    /// Whether to run in parallel with other steps
198    #[serde(default)]
199    pub parallel: bool,
200}
201
202/// Step type enumeration
203#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
204#[serde(rename_all = "lowercase")]
205pub enum StepType {
206    /// Synthesis step
207    Synthesize,
208    /// Quality validation step
209    Validate,
210    /// File operation step
211    FileOp,
212    /// Command execution step
213    Command,
214    /// Script execution step
215    Script,
216    /// Conditional branch step
217    Branch,
218    /// Loop step
219    Loop,
220    /// Sub-workflow invocation
221    Workflow,
222    /// Wait/delay step
223    Wait,
224    /// Notification step
225    Notify,
226}
227
228/// Step dependency
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct StepDependency {
231    /// Name of the step to depend on
232    pub step_name: String,
233    /// Whether the dependency must succeed
234    #[serde(default = "default_true")]
235    pub must_succeed: bool,
236}
237
238/// Condition definition
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct Condition {
241    /// Left operand (variable or value)
242    pub left: String,
243    /// Operator
244    pub operator: ConditionOperator,
245    /// Right operand (variable or value)
246    pub right: String,
247}
248
249impl Condition {
250    /// Create a new condition
251    pub fn new(left: String, operator: ConditionOperator, right: String) -> Self {
252        Self {
253            left,
254            operator,
255            right,
256        }
257    }
258
259    /// Evaluate condition with variable substitution
260    pub fn evaluate(&self, variables: &HashMap<String, serde_json::Value>) -> bool {
261        let left_val = self.resolve_value(&self.left, variables);
262        let right_val = self.resolve_value(&self.right, variables);
263
264        match self.operator {
265            ConditionOperator::Equals => left_val == right_val,
266            ConditionOperator::NotEquals => left_val != right_val,
267            ConditionOperator::GreaterThan => {
268                self.compare_numeric(&left_val, &right_val, |a, b| a > b)
269            }
270            ConditionOperator::LessThan => {
271                self.compare_numeric(&left_val, &right_val, |a, b| a < b)
272            }
273            ConditionOperator::GreaterOrEqual => {
274                self.compare_numeric(&left_val, &right_val, |a, b| a >= b)
275            }
276            ConditionOperator::LessOrEqual => {
277                self.compare_numeric(&left_val, &right_val, |a, b| a <= b)
278            }
279            ConditionOperator::Contains => {
280                if let (Some(left_str), Some(right_str)) = (left_val.as_str(), right_val.as_str()) {
281                    left_str.contains(right_str)
282                } else {
283                    false
284                }
285            }
286            ConditionOperator::Matches => {
287                // Regex match (simplified for now)
288                if let (Some(left_str), Some(right_str)) = (left_val.as_str(), right_val.as_str()) {
289                    regex::Regex::new(right_str)
290                        .map(|re| re.is_match(left_str))
291                        .unwrap_or(false)
292                } else {
293                    false
294                }
295            }
296        }
297    }
298
299    fn resolve_value(
300        &self,
301        value: &str,
302        variables: &HashMap<String, serde_json::Value>,
303    ) -> serde_json::Value {
304        // Check if it's a variable reference ${var}
305        if let Some(var_name) = value.strip_prefix("${").and_then(|s| s.strip_suffix('}')) {
306            variables
307                .get(var_name)
308                .cloned()
309                .unwrap_or(serde_json::Value::Null)
310        } else {
311            // Try to parse as JSON value
312            serde_json::from_str(value)
313                .unwrap_or_else(|_| serde_json::Value::String(value.to_string()))
314        }
315    }
316
317    fn compare_numeric<F>(&self, left: &serde_json::Value, right: &serde_json::Value, op: F) -> bool
318    where
319        F: Fn(f64, f64) -> bool,
320    {
321        match (left.as_f64(), right.as_f64()) {
322            (Some(l), Some(r)) => op(l, r),
323            _ => false,
324        }
325    }
326}
327
328/// Condition operators
329#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
330#[serde(rename_all = "lowercase")]
331pub enum ConditionOperator {
332    /// Equality
333    #[serde(rename = "==")]
334    Equals,
335    /// Inequality
336    #[serde(rename = "!=")]
337    NotEquals,
338    /// Greater than
339    #[serde(rename = ">")]
340    GreaterThan,
341    /// Less than
342    #[serde(rename = "<")]
343    LessThan,
344    /// Greater or equal
345    #[serde(rename = ">=")]
346    GreaterOrEqual,
347    /// Less or equal
348    #[serde(rename = "<=")]
349    LessOrEqual,
350    /// String contains
351    Contains,
352    /// Regex match
353    Matches,
354}
355
356/// Retry strategy
357#[derive(Debug, Clone, Serialize, Deserialize)]
358pub struct RetryStrategy {
359    /// Maximum retry attempts
360    pub max_attempts: usize,
361    /// Backoff strategy
362    pub backoff: BackoffType,
363    /// Initial delay in milliseconds
364    #[serde(default = "default_initial_delay")]
365    pub initial_delay_ms: u64,
366    /// Maximum delay in milliseconds
367    #[serde(default = "default_max_delay")]
368    pub max_delay_ms: u64,
369    /// Multiplier for exponential backoff
370    #[serde(default = "default_backoff_multiplier")]
371    pub backoff_multiplier: f64,
372}
373
374fn default_initial_delay() -> u64 {
375    1000
376}
377
378fn default_max_delay() -> u64 {
379    60_000
380}
381
382fn default_backoff_multiplier() -> f64 {
383    2.0
384}
385
386impl Default for RetryStrategy {
387    fn default() -> Self {
388        Self {
389            max_attempts: 3,
390            backoff: BackoffType::Exponential,
391            initial_delay_ms: default_initial_delay(),
392            max_delay_ms: default_max_delay(),
393            backoff_multiplier: default_backoff_multiplier(),
394        }
395    }
396}
397
398/// Backoff types
399#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
400#[serde(rename_all = "lowercase")]
401pub enum BackoffType {
402    /// Fixed delay
403    Fixed,
404    /// Linear increase
405    Linear,
406    /// Exponential backoff
407    Exponential,
408    /// Exponential with jitter
409    ExponentialJitter,
410}
411
412/// Variable value types
413#[derive(Debug, Clone, Serialize, Deserialize)]
414#[serde(untagged)]
415pub enum Variable {
416    /// String value
417    String(String),
418    /// Numeric value
419    Number(f64),
420    /// Boolean value
421    Boolean(bool),
422    /// Array of values
423    Array(Vec<serde_json::Value>),
424    /// Object/map
425    Object(HashMap<String, serde_json::Value>),
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn test_workflow_creation() {
434        let workflow = Workflow::new("test", "1.0", "Test workflow");
435        assert_eq!(workflow.metadata.name, "test");
436        assert_eq!(workflow.metadata.version, "1.0");
437        assert_eq!(workflow.steps.len(), 0);
438    }
439
440    #[test]
441    fn test_workflow_add_step() {
442        let mut workflow = Workflow::new("test", "1.0", "Test workflow");
443
444        let step = Step {
445            name: "step1".to_string(),
446            step_type: StepType::Synthesize,
447            description: None,
448            parameters: HashMap::new(),
449            condition: None,
450            depends_on: Vec::new(),
451            retry: None,
452            for_each: None,
453            parallel: false,
454        };
455
456        workflow.add_step(step);
457        assert_eq!(workflow.steps.len(), 1);
458        assert_eq!(workflow.steps[0].name, "step1");
459    }
460
461    #[test]
462    fn test_workflow_validation_duplicate_names() {
463        let mut workflow = Workflow::new("test", "1.0", "Test workflow");
464
465        let step1 = Step {
466            name: "duplicate".to_string(),
467            step_type: StepType::Synthesize,
468            description: None,
469            parameters: HashMap::new(),
470            condition: None,
471            depends_on: Vec::new(),
472            retry: None,
473            for_each: None,
474            parallel: false,
475        };
476
477        let step2 = Step {
478            name: "duplicate".to_string(),
479            step_type: StepType::Validate,
480            description: None,
481            parameters: HashMap::new(),
482            condition: None,
483            depends_on: Vec::new(),
484            retry: None,
485            for_each: None,
486            parallel: false,
487        };
488
489        workflow.add_step(step1);
490        workflow.add_step(step2);
491
492        assert!(workflow.validate().is_err());
493    }
494
495    #[test]
496    fn test_condition_evaluation_equals() {
497        let condition = Condition::new(
498            "${status}".to_string(),
499            ConditionOperator::Equals,
500            "success".to_string(),
501        );
502
503        let mut variables = HashMap::new();
504        variables.insert(
505            "status".to_string(),
506            serde_json::Value::String("success".to_string()),
507        );
508
509        assert!(condition.evaluate(&variables));
510    }
511
512    #[test]
513    fn test_condition_evaluation_greater_than() {
514        let condition = Condition::new(
515            "${score}".to_string(),
516            ConditionOperator::GreaterThan,
517            "4.0".to_string(),
518        );
519
520        let mut variables = HashMap::new();
521        variables.insert("score".to_string(), serde_json::json!(4.5));
522
523        assert!(condition.evaluate(&variables));
524    }
525
526    #[test]
527    fn test_condition_evaluation_contains() {
528        let condition = Condition::new(
529            "${output}".to_string(),
530            ConditionOperator::Contains,
531            "error".to_string(),
532        );
533
534        let mut variables = HashMap::new();
535        variables.insert(
536            "output".to_string(),
537            serde_json::Value::String("An error occurred".to_string()),
538        );
539
540        assert!(condition.evaluate(&variables));
541    }
542
543    #[test]
544    fn test_retry_strategy_defaults() {
545        let retry = RetryStrategy::default();
546        assert_eq!(retry.max_attempts, 3);
547        assert_eq!(retry.backoff, BackoffType::Exponential);
548        assert_eq!(retry.initial_delay_ms, 1000);
549    }
550
551    #[test]
552    fn test_workflow_config_defaults() {
553        let config = WorkflowConfig::default();
554        assert_eq!(config.max_parallel, 4);
555        assert_eq!(config.timeout_seconds, 0);
556        assert!(!config.continue_on_error);
557        assert!(config.save_state);
558    }
559}