1use crate::error::WorkflowError;
5use crate::workflow::templates::{TemplateContext, evaluate_condition, render_value};
6use crate::workflow::types::{
7 ApprovalDecision, ErrorAction, GateType, WorkflowDefinition, WorkflowState, WorkflowStatus,
8};
9use async_trait::async_trait;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::path::PathBuf;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15use uuid::Uuid;
16
17#[async_trait]
20pub trait ToolExecutor: Send + Sync {
21 async fn execute_tool(&self, tool_name: &str, args: Value) -> Result<Value, String>;
22}
23
24#[async_trait]
26pub trait ApprovalHandler: Send + Sync {
27 async fn request_approval(
28 &self,
29 workflow: &str,
30 step_id: &str,
31 message: &str,
32 preview: Option<&str>,
33 ) -> ApprovalDecision;
34}
35
36pub struct AutoApproveHandler;
38
39#[async_trait]
40impl ApprovalHandler for AutoApproveHandler {
41 async fn request_approval(
42 &self,
43 _workflow: &str,
44 _step_id: &str,
45 _message: &str,
46 _preview: Option<&str>,
47 ) -> ApprovalDecision {
48 ApprovalDecision::Approved
49 }
50}
51
52pub struct AutoDenyHandler;
54
55#[async_trait]
56impl ApprovalHandler for AutoDenyHandler {
57 async fn request_approval(
58 &self,
59 _workflow: &str,
60 _step_id: &str,
61 _message: &str,
62 _preview: Option<&str>,
63 ) -> ApprovalDecision {
64 ApprovalDecision::Denied
65 }
66}
67
68pub struct WorkflowExecutor {
70 tool_executor: Arc<dyn ToolExecutor>,
71 approval_handler: Arc<dyn ApprovalHandler>,
72 runs: Arc<Mutex<HashMap<Uuid, WorkflowState>>>,
73 state_path: Option<PathBuf>,
74}
75
76impl WorkflowExecutor {
77 pub fn new(
78 tool_executor: Arc<dyn ToolExecutor>,
79 approval_handler: Arc<dyn ApprovalHandler>,
80 state_path: Option<PathBuf>,
81 ) -> Self {
82 Self {
83 tool_executor,
84 approval_handler,
85 runs: Arc::new(Mutex::new(HashMap::new())),
86 state_path,
87 }
88 }
89
90 pub async fn start(
92 &self,
93 workflow: &WorkflowDefinition,
94 inputs: HashMap<String, Value>,
95 ) -> Result<WorkflowState, WorkflowError> {
96 let mut state = WorkflowState::new(workflow.name.clone(), inputs);
97 state.status = WorkflowStatus::Running;
98 state.updated_at = chrono::Utc::now();
99
100 {
102 let mut runs = self.runs.lock().await;
103 runs.insert(state.run_id, state.clone());
104 }
105
106 let final_state = self.execute_steps(workflow, state).await?;
108
109 {
111 let mut runs = self.runs.lock().await;
112 runs.insert(final_state.run_id, final_state.clone());
113 }
114
115 if let Some(ref path) = self.state_path {
117 self.persist_state(&final_state, path).await?;
118 }
119
120 Ok(final_state)
121 }
122
123 pub async fn resume(
125 &self,
126 run_id: Uuid,
127 workflow: &WorkflowDefinition,
128 decision: ApprovalDecision,
129 ) -> Result<WorkflowState, WorkflowError> {
130 let state = {
131 let runs = self.runs.lock().await;
132 runs.get(&run_id)
133 .cloned()
134 .ok_or(WorkflowError::RunNotFound { run_id })?
135 };
136
137 if state.status != WorkflowStatus::WaitingApproval {
138 return Err(WorkflowError::StepFailed {
139 step: format!("step_{}", state.current_step_index),
140 message: format!("Cannot resume workflow in status: {}", state.status),
141 });
142 }
143
144 let mut state = state;
145
146 match decision {
147 ApprovalDecision::Approved => {
148 let step = &workflow.steps[state.current_step_index];
150 let ctx = TemplateContext::new(state.inputs.clone(), state.step_outputs.clone());
151
152 let rendered_params =
153 render_value(&serde_json::to_value(&step.params).unwrap(), &ctx).map_err(
154 |e| WorkflowError::StepFailed {
155 step: step.id.clone(),
156 message: e.to_string(),
157 },
158 )?;
159
160 let output = self
161 .tool_executor
162 .execute_tool(&step.tool, rendered_params)
163 .await
164 .map_err(|e| WorkflowError::StepFailed {
165 step: step.id.clone(),
166 message: e,
167 })?;
168
169 state.step_outputs.insert(step.id.clone(), output);
170 state.current_step_index += 1;
171 state.status = WorkflowStatus::Running;
172 state.updated_at = chrono::Utc::now();
173
174 let final_state = self.execute_steps(workflow, state).await?;
176 let mut runs = self.runs.lock().await;
177 runs.insert(final_state.run_id, final_state.clone());
178 Ok(final_state)
179 }
180 ApprovalDecision::Denied => {
181 state.status = WorkflowStatus::Failed;
182 state.error = Some("Approval denied by user".to_string());
183 state.updated_at = chrono::Utc::now();
184 let mut runs = self.runs.lock().await;
185 runs.insert(state.run_id, state.clone());
186 Ok(state)
187 }
188 }
189 }
190
191 pub async fn cancel(&self, run_id: Uuid) -> Result<WorkflowState, WorkflowError> {
193 let mut runs = self.runs.lock().await;
194 let state = runs
195 .get_mut(&run_id)
196 .ok_or(WorkflowError::RunNotFound { run_id })?;
197 state.status = WorkflowStatus::Cancelled;
198 state.updated_at = chrono::Utc::now();
199 Ok(state.clone())
200 }
201
202 pub async fn get_status(&self, run_id: Uuid) -> Result<WorkflowState, WorkflowError> {
204 let runs = self.runs.lock().await;
205 runs.get(&run_id)
206 .cloned()
207 .ok_or(WorkflowError::RunNotFound { run_id })
208 }
209
210 pub async fn list_runs(&self) -> Vec<WorkflowState> {
212 let runs = self.runs.lock().await;
213 runs.values().cloned().collect()
214 }
215
216 async fn execute_steps(
218 &self,
219 workflow: &WorkflowDefinition,
220 mut state: WorkflowState,
221 ) -> Result<WorkflowState, WorkflowError> {
222 while state.current_step_index < workflow.steps.len() {
223 let step = &workflow.steps[state.current_step_index];
224 let ctx = TemplateContext::new(state.inputs.clone(), state.step_outputs.clone());
225
226 if let Some(ref condition) = step.condition {
228 let should_run = evaluate_condition(condition, &ctx).unwrap_or(false);
229 if !should_run {
230 state.current_step_index += 1;
231 state.updated_at = chrono::Utc::now();
232 continue;
233 }
234 }
235
236 if let Some(ref gate) = step.gate
238 && gate.gate_type == GateType::ApprovalRequired
239 {
240 let message = step.gate_message.as_deref().unwrap_or(&gate.message);
241 let preview = step.gate_preview.as_deref().or(gate.preview.as_deref());
242
243 let decision = self
244 .approval_handler
245 .request_approval(&state.workflow_name, &step.id, message, preview)
246 .await;
247
248 match decision {
249 ApprovalDecision::Denied => {
250 state.status = WorkflowStatus::WaitingApproval;
251 state.updated_at = chrono::Utc::now();
252 return Ok(state);
253 }
254 ApprovalDecision::Approved => {
255 }
257 }
258 }
259
260 let params_value =
262 serde_json::to_value(&step.params).unwrap_or(Value::Object(Default::default()));
263 let rendered_params =
264 render_value(¶ms_value, &ctx).map_err(|e| WorkflowError::StepFailed {
265 step: step.id.clone(),
266 message: e.to_string(),
267 })?;
268
269 let result = self
271 .tool_executor
272 .execute_tool(&step.tool, rendered_params)
273 .await;
274
275 match result {
276 Ok(output) => {
277 state.step_outputs.insert(step.id.clone(), output);
278 state.current_step_index += 1;
279 state.updated_at = chrono::Utc::now();
280 }
281 Err(err) => match &step.on_error {
282 Some(ErrorAction::Skip) => {
283 state
284 .step_outputs
285 .insert(step.id.clone(), Value::String(format!("skipped: {}", err)));
286 state.current_step_index += 1;
287 state.updated_at = chrono::Utc::now();
288 }
289 Some(ErrorAction::Retry { max_retries }) => {
290 let mut retries = 0;
291 let mut last_err = err;
292 while retries < *max_retries {
293 retries += 1;
294 let ctx2 = TemplateContext::new(
295 state.inputs.clone(),
296 state.step_outputs.clone(),
297 );
298 let params_value2 = serde_json::to_value(&step.params)
299 .unwrap_or(Value::Object(Default::default()));
300 let rendered2 = render_value(¶ms_value2, &ctx2).map_err(|e| {
301 WorkflowError::StepFailed {
302 step: step.id.clone(),
303 message: e.to_string(),
304 }
305 })?;
306 match self.tool_executor.execute_tool(&step.tool, rendered2).await {
307 Ok(output) => {
308 state.step_outputs.insert(step.id.clone(), output);
309 state.current_step_index += 1;
310 state.updated_at = chrono::Utc::now();
311 last_err = String::new();
312 break;
313 }
314 Err(e) => {
315 last_err = e;
316 }
317 }
318 }
319 if !last_err.is_empty() {
320 state.status = WorkflowStatus::Failed;
321 state.error = Some(format!(
322 "Step '{}' failed after {} retries: {}",
323 step.id, max_retries, last_err
324 ));
325 state.updated_at = chrono::Utc::now();
326 return Ok(state);
327 }
328 }
329 Some(ErrorAction::Fail) | None => {
330 state.status = WorkflowStatus::Failed;
331 state.error = Some(format!("Step '{}' failed: {}", step.id, err));
332 state.updated_at = chrono::Utc::now();
333 return Ok(state);
334 }
335 },
336 }
337 }
338
339 state.status = WorkflowStatus::Completed;
340 state.updated_at = chrono::Utc::now();
341 Ok(state)
342 }
343
344 async fn persist_state(
346 &self,
347 state: &WorkflowState,
348 base_path: &PathBuf,
349 ) -> Result<(), WorkflowError> {
350 let file_path = base_path.join(format!("{}.json", state.run_id));
351 let json = serde_json::to_string_pretty(state).map_err(|e| WorkflowError::StepFailed {
352 step: "persistence".to_string(),
353 message: e.to_string(),
354 })?;
355 tokio::fs::create_dir_all(base_path)
356 .await
357 .map_err(|e| WorkflowError::StepFailed {
358 step: "persistence".to_string(),
359 message: e.to_string(),
360 })?;
361 tokio::fs::write(file_path, json)
362 .await
363 .map_err(|e| WorkflowError::StepFailed {
364 step: "persistence".to_string(),
365 message: e.to_string(),
366 })?;
367 Ok(())
368 }
369
370 pub async fn load_state(
372 base_path: &std::path::Path,
373 run_id: Uuid,
374 ) -> Result<WorkflowState, WorkflowError> {
375 let file_path = base_path.join(format!("{}.json", run_id));
376 let json = tokio::fs::read_to_string(file_path)
377 .await
378 .map_err(|_| WorkflowError::RunNotFound { run_id })?;
379 serde_json::from_str(&json).map_err(|e| WorkflowError::ParseError {
380 message: e.to_string(),
381 })
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use crate::workflow::parser::parse_workflow;
389
390 struct MockToolExecutor {
392 responses: Mutex<Vec<Result<Value, String>>>,
393 }
394
395 impl MockToolExecutor {
396 fn new(responses: Vec<Result<Value, String>>) -> Self {
397 Self {
398 responses: Mutex::new(responses),
399 }
400 }
401
402 fn succeeding(count: usize) -> Self {
403 let responses: Vec<_> = (0..count)
404 .map(|i| Ok(Value::String(format!("output_{}", i))))
405 .collect();
406 Self::new(responses)
407 }
408 }
409
410 #[async_trait]
411 impl ToolExecutor for MockToolExecutor {
412 async fn execute_tool(&self, _tool_name: &str, _args: Value) -> Result<Value, String> {
413 let mut responses = self.responses.lock().await;
414 if responses.is_empty() {
415 Ok(Value::String("default_output".to_string()))
416 } else {
417 responses.remove(0)
418 }
419 }
420 }
421
422 fn simple_workflow_yaml() -> &'static str {
423 r#"
424name: test_workflow
425description: A test workflow
426steps:
427 - id: step1
428 tool: echo
429 params:
430 text: "hello"
431"#
432 }
433
434 fn multi_step_yaml() -> &'static str {
435 r#"
436name: multi_step
437description: Multi-step workflow
438inputs:
439 - name: greeting
440 type: string
441steps:
442 - id: step1
443 tool: echo
444 params:
445 text: "{{ inputs.greeting }}"
446 - id: step2
447 tool: echo
448 params:
449 text: "{{ steps.step1.output }}"
450 - id: step3
451 tool: echo
452 params:
453 text: "final"
454"#
455 }
456
457 fn gated_workflow_yaml() -> &'static str {
458 r#"
459name: gated
460description: Workflow with gate
461steps:
462 - id: step1
463 tool: echo
464 params:
465 text: "before gate"
466 - id: gated_step
467 tool: echo
468 params:
469 text: "after gate"
470 gate:
471 type: approval_required
472 message: "Approve this?"
473"#
474 }
475
476 #[tokio::test]
477 async fn test_executor_start_creates_run() {
478 let executor = WorkflowExecutor::new(
479 Arc::new(MockToolExecutor::succeeding(1)),
480 Arc::new(AutoApproveHandler),
481 None,
482 );
483 let wf = parse_workflow(simple_workflow_yaml()).unwrap();
484 let state = executor.start(&wf, HashMap::new()).await.unwrap();
485 assert_eq!(state.workflow_name, "test_workflow");
486 assert_eq!(state.status, WorkflowStatus::Completed);
487 }
488
489 #[tokio::test]
490 async fn test_executor_step_executes_tool() {
491 let executor = WorkflowExecutor::new(
492 Arc::new(MockToolExecutor::new(vec![Ok(Value::String(
493 "tool_output".to_string(),
494 ))])),
495 Arc::new(AutoApproveHandler),
496 None,
497 );
498 let wf = parse_workflow(simple_workflow_yaml()).unwrap();
499 let state = executor.start(&wf, HashMap::new()).await.unwrap();
500 assert_eq!(state.status, WorkflowStatus::Completed);
501 assert!(state.step_outputs.contains_key("step1"));
502 assert_eq!(
503 state.step_outputs["step1"],
504 Value::String("tool_output".to_string())
505 );
506 }
507
508 #[tokio::test]
509 async fn test_executor_multi_step_sequential() {
510 let executor = WorkflowExecutor::new(
511 Arc::new(MockToolExecutor::succeeding(3)),
512 Arc::new(AutoApproveHandler),
513 None,
514 );
515 let wf = parse_workflow(multi_step_yaml()).unwrap();
516 let mut inputs = HashMap::new();
517 inputs.insert("greeting".to_string(), Value::String("hi".to_string()));
518 let state = executor.start(&wf, inputs).await.unwrap();
519 assert_eq!(state.status, WorkflowStatus::Completed);
520 assert_eq!(state.step_outputs.len(), 3);
521 assert!(state.step_outputs.contains_key("step1"));
522 assert!(state.step_outputs.contains_key("step2"));
523 assert!(state.step_outputs.contains_key("step3"));
524 }
525
526 #[tokio::test]
527 async fn test_executor_step_output_forwarded() {
528 let executor = WorkflowExecutor::new(
529 Arc::new(MockToolExecutor::new(vec![
530 Ok(Value::String("from_step1".to_string())),
531 Ok(Value::String("from_step2".to_string())),
532 Ok(Value::String("from_step3".to_string())),
533 ])),
534 Arc::new(AutoApproveHandler),
535 None,
536 );
537 let wf = parse_workflow(multi_step_yaml()).unwrap();
538 let mut inputs = HashMap::new();
539 inputs.insert("greeting".to_string(), Value::String("hi".to_string()));
540 let state = executor.start(&wf, inputs).await.unwrap();
541 assert_eq!(state.status, WorkflowStatus::Completed);
542 assert_eq!(
543 state.step_outputs["step1"],
544 Value::String("from_step1".to_string())
545 );
546 }
547
548 #[tokio::test]
549 async fn test_executor_gate_pauses_workflow() {
550 let executor = WorkflowExecutor::new(
551 Arc::new(MockToolExecutor::succeeding(2)),
552 Arc::new(AutoDenyHandler),
553 None,
554 );
555 let wf = parse_workflow(gated_workflow_yaml()).unwrap();
556 let state = executor.start(&wf, HashMap::new()).await.unwrap();
557 assert_eq!(state.status, WorkflowStatus::WaitingApproval);
558 assert_eq!(state.current_step_index, 1); }
560
561 #[tokio::test]
562 async fn test_executor_resume_after_approval() {
563 let executor = WorkflowExecutor::new(
564 Arc::new(MockToolExecutor::succeeding(3)),
565 Arc::new(AutoDenyHandler),
566 None,
567 );
568 let wf = parse_workflow(gated_workflow_yaml()).unwrap();
569 let state = executor.start(&wf, HashMap::new()).await.unwrap();
570 assert_eq!(state.status, WorkflowStatus::WaitingApproval);
571
572 let resumed = executor
573 .resume(state.run_id, &wf, ApprovalDecision::Approved)
574 .await
575 .unwrap();
576 assert_eq!(resumed.status, WorkflowStatus::Completed);
577 }
578
579 #[tokio::test]
580 async fn test_executor_cancel_sets_cancelled() {
581 let executor = WorkflowExecutor::new(
582 Arc::new(MockToolExecutor::succeeding(2)),
583 Arc::new(AutoDenyHandler),
584 None,
585 );
586 let wf = parse_workflow(gated_workflow_yaml()).unwrap();
587 let state = executor.start(&wf, HashMap::new()).await.unwrap();
588 assert_eq!(state.status, WorkflowStatus::WaitingApproval);
589
590 let cancelled = executor.cancel(state.run_id).await.unwrap();
591 assert_eq!(cancelled.status, WorkflowStatus::Cancelled);
592 }
593
594 #[tokio::test]
595 async fn test_executor_step_failure_with_fail_action() {
596 let executor = WorkflowExecutor::new(
597 Arc::new(MockToolExecutor::new(vec![Err("tool crashed".to_string())])),
598 Arc::new(AutoApproveHandler),
599 None,
600 );
601 let wf = parse_workflow(simple_workflow_yaml()).unwrap();
602 let state = executor.start(&wf, HashMap::new()).await.unwrap();
603 assert_eq!(state.status, WorkflowStatus::Failed);
604 assert!(state.error.unwrap().contains("tool crashed"));
605 }
606
607 #[tokio::test]
608 async fn test_executor_step_failure_with_skip_action() {
609 let yaml = r#"
610name: skip_test
611description: Test skip on error
612steps:
613 - id: failing
614 tool: bad_tool
615 params: {}
616 on_error:
617 action: skip
618 - id: after
619 tool: echo
620 params:
621 text: "continued"
622"#;
623 let executor = WorkflowExecutor::new(
624 Arc::new(MockToolExecutor::new(vec![
625 Err("fail".to_string()),
626 Ok(Value::String("ok".to_string())),
627 ])),
628 Arc::new(AutoApproveHandler),
629 None,
630 );
631 let wf = parse_workflow(yaml).unwrap();
632 let state = executor.start(&wf, HashMap::new()).await.unwrap();
633 assert_eq!(state.status, WorkflowStatus::Completed);
634 assert!(state.step_outputs.contains_key("failing"));
635 assert!(
636 state.step_outputs["failing"]
637 .as_str()
638 .unwrap()
639 .contains("skipped")
640 );
641 }
642
643 #[tokio::test]
644 async fn test_executor_step_failure_with_retry() {
645 let yaml = r#"
646name: retry_test
647description: Test retry on error
648steps:
649 - id: flaky
650 tool: flaky_tool
651 params: {}
652 on_error:
653 action: retry
654 max_retries: 3
655"#;
656 let executor = WorkflowExecutor::new(
658 Arc::new(MockToolExecutor::new(vec![
659 Err("fail1".to_string()),
660 Err("fail2".to_string()),
661 Ok(Value::String("success".to_string())),
662 ])),
663 Arc::new(AutoApproveHandler),
664 None,
665 );
666 let wf = parse_workflow(yaml).unwrap();
667 let state = executor.start(&wf, HashMap::new()).await.unwrap();
668 assert_eq!(state.status, WorkflowStatus::Completed);
669 }
670
671 #[tokio::test]
672 async fn test_executor_condition_skip_step() {
673 let yaml = r#"
674name: conditional
675description: Conditional step test
676steps:
677 - id: check
678 tool: echo
679 params:
680 text: "fail"
681 - id: skipped
682 tool: echo
683 params:
684 text: "should not run"
685 condition: "{{ steps.check.output }} == 'pass'"
686 - id: final_step
687 tool: echo
688 params:
689 text: "done"
690"#;
691 let executor = WorkflowExecutor::new(
692 Arc::new(MockToolExecutor::new(vec![
693 Ok(Value::String("fail".to_string())),
694 Ok(Value::String("done".to_string())),
696 ])),
697 Arc::new(AutoApproveHandler),
698 None,
699 );
700 let wf = parse_workflow(yaml).unwrap();
701 let state = executor.start(&wf, HashMap::new()).await.unwrap();
702 assert_eq!(state.status, WorkflowStatus::Completed);
703 assert!(state.step_outputs.contains_key("check"));
704 assert!(!state.step_outputs.contains_key("skipped"));
705 assert!(state.step_outputs.contains_key("final_step"));
706 }
707
708 #[tokio::test]
709 async fn test_executor_get_status_returns_current() {
710 let executor = WorkflowExecutor::new(
711 Arc::new(MockToolExecutor::succeeding(1)),
712 Arc::new(AutoApproveHandler),
713 None,
714 );
715 let wf = parse_workflow(simple_workflow_yaml()).unwrap();
716 let state = executor.start(&wf, HashMap::new()).await.unwrap();
717 let status = executor.get_status(state.run_id).await.unwrap();
718 assert_eq!(status.status, WorkflowStatus::Completed);
719 assert_eq!(status.run_id, state.run_id);
720 }
721
722 #[tokio::test]
723 async fn test_executor_list_runs() {
724 let executor = WorkflowExecutor::new(
725 Arc::new(MockToolExecutor::succeeding(5)),
726 Arc::new(AutoApproveHandler),
727 None,
728 );
729 let wf = parse_workflow(simple_workflow_yaml()).unwrap();
730 executor.start(&wf, HashMap::new()).await.unwrap();
731 executor.start(&wf, HashMap::new()).await.unwrap();
732 let runs = executor.list_runs().await;
733 assert_eq!(runs.len(), 2);
734 }
735
736 #[tokio::test]
737 async fn test_executor_state_persistence() {
738 let temp_dir = tempfile::tempdir().unwrap();
739 let state_path = temp_dir.path().to_path_buf();
740
741 let executor = WorkflowExecutor::new(
742 Arc::new(MockToolExecutor::succeeding(1)),
743 Arc::new(AutoApproveHandler),
744 Some(state_path.clone()),
745 );
746 let wf = parse_workflow(simple_workflow_yaml()).unwrap();
747 let state = executor.start(&wf, HashMap::new()).await.unwrap();
748
749 let loaded = WorkflowExecutor::load_state(&state_path, state.run_id)
751 .await
752 .unwrap();
753 assert_eq!(loaded.run_id, state.run_id);
754 assert_eq!(loaded.status, WorkflowStatus::Completed);
755 assert_eq!(loaded.workflow_name, "test_workflow");
756 }
757}