sh_layer2/workflow_engine/
executor.rs1use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Instant;
10
11use crate::types::{Layer2Result, TaskId};
12
13use super::{
14 Dag, Node, NodeExecutor, NodeResult, NodeStatus, WorkflowEngineTrait, WorkflowInput,
15 WorkflowOutput, WorkflowStatus,
16};
17
18pub struct WorkflowExecutor {
20 dag: RwLock<Dag>,
21 task_status: RwLock<HashMap<TaskId, WorkflowStatus>>,
22 node_executors: RwLock<HashMap<String, Arc<dyn NodeExecutor>>>,
23}
24
25impl WorkflowExecutor {
26 pub fn new() -> Self {
27 Self {
28 dag: RwLock::new(Dag::new()),
29 task_status: RwLock::new(HashMap::new()),
30 node_executors: RwLock::new(HashMap::new()),
31 }
32 }
33
34 pub fn register_executor(&self, node_type: &str, executor: Arc<dyn NodeExecutor>) {
36 self.node_executors
37 .write()
38 .insert(node_type.to_string(), executor);
39 }
40
41 #[allow(clippy::type_complexity)]
43 fn get_node_info(
44 &self,
45 node_id: &str,
46 ) -> Option<(Node, Option<Arc<dyn NodeExecutor>>, String)> {
47 let dag = self.dag.read();
48 let node = dag.get_node(node_id)?;
49 let node_type = node.node_type.clone();
50 let node_clone = node.clone();
51
52 drop(dag);
53
54 let executors = self.node_executors.read();
55 let executor = executors.get(&node_type).cloned();
56 drop(executors);
57
58 Some((node_clone, executor, node_type))
59 }
60}
61
62impl Default for WorkflowExecutor {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68#[async_trait]
69impl WorkflowEngineTrait for WorkflowExecutor {
70 fn add_node(&mut self, node: Node) -> Layer2Result<()> {
71 self.dag.write().add_node(node)
72 }
73
74 fn add_edge(&mut self, from: &str, to: &str) -> Layer2Result<()> {
75 self.dag.write().add_edge(from, to)
76 }
77
78 async fn execute(&self, input: WorkflowInput) -> Layer2Result<WorkflowOutput> {
79 let task_id = TaskId::new();
80 let start = Instant::now();
81
82 self.task_status
84 .write()
85 .insert(task_id.clone(), WorkflowStatus::Running);
86
87 let sorted_nodes = {
89 let dag = self.dag.read();
90 dag.topological_sort()?
91 };
92
93 let mut results = Vec::new();
94
95 for node_id in sorted_nodes {
96 if let Some((node, executor, node_type)) = self.get_node_info(&node_id) {
98 let node_start = Instant::now();
99
100 let (status, output, error) = if let Some(exec) = executor {
101 match exec.execute(&node, &input).await {
102 Ok(out) => (NodeStatus::Completed, Some(out), None),
103 Err(e) => (NodeStatus::Failed, None, Some(e.to_string())),
104 }
105 } else {
106 (
107 NodeStatus::Skipped,
108 None,
109 Some(format!("No executor for node type: {}", node_type)),
110 )
111 };
112
113 results.push(NodeResult {
114 node_id: node_id.clone(),
115 status,
116 output,
117 error,
118 duration_ms: node_start.elapsed().as_millis() as u64,
119 });
120 }
121 }
122
123 let final_status = if results.iter().all(|r| r.status == NodeStatus::Completed) {
124 WorkflowStatus::Completed
125 } else if results.iter().any(|r| r.status == NodeStatus::Failed) {
126 WorkflowStatus::Failed
127 } else {
128 WorkflowStatus::Completed
129 };
130
131 self.task_status
132 .write()
133 .insert(task_id.clone(), final_status);
134
135 Ok(WorkflowOutput {
136 task_id,
137 results,
138 status: final_status,
139 duration_ms: start.elapsed().as_millis() as u64,
140 })
141 }
142
143 async fn cancel(&self, task_id: &TaskId) -> Layer2Result<bool> {
144 let mut status = self.task_status.write();
145 if let Some(s) = status.get_mut(task_id) {
146 *s = WorkflowStatus::Cancelled;
147 Ok(true)
148 } else {
149 Ok(false)
150 }
151 }
152
153 fn status(&self, task_id: &TaskId) -> Layer2Result<WorkflowStatus> {
154 let status = self.task_status.read();
155 status
156 .get(task_id)
157 .copied()
158 .ok_or_else(|| anyhow::anyhow!("Task not found: {}", task_id))
159 }
160
161 fn validate(&self) -> Layer2Result<Vec<String>> {
162 let dag = self.dag.read();
163 let mut errors = Vec::new();
164
165 if dag.has_cycle() {
166 errors.push("DAG contains cycle".to_string());
167 }
168
169 Ok(errors)
170 }
171
172 fn node_count(&self) -> usize {
173 self.dag.read().node_count()
174 }
175
176 fn edge_count(&self) -> usize {
177 self.dag.read().edge_count()
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_executor_creation() {
187 let executor = WorkflowExecutor::new();
188 assert_eq!(executor.node_count(), 0);
189 }
190
191 #[test]
192 fn test_add_node() {
193 let mut executor = WorkflowExecutor::new();
194 let node = Node::new("test", "Test");
195 executor.add_node(node).unwrap();
196 assert_eq!(executor.node_count(), 1);
197 }
198}