Skip to main content

sh_layer2/workflow_engine/
dag.rs

1//! # DAG Implementation
2//!
3//! 有向无环图结构实现。
4
5use std::collections::{HashMap, HashSet};
6
7use crate::types::Layer2Result;
8
9use super::node::Node;
10
11/// DAG 结构
12pub struct Dag {
13    nodes: HashMap<String, Node>,
14    edges: HashMap<String, Vec<String>>,
15    reverse_edges: HashMap<String, Vec<String>>,
16}
17
18impl Dag {
19    pub fn new() -> Self {
20        Self {
21            nodes: HashMap::new(),
22            edges: HashMap::new(),
23            reverse_edges: HashMap::new(),
24        }
25    }
26
27    /// 添加节点
28    pub fn add_node(&mut self, node: Node) -> Layer2Result<()> {
29        let id = node.id.clone();
30        self.nodes.insert(id.clone(), node);
31
32        self.edges.entry(id.clone()).or_default();
33        self.reverse_edges.entry(id).or_default();
34
35        Ok(())
36    }
37
38    /// 添加边
39    pub fn add_edge(&mut self, from: &str, to: &str) -> Layer2Result<()> {
40        // 验证节点存在
41        if !self.nodes.contains_key(from) {
42            return Err(anyhow::anyhow!("Source node not found: {}", from));
43        }
44
45        if !self.nodes.contains_key(to) {
46            return Err(anyhow::anyhow!("Target node not found: {}", to));
47        }
48
49        // 添加边
50        self.edges.get_mut(from).unwrap().push(to.to_string());
51        self.reverse_edges
52            .get_mut(to)
53            .unwrap()
54            .push(from.to_string());
55
56        Ok(())
57    }
58
59    /// 验证是否有环
60    pub fn has_cycle(&self) -> bool {
61        let mut visited = HashSet::new();
62        let mut rec_stack = HashSet::new();
63
64        for node_id in self.nodes.keys() {
65            if self.dfs_cycle(node_id, &mut visited, &mut rec_stack) {
66                return true;
67            }
68        }
69
70        false
71    }
72
73    fn dfs_cycle(
74        &self,
75        node_id: &str,
76        visited: &mut HashSet<String>,
77        rec_stack: &mut HashSet<String>,
78    ) -> bool {
79        if rec_stack.contains(node_id) {
80            return true;
81        }
82
83        if visited.contains(node_id) {
84            return false;
85        }
86
87        visited.insert(node_id.to_string());
88        rec_stack.insert(node_id.to_string());
89
90        if let Some(neighbors) = self.edges.get(node_id) {
91            for neighbor in neighbors {
92                if self.dfs_cycle(neighbor, visited, rec_stack) {
93                    return true;
94                }
95            }
96        }
97
98        rec_stack.remove(node_id);
99        false
100    }
101
102    /// 拓扑排序
103    pub fn topological_sort(&self) -> Layer2Result<Vec<String>> {
104        if self.has_cycle() {
105            return Err(anyhow::anyhow!("DAG contains cycle"));
106        }
107
108        let mut in_degree: HashMap<String, i32> = HashMap::new();
109        let mut result = Vec::new();
110        let mut queue = Vec::new();
111
112        // 计算入度
113        for node_id in self.nodes.keys() {
114            in_degree.insert(node_id.clone(), 0);
115        }
116
117        for node_id in self.nodes.keys() {
118            if let Some(neighbors) = self.edges.get(node_id) {
119                for neighbor in neighbors {
120                    *in_degree.get_mut(neighbor).unwrap() += 1;
121                }
122            }
123        }
124
125        // 找到所有入度为 0 的节点
126        for (node_id, &degree) in &in_degree {
127            if degree == 0 {
128                queue.push(node_id.clone());
129            }
130        }
131
132        // Kahn 算法
133        while !queue.is_empty() {
134            let node_id = queue.remove(0);
135            result.push(node_id.clone());
136
137            if let Some(neighbors) = self.edges.get(&node_id) {
138                for neighbor in neighbors {
139                    let degree = in_degree.get_mut(neighbor).unwrap();
140                    *degree -= 1;
141                    if *degree == 0 {
142                        queue.push(neighbor.clone());
143                    }
144                }
145            }
146        }
147
148        Ok(result)
149    }
150
151    /// 获取节点的依赖
152    pub fn get_dependencies(&self, node_id: &str) -> Vec<String> {
153        self.reverse_edges.get(node_id).cloned().unwrap_or_default()
154    }
155
156    /// 获取节点的后继
157    pub fn get_successors(&self, node_id: &str) -> Vec<String> {
158        self.edges.get(node_id).cloned().unwrap_or_default()
159    }
160
161    /// 获取节点
162    pub fn get_node(&self, node_id: &str) -> Option<&Node> {
163        self.nodes.get(node_id)
164    }
165
166    /// 获取所有节点 ID
167    pub fn node_ids(&self) -> Vec<String> {
168        self.nodes.keys().cloned().collect()
169    }
170
171    /// 节点数量
172    pub fn node_count(&self) -> usize {
173        self.nodes.len()
174    }
175
176    /// 边数量
177    pub fn edge_count(&self) -> usize {
178        self.edges.values().map(|v| v.len()).sum()
179    }
180}
181
182impl Default for Dag {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn test_dag_creation() {
194        let dag = Dag::new();
195        assert_eq!(dag.node_count(), 0);
196    }
197
198    #[test]
199    fn test_add_node() {
200        let mut dag = Dag::new();
201        let node = Node::new("test", "Test Node");
202        dag.add_node(node).unwrap();
203
204        assert_eq!(dag.node_count(), 1);
205    }
206
207    #[test]
208    fn test_topological_sort() {
209        let mut dag = Dag::new();
210
211        let node_a = Node::new("a", "Node A");
212        let node_b = Node::new("b", "Node B");
213        let node_c = Node::new("c", "Node C");
214
215        dag.add_node(node_a).unwrap();
216        dag.add_node(node_b).unwrap();
217        dag.add_node(node_c).unwrap();
218
219        dag.add_edge("a", "b").unwrap();
220        dag.add_edge("b", "c").unwrap();
221
222        let sorted = dag.topological_sort().unwrap();
223        assert_eq!(sorted, vec!["a", "b", "c"]);
224    }
225
226    #[test]
227    fn test_cycle_detection() {
228        let mut dag = Dag::new();
229
230        let node_a = Node::new("a", "Node A");
231        let node_b = Node::new("b", "Node B");
232
233        dag.add_node(node_a).unwrap();
234        dag.add_node(node_b).unwrap();
235
236        dag.add_edge("a", "b").unwrap();
237        dag.add_edge("b", "a").unwrap();
238
239        assert!(dag.has_cycle());
240    }
241}