1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::Path;
9
10use crate::error::CliError;
11
12type Result<T> = std::result::Result<T, CliError>;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Workflow {
17 pub metadata: WorkflowMetadata,
19 #[serde(default)]
21 pub variables: HashMap<String, Variable>,
22 pub steps: Vec<Step>,
24 #[serde(default)]
26 pub config: WorkflowConfig,
27}
28
29impl Workflow {
30 pub fn new(name: &str, version: &str, description: &str) -> Self {
32 Self {
33 metadata: WorkflowMetadata {
34 name: name.to_string(),
35 version: version.to_string(),
36 description: description.to_string(),
37 author: None,
38 tags: Vec::new(),
39 },
40 variables: HashMap::new(),
41 steps: Vec::new(),
42 config: WorkflowConfig::default(),
43 }
44 }
45
46 pub async fn load_from_file(path: &Path) -> Result<Self> {
48 let content = tokio::fs::read_to_string(path).await?;
49
50 if path.extension().is_some_and(|ext| ext == "json") {
51 Ok(serde_json::from_str(&content)?)
52 } else {
53 Ok(serde_yaml::from_str(&content).map_err(|e| {
55 CliError::SerializationError(format!("Failed to parse YAML: {}", e))
56 })?)
57 }
58 }
59
60 pub async fn save_to_file(&self, path: &Path) -> Result<()> {
62 let content = if path.extension().is_some_and(|ext| ext == "json") {
63 serde_json::to_string_pretty(self)?
64 } else {
65 serde_yaml::to_string(self).map_err(|e| {
66 CliError::SerializationError(format!("Failed to serialize to YAML: {}", e))
67 })?
68 };
69
70 tokio::fs::write(path, content).await?;
71 Ok(())
72 }
73
74 pub fn add_step(&mut self, step: Step) {
76 self.steps.push(step);
77 }
78
79 pub fn add_variable(&mut self, name: String, value: Variable) {
81 self.variables.insert(name, value);
82 }
83
84 pub fn get_step(&self, name: &str) -> Option<&Step> {
86 self.steps.iter().find(|s| s.name == name)
87 }
88
89 pub fn validate(&self) -> Result<()> {
91 let mut step_names = std::collections::HashSet::new();
93 for step in &self.steps {
94 if !step_names.insert(&step.name) {
95 return Err(CliError::Workflow(format!(
96 "Duplicate step name: {}",
97 step.name
98 )));
99 }
100 }
101
102 for step in &self.steps {
104 for dep in &step.depends_on {
105 if !self.steps.iter().any(|s| s.name == dep.step_name) {
106 return Err(CliError::Workflow(format!(
107 "Step '{}' depends on non-existent step '{}'",
108 step.name, dep.step_name
109 )));
110 }
111 }
112 }
113
114 Ok(())
115 }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct WorkflowMetadata {
121 pub name: String,
123 pub version: String,
125 pub description: String,
127 #[serde(skip_serializing_if = "Option::is_none")]
129 pub author: Option<String>,
130 #[serde(default)]
132 pub tags: Vec<String>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct WorkflowConfig {
138 #[serde(default = "default_max_parallel")]
140 pub max_parallel: usize,
141 #[serde(default)]
143 pub timeout_seconds: u64,
144 #[serde(default)]
146 pub continue_on_error: bool,
147 #[serde(default = "default_true")]
149 pub save_state: bool,
150}
151
152fn default_max_parallel() -> usize {
153 4
154}
155
156fn default_true() -> bool {
157 true
158}
159
160impl Default for WorkflowConfig {
161 fn default() -> Self {
162 Self {
163 max_parallel: default_max_parallel(),
164 timeout_seconds: 0,
165 continue_on_error: false,
166 save_state: true,
167 }
168 }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct Step {
174 pub name: String,
176 #[serde(rename = "type")]
178 pub step_type: StepType,
179 #[serde(skip_serializing_if = "Option::is_none")]
181 pub description: Option<String>,
182 #[serde(default)]
184 pub parameters: HashMap<String, serde_json::Value>,
185 #[serde(skip_serializing_if = "Option::is_none")]
187 pub condition: Option<Condition>,
188 #[serde(default)]
190 pub depends_on: Vec<StepDependency>,
191 #[serde(skip_serializing_if = "Option::is_none")]
193 pub retry: Option<RetryStrategy>,
194 #[serde(skip_serializing_if = "Option::is_none")]
196 pub for_each: Option<String>,
197 #[serde(default)]
199 pub parallel: bool,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
204#[serde(rename_all = "lowercase")]
205pub enum StepType {
206 Synthesize,
208 Validate,
210 FileOp,
212 Command,
214 Script,
216 Branch,
218 Loop,
220 Workflow,
222 Wait,
224 Notify,
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct StepDependency {
231 pub step_name: String,
233 #[serde(default = "default_true")]
235 pub must_succeed: bool,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct Condition {
241 pub left: String,
243 pub operator: ConditionOperator,
245 pub right: String,
247}
248
249impl Condition {
250 pub fn new(left: String, operator: ConditionOperator, right: String) -> Self {
252 Self {
253 left,
254 operator,
255 right,
256 }
257 }
258
259 pub fn evaluate(&self, variables: &HashMap<String, serde_json::Value>) -> bool {
261 let left_val = self.resolve_value(&self.left, variables);
262 let right_val = self.resolve_value(&self.right, variables);
263
264 match self.operator {
265 ConditionOperator::Equals => left_val == right_val,
266 ConditionOperator::NotEquals => left_val != right_val,
267 ConditionOperator::GreaterThan => {
268 self.compare_numeric(&left_val, &right_val, |a, b| a > b)
269 }
270 ConditionOperator::LessThan => {
271 self.compare_numeric(&left_val, &right_val, |a, b| a < b)
272 }
273 ConditionOperator::GreaterOrEqual => {
274 self.compare_numeric(&left_val, &right_val, |a, b| a >= b)
275 }
276 ConditionOperator::LessOrEqual => {
277 self.compare_numeric(&left_val, &right_val, |a, b| a <= b)
278 }
279 ConditionOperator::Contains => {
280 if let (Some(left_str), Some(right_str)) = (left_val.as_str(), right_val.as_str()) {
281 left_str.contains(right_str)
282 } else {
283 false
284 }
285 }
286 ConditionOperator::Matches => {
287 if let (Some(left_str), Some(right_str)) = (left_val.as_str(), right_val.as_str()) {
289 regex::Regex::new(right_str)
290 .map(|re| re.is_match(left_str))
291 .unwrap_or(false)
292 } else {
293 false
294 }
295 }
296 }
297 }
298
299 fn resolve_value(
300 &self,
301 value: &str,
302 variables: &HashMap<String, serde_json::Value>,
303 ) -> serde_json::Value {
304 if let Some(var_name) = value.strip_prefix("${").and_then(|s| s.strip_suffix('}')) {
306 variables
307 .get(var_name)
308 .cloned()
309 .unwrap_or(serde_json::Value::Null)
310 } else {
311 serde_json::from_str(value)
313 .unwrap_or_else(|_| serde_json::Value::String(value.to_string()))
314 }
315 }
316
317 fn compare_numeric<F>(&self, left: &serde_json::Value, right: &serde_json::Value, op: F) -> bool
318 where
319 F: Fn(f64, f64) -> bool,
320 {
321 match (left.as_f64(), right.as_f64()) {
322 (Some(l), Some(r)) => op(l, r),
323 _ => false,
324 }
325 }
326}
327
328#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
330#[serde(rename_all = "lowercase")]
331pub enum ConditionOperator {
332 #[serde(rename = "==")]
334 Equals,
335 #[serde(rename = "!=")]
337 NotEquals,
338 #[serde(rename = ">")]
340 GreaterThan,
341 #[serde(rename = "<")]
343 LessThan,
344 #[serde(rename = ">=")]
346 GreaterOrEqual,
347 #[serde(rename = "<=")]
349 LessOrEqual,
350 Contains,
352 Matches,
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize)]
358pub struct RetryStrategy {
359 pub max_attempts: usize,
361 pub backoff: BackoffType,
363 #[serde(default = "default_initial_delay")]
365 pub initial_delay_ms: u64,
366 #[serde(default = "default_max_delay")]
368 pub max_delay_ms: u64,
369 #[serde(default = "default_backoff_multiplier")]
371 pub backoff_multiplier: f64,
372}
373
374fn default_initial_delay() -> u64 {
375 1000
376}
377
378fn default_max_delay() -> u64 {
379 60_000
380}
381
382fn default_backoff_multiplier() -> f64 {
383 2.0
384}
385
386impl Default for RetryStrategy {
387 fn default() -> Self {
388 Self {
389 max_attempts: 3,
390 backoff: BackoffType::Exponential,
391 initial_delay_ms: default_initial_delay(),
392 max_delay_ms: default_max_delay(),
393 backoff_multiplier: default_backoff_multiplier(),
394 }
395 }
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
400#[serde(rename_all = "lowercase")]
401pub enum BackoffType {
402 Fixed,
404 Linear,
406 Exponential,
408 ExponentialJitter,
410}
411
412#[derive(Debug, Clone, Serialize, Deserialize)]
414#[serde(untagged)]
415pub enum Variable {
416 String(String),
418 Number(f64),
420 Boolean(bool),
422 Array(Vec<serde_json::Value>),
424 Object(HashMap<String, serde_json::Value>),
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[test]
433 fn test_workflow_creation() {
434 let workflow = Workflow::new("test", "1.0", "Test workflow");
435 assert_eq!(workflow.metadata.name, "test");
436 assert_eq!(workflow.metadata.version, "1.0");
437 assert_eq!(workflow.steps.len(), 0);
438 }
439
440 #[test]
441 fn test_workflow_add_step() {
442 let mut workflow = Workflow::new("test", "1.0", "Test workflow");
443
444 let step = Step {
445 name: "step1".to_string(),
446 step_type: StepType::Synthesize,
447 description: None,
448 parameters: HashMap::new(),
449 condition: None,
450 depends_on: Vec::new(),
451 retry: None,
452 for_each: None,
453 parallel: false,
454 };
455
456 workflow.add_step(step);
457 assert_eq!(workflow.steps.len(), 1);
458 assert_eq!(workflow.steps[0].name, "step1");
459 }
460
461 #[test]
462 fn test_workflow_validation_duplicate_names() {
463 let mut workflow = Workflow::new("test", "1.0", "Test workflow");
464
465 let step1 = Step {
466 name: "duplicate".to_string(),
467 step_type: StepType::Synthesize,
468 description: None,
469 parameters: HashMap::new(),
470 condition: None,
471 depends_on: Vec::new(),
472 retry: None,
473 for_each: None,
474 parallel: false,
475 };
476
477 let step2 = Step {
478 name: "duplicate".to_string(),
479 step_type: StepType::Validate,
480 description: None,
481 parameters: HashMap::new(),
482 condition: None,
483 depends_on: Vec::new(),
484 retry: None,
485 for_each: None,
486 parallel: false,
487 };
488
489 workflow.add_step(step1);
490 workflow.add_step(step2);
491
492 assert!(workflow.validate().is_err());
493 }
494
495 #[test]
496 fn test_condition_evaluation_equals() {
497 let condition = Condition::new(
498 "${status}".to_string(),
499 ConditionOperator::Equals,
500 "success".to_string(),
501 );
502
503 let mut variables = HashMap::new();
504 variables.insert(
505 "status".to_string(),
506 serde_json::Value::String("success".to_string()),
507 );
508
509 assert!(condition.evaluate(&variables));
510 }
511
512 #[test]
513 fn test_condition_evaluation_greater_than() {
514 let condition = Condition::new(
515 "${score}".to_string(),
516 ConditionOperator::GreaterThan,
517 "4.0".to_string(),
518 );
519
520 let mut variables = HashMap::new();
521 variables.insert("score".to_string(), serde_json::json!(4.5));
522
523 assert!(condition.evaluate(&variables));
524 }
525
526 #[test]
527 fn test_condition_evaluation_contains() {
528 let condition = Condition::new(
529 "${output}".to_string(),
530 ConditionOperator::Contains,
531 "error".to_string(),
532 );
533
534 let mut variables = HashMap::new();
535 variables.insert(
536 "output".to_string(),
537 serde_json::Value::String("An error occurred".to_string()),
538 );
539
540 assert!(condition.evaluate(&variables));
541 }
542
543 #[test]
544 fn test_retry_strategy_defaults() {
545 let retry = RetryStrategy::default();
546 assert_eq!(retry.max_attempts, 3);
547 assert_eq!(retry.backoff, BackoffType::Exponential);
548 assert_eq!(retry.initial_delay_ms, 1000);
549 }
550
551 #[test]
552 fn test_workflow_config_defaults() {
553 let config = WorkflowConfig::default();
554 assert_eq!(config.max_parallel, 4);
555 assert_eq!(config.timeout_seconds, 0);
556 assert!(!config.continue_on_error);
557 assert!(config.save_state);
558 }
559}