Skip to main content

sh_layer2/workflow_engine/
executor.rs

1//! # Workflow Executor
2//!
3//! 工作流执行器实现。
4
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Instant;
10
11use crate::types::{Layer2Result, TaskId};
12
13use super::{
14    Dag, Node, NodeExecutor, NodeResult, NodeStatus, WorkflowEngineTrait, WorkflowInput,
15    WorkflowOutput, WorkflowStatus,
16};
17
18/// 工作流执行器
19pub struct WorkflowExecutor {
20    dag: RwLock<Dag>,
21    task_status: RwLock<HashMap<TaskId, WorkflowStatus>>,
22    node_executors: RwLock<HashMap<String, Arc<dyn NodeExecutor>>>,
23}
24
25impl WorkflowExecutor {
26    pub fn new() -> Self {
27        Self {
28            dag: RwLock::new(Dag::new()),
29            task_status: RwLock::new(HashMap::new()),
30            node_executors: RwLock::new(HashMap::new()),
31        }
32    }
33
34    /// 注册节点执行器
35    pub fn register_executor(&self, node_type: &str, executor: Arc<dyn NodeExecutor>) {
36        self.node_executors
37            .write()
38            .insert(node_type.to_string(), executor);
39    }
40
41    /// 获取节点和执行器信息(不持有锁)
42    #[allow(clippy::type_complexity)]
43    fn get_node_info(
44        &self,
45        node_id: &str,
46    ) -> Option<(Node, Option<Arc<dyn NodeExecutor>>, String)> {
47        let dag = self.dag.read();
48        let node = dag.get_node(node_id)?;
49        let node_type = node.node_type.clone();
50        let node_clone = node.clone();
51
52        drop(dag);
53
54        let executors = self.node_executors.read();
55        let executor = executors.get(&node_type).cloned();
56        drop(executors);
57
58        Some((node_clone, executor, node_type))
59    }
60}
61
62impl Default for WorkflowExecutor {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68#[async_trait]
69impl WorkflowEngineTrait for WorkflowExecutor {
70    fn add_node(&mut self, node: Node) -> Layer2Result<()> {
71        self.dag.write().add_node(node)
72    }
73
74    fn add_edge(&mut self, from: &str, to: &str) -> Layer2Result<()> {
75        self.dag.write().add_edge(from, to)
76    }
77
78    async fn execute(&self, input: WorkflowInput) -> Layer2Result<WorkflowOutput> {
79        let task_id = TaskId::new();
80        let start = Instant::now();
81
82        // 设置状态为运行中
83        self.task_status
84            .write()
85            .insert(task_id.clone(), WorkflowStatus::Running);
86
87        // 获取排序后的节点列表(释放锁后再执行)
88        let sorted_nodes = {
89            let dag = self.dag.read();
90            dag.topological_sort()?
91        };
92
93        let mut results = Vec::new();
94
95        for node_id in sorted_nodes {
96            // 获取节点信息(不持有锁)
97            if let Some((node, executor, node_type)) = self.get_node_info(&node_id) {
98                let node_start = Instant::now();
99
100                let (status, output, error) = if let Some(exec) = executor {
101                    match exec.execute(&node, &input).await {
102                        Ok(out) => (NodeStatus::Completed, Some(out), None),
103                        Err(e) => (NodeStatus::Failed, None, Some(e.to_string())),
104                    }
105                } else {
106                    (
107                        NodeStatus::Skipped,
108                        None,
109                        Some(format!("No executor for node type: {}", node_type)),
110                    )
111                };
112
113                results.push(NodeResult {
114                    node_id: node_id.clone(),
115                    status,
116                    output,
117                    error,
118                    duration_ms: node_start.elapsed().as_millis() as u64,
119                });
120            }
121        }
122
123        let final_status = if results.iter().all(|r| r.status == NodeStatus::Completed) {
124            WorkflowStatus::Completed
125        } else if results.iter().any(|r| r.status == NodeStatus::Failed) {
126            WorkflowStatus::Failed
127        } else {
128            WorkflowStatus::Completed
129        };
130
131        self.task_status
132            .write()
133            .insert(task_id.clone(), final_status);
134
135        Ok(WorkflowOutput {
136            task_id,
137            results,
138            status: final_status,
139            duration_ms: start.elapsed().as_millis() as u64,
140        })
141    }
142
143    async fn cancel(&self, task_id: &TaskId) -> Layer2Result<bool> {
144        let mut status = self.task_status.write();
145        if let Some(s) = status.get_mut(task_id) {
146            *s = WorkflowStatus::Cancelled;
147            Ok(true)
148        } else {
149            Ok(false)
150        }
151    }
152
153    fn status(&self, task_id: &TaskId) -> Layer2Result<WorkflowStatus> {
154        let status = self.task_status.read();
155        status
156            .get(task_id)
157            .copied()
158            .ok_or_else(|| anyhow::anyhow!("Task not found: {}", task_id))
159    }
160
161    fn validate(&self) -> Layer2Result<Vec<String>> {
162        let dag = self.dag.read();
163        let mut errors = Vec::new();
164
165        if dag.has_cycle() {
166            errors.push("DAG contains cycle".to_string());
167        }
168
169        Ok(errors)
170    }
171
172    fn node_count(&self) -> usize {
173        self.dag.read().node_count()
174    }
175
176    fn edge_count(&self) -> usize {
177        self.dag.read().edge_count()
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_executor_creation() {
187        let executor = WorkflowExecutor::new();
188        assert_eq!(executor.node_count(), 0);
189    }
190
191    #[test]
192    fn test_add_node() {
193        let mut executor = WorkflowExecutor::new();
194        let node = Node::new("test", "Test");
195        executor.add_node(node).unwrap();
196        assert_eq!(executor.node_count(), 1);
197    }
198}