terminals_core/primitives/
dag.rs1use std::collections::{HashMap, HashSet, VecDeque};
2
3pub struct DAG<T: Clone + Eq + std::hash::Hash> {
4 nodes: HashSet<T>,
5 edges: HashMap<T, Vec<T>>,
6 in_degree: HashMap<T, usize>,
7}
8
9impl<T: Clone + Eq + std::hash::Hash + Ord> DAG<T> {
10 pub fn new() -> Self {
11 Self {
12 nodes: HashSet::new(),
13 edges: HashMap::new(),
14 in_degree: HashMap::new(),
15 }
16 }
17
18 pub fn add_node(&mut self, node: T) {
19 self.nodes.insert(node.clone());
20 self.edges.entry(node.clone()).or_default();
21 self.in_degree.entry(node).or_insert(0);
22 }
23
24 pub fn add_edge(&mut self, from: T, to: T) {
25 self.edges.entry(from.clone()).or_default().push(to.clone());
26 *self.in_degree.entry(to).or_insert(0) += 1;
27 self.in_degree.entry(from).or_insert(0);
28 }
29
30 pub fn toposort(&self) -> Result<Vec<T>, DAGError> {
32 let mut in_deg = self.in_degree.clone();
33 let mut zero: Vec<T> = in_deg
34 .iter()
35 .filter(|(_, &d)| d == 0)
36 .map(|(n, _)| n.clone())
37 .collect();
38 zero.sort();
39 let mut queue: VecDeque<T> = zero.into();
40 let mut result = Vec::with_capacity(self.nodes.len());
41
42 while let Some(node) = queue.pop_front() {
43 result.push(node.clone());
44 if let Some(neighbors) = self.edges.get(&node) {
45 let mut nexts: Vec<T> = neighbors
46 .iter()
47 .filter_map(|next| {
48 let deg = in_deg.get_mut(next).unwrap();
49 *deg -= 1;
50 if *deg == 0 {
51 Some(next.clone())
52 } else {
53 None
54 }
55 })
56 .collect();
57 nexts.sort();
58 for n in nexts {
59 queue.push_back(n);
60 }
61 }
62 }
63
64 if result.len() != self.nodes.len() {
65 Err(DAGError::CycleDetected)
66 } else {
67 Ok(result)
68 }
69 }
70
71 pub fn parallel_groups(&self) -> Result<Vec<Vec<T>>, DAGError> {
73 let mut in_deg = self.in_degree.clone();
74 let mut current: Vec<T> = in_deg
75 .iter()
76 .filter(|(_, &d)| d == 0)
77 .map(|(n, _)| n.clone())
78 .collect();
79 current.sort();
80 let mut groups = Vec::new();
81 let mut processed = 0;
82
83 while !current.is_empty() {
84 let mut next = Vec::new();
85 for node in ¤t {
86 if let Some(neighbors) = self.edges.get(node) {
87 for n in neighbors {
88 let deg = in_deg.get_mut(n).unwrap();
89 *deg -= 1;
90 if *deg == 0 {
91 next.push(n.clone());
92 }
93 }
94 }
95 }
96 processed += current.len();
97 groups.push(current);
98 next.sort();
99 current = next;
100 }
101
102 if processed != self.nodes.len() {
103 Err(DAGError::CycleDetected)
104 } else {
105 Ok(groups)
106 }
107 }
108}
109
110impl<T: Clone + Eq + std::hash::Hash + Ord> Default for DAG<T> {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116#[derive(Debug)]
117pub enum DAGError {
118 CycleDetected,
119}
120
121impl std::fmt::Display for DAGError {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 match self {
124 Self::CycleDetected => write!(f, "Cycle detected in DAG"),
125 }
126 }
127}
128
129impl std::error::Error for DAGError {}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn test_toposort_linear() {
137 let mut dag = DAG::new();
138 dag.add_node("a");
139 dag.add_node("b");
140 dag.add_node("c");
141 dag.add_edge("a", "b");
142 dag.add_edge("b", "c");
143 let sorted = dag.toposort().unwrap();
144 assert_eq!(sorted, vec!["a", "b", "c"]);
145 }
146
147 #[test]
148 fn test_toposort_diamond() {
149 let mut dag = DAG::new();
150 dag.add_node("a");
151 dag.add_node("b");
152 dag.add_node("c");
153 dag.add_node("d");
154 dag.add_edge("a", "b");
155 dag.add_edge("a", "c");
156 dag.add_edge("b", "d");
157 dag.add_edge("c", "d");
158 let sorted = dag.toposort().unwrap();
159 assert_eq!(sorted[0], "a");
160 assert_eq!(*sorted.last().unwrap(), "d");
161 }
162
163 #[test]
164 fn test_cycle_detection() {
165 let mut dag = DAG::new();
166 dag.add_node("a");
167 dag.add_node("b");
168 dag.add_edge("a", "b");
169 dag.add_edge("b", "a");
170 assert!(dag.toposort().is_err());
171 }
172
173 #[test]
174 fn test_parallel_groups() {
175 let mut dag = DAG::new();
176 dag.add_node("a");
177 dag.add_node("b");
178 dag.add_node("c");
179 dag.add_edge("a", "b");
180 dag.add_edge("a", "c");
181 let groups = dag.parallel_groups().unwrap();
182 assert_eq!(groups.len(), 2); assert_eq!(groups[0], vec!["a"]);
184 }
185}