sh_layer2/workflow_engine/
dag.rs1use std::collections::{HashMap, HashSet};
6
7use crate::types::Layer2Result;
8
9use super::node::Node;
10
11pub struct Dag {
13 nodes: HashMap<String, Node>,
14 edges: HashMap<String, Vec<String>>,
15 reverse_edges: HashMap<String, Vec<String>>,
16}
17
18impl Dag {
19 pub fn new() -> Self {
20 Self {
21 nodes: HashMap::new(),
22 edges: HashMap::new(),
23 reverse_edges: HashMap::new(),
24 }
25 }
26
27 pub fn add_node(&mut self, node: Node) -> Layer2Result<()> {
29 let id = node.id.clone();
30 self.nodes.insert(id.clone(), node);
31
32 self.edges.entry(id.clone()).or_default();
33 self.reverse_edges.entry(id).or_default();
34
35 Ok(())
36 }
37
38 pub fn add_edge(&mut self, from: &str, to: &str) -> Layer2Result<()> {
40 if !self.nodes.contains_key(from) {
42 return Err(anyhow::anyhow!("Source node not found: {}", from));
43 }
44
45 if !self.nodes.contains_key(to) {
46 return Err(anyhow::anyhow!("Target node not found: {}", to));
47 }
48
49 self.edges.get_mut(from).unwrap().push(to.to_string());
51 self.reverse_edges
52 .get_mut(to)
53 .unwrap()
54 .push(from.to_string());
55
56 Ok(())
57 }
58
59 pub fn has_cycle(&self) -> bool {
61 let mut visited = HashSet::new();
62 let mut rec_stack = HashSet::new();
63
64 for node_id in self.nodes.keys() {
65 if self.dfs_cycle(node_id, &mut visited, &mut rec_stack) {
66 return true;
67 }
68 }
69
70 false
71 }
72
73 fn dfs_cycle(
74 &self,
75 node_id: &str,
76 visited: &mut HashSet<String>,
77 rec_stack: &mut HashSet<String>,
78 ) -> bool {
79 if rec_stack.contains(node_id) {
80 return true;
81 }
82
83 if visited.contains(node_id) {
84 return false;
85 }
86
87 visited.insert(node_id.to_string());
88 rec_stack.insert(node_id.to_string());
89
90 if let Some(neighbors) = self.edges.get(node_id) {
91 for neighbor in neighbors {
92 if self.dfs_cycle(neighbor, visited, rec_stack) {
93 return true;
94 }
95 }
96 }
97
98 rec_stack.remove(node_id);
99 false
100 }
101
102 pub fn topological_sort(&self) -> Layer2Result<Vec<String>> {
104 if self.has_cycle() {
105 return Err(anyhow::anyhow!("DAG contains cycle"));
106 }
107
108 let mut in_degree: HashMap<String, i32> = HashMap::new();
109 let mut result = Vec::new();
110 let mut queue = Vec::new();
111
112 for node_id in self.nodes.keys() {
114 in_degree.insert(node_id.clone(), 0);
115 }
116
117 for node_id in self.nodes.keys() {
118 if let Some(neighbors) = self.edges.get(node_id) {
119 for neighbor in neighbors {
120 *in_degree.get_mut(neighbor).unwrap() += 1;
121 }
122 }
123 }
124
125 for (node_id, °ree) in &in_degree {
127 if degree == 0 {
128 queue.push(node_id.clone());
129 }
130 }
131
132 while !queue.is_empty() {
134 let node_id = queue.remove(0);
135 result.push(node_id.clone());
136
137 if let Some(neighbors) = self.edges.get(&node_id) {
138 for neighbor in neighbors {
139 let degree = in_degree.get_mut(neighbor).unwrap();
140 *degree -= 1;
141 if *degree == 0 {
142 queue.push(neighbor.clone());
143 }
144 }
145 }
146 }
147
148 Ok(result)
149 }
150
151 pub fn get_dependencies(&self, node_id: &str) -> Vec<String> {
153 self.reverse_edges.get(node_id).cloned().unwrap_or_default()
154 }
155
156 pub fn get_successors(&self, node_id: &str) -> Vec<String> {
158 self.edges.get(node_id).cloned().unwrap_or_default()
159 }
160
161 pub fn get_node(&self, node_id: &str) -> Option<&Node> {
163 self.nodes.get(node_id)
164 }
165
166 pub fn node_ids(&self) -> Vec<String> {
168 self.nodes.keys().cloned().collect()
169 }
170
171 pub fn node_count(&self) -> usize {
173 self.nodes.len()
174 }
175
176 pub fn edge_count(&self) -> usize {
178 self.edges.values().map(|v| v.len()).sum()
179 }
180}
181
182impl Default for Dag {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn test_dag_creation() {
194 let dag = Dag::new();
195 assert_eq!(dag.node_count(), 0);
196 }
197
198 #[test]
199 fn test_add_node() {
200 let mut dag = Dag::new();
201 let node = Node::new("test", "Test Node");
202 dag.add_node(node).unwrap();
203
204 assert_eq!(dag.node_count(), 1);
205 }
206
207 #[test]
208 fn test_topological_sort() {
209 let mut dag = Dag::new();
210
211 let node_a = Node::new("a", "Node A");
212 let node_b = Node::new("b", "Node B");
213 let node_c = Node::new("c", "Node C");
214
215 dag.add_node(node_a).unwrap();
216 dag.add_node(node_b).unwrap();
217 dag.add_node(node_c).unwrap();
218
219 dag.add_edge("a", "b").unwrap();
220 dag.add_edge("b", "c").unwrap();
221
222 let sorted = dag.topological_sort().unwrap();
223 assert_eq!(sorted, vec!["a", "b", "c"]);
224 }
225
226 #[test]
227 fn test_cycle_detection() {
228 let mut dag = Dag::new();
229
230 let node_a = Node::new("a", "Node A");
231 let node_b = Node::new("b", "Node B");
232
233 dag.add_node(node_a).unwrap();
234 dag.add_node(node_b).unwrap();
235
236 dag.add_edge("a", "b").unwrap();
237 dag.add_edge("b", "a").unwrap();
238
239 assert!(dag.has_cycle());
240 }
241}