Skip to main content

terminals_core/primitives/
dag.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2
3pub struct DAG<T: Clone + Eq + std::hash::Hash> {
4    nodes: HashSet<T>,
5    edges: HashMap<T, Vec<T>>,
6    in_degree: HashMap<T, usize>,
7}
8
9impl<T: Clone + Eq + std::hash::Hash + Ord> DAG<T> {
10    pub fn new() -> Self {
11        Self {
12            nodes: HashSet::new(),
13            edges: HashMap::new(),
14            in_degree: HashMap::new(),
15        }
16    }
17
18    pub fn add_node(&mut self, node: T) {
19        self.nodes.insert(node.clone());
20        self.edges.entry(node.clone()).or_default();
21        self.in_degree.entry(node).or_insert(0);
22    }
23
24    pub fn add_edge(&mut self, from: T, to: T) {
25        self.edges.entry(from.clone()).or_default().push(to.clone());
26        *self.in_degree.entry(to).or_insert(0) += 1;
27        self.in_degree.entry(from).or_insert(0);
28    }
29
30    /// Kahn's algorithm for topological sort (deterministic: zero-degree nodes sorted before enqueue)
31    pub fn toposort(&self) -> Result<Vec<T>, DAGError> {
32        let mut in_deg = self.in_degree.clone();
33        let mut zero: Vec<T> = in_deg
34            .iter()
35            .filter(|(_, &d)| d == 0)
36            .map(|(n, _)| n.clone())
37            .collect();
38        zero.sort();
39        let mut queue: VecDeque<T> = zero.into();
40        let mut result = Vec::with_capacity(self.nodes.len());
41
42        while let Some(node) = queue.pop_front() {
43            result.push(node.clone());
44            if let Some(neighbors) = self.edges.get(&node) {
45                let mut nexts: Vec<T> = neighbors
46                    .iter()
47                    .filter_map(|next| {
48                        let deg = in_deg.get_mut(next).unwrap();
49                        *deg -= 1;
50                        if *deg == 0 {
51                            Some(next.clone())
52                        } else {
53                            None
54                        }
55                    })
56                    .collect();
57                nexts.sort();
58                for n in nexts {
59                    queue.push_back(n);
60                }
61            }
62        }
63
64        if result.len() != self.nodes.len() {
65            Err(DAGError::CycleDetected)
66        } else {
67            Ok(result)
68        }
69    }
70
71    /// Group nodes into parallel execution levels
72    pub fn parallel_groups(&self) -> Result<Vec<Vec<T>>, DAGError> {
73        let mut in_deg = self.in_degree.clone();
74        let mut current: Vec<T> = in_deg
75            .iter()
76            .filter(|(_, &d)| d == 0)
77            .map(|(n, _)| n.clone())
78            .collect();
79        current.sort();
80        let mut groups = Vec::new();
81        let mut processed = 0;
82
83        while !current.is_empty() {
84            let mut next = Vec::new();
85            for node in &current {
86                if let Some(neighbors) = self.edges.get(node) {
87                    for n in neighbors {
88                        let deg = in_deg.get_mut(n).unwrap();
89                        *deg -= 1;
90                        if *deg == 0 {
91                            next.push(n.clone());
92                        }
93                    }
94                }
95            }
96            processed += current.len();
97            groups.push(current);
98            next.sort();
99            current = next;
100        }
101
102        if processed != self.nodes.len() {
103            Err(DAGError::CycleDetected)
104        } else {
105            Ok(groups)
106        }
107    }
108}
109
110impl<T: Clone + Eq + std::hash::Hash + Ord> Default for DAG<T> {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116#[derive(Debug)]
117pub enum DAGError {
118    CycleDetected,
119}
120
121impl std::fmt::Display for DAGError {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        match self {
124            Self::CycleDetected => write!(f, "Cycle detected in DAG"),
125        }
126    }
127}
128
129impl std::error::Error for DAGError {}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn test_toposort_linear() {
137        let mut dag = DAG::new();
138        dag.add_node("a");
139        dag.add_node("b");
140        dag.add_node("c");
141        dag.add_edge("a", "b");
142        dag.add_edge("b", "c");
143        let sorted = dag.toposort().unwrap();
144        assert_eq!(sorted, vec!["a", "b", "c"]);
145    }
146
147    #[test]
148    fn test_toposort_diamond() {
149        let mut dag = DAG::new();
150        dag.add_node("a");
151        dag.add_node("b");
152        dag.add_node("c");
153        dag.add_node("d");
154        dag.add_edge("a", "b");
155        dag.add_edge("a", "c");
156        dag.add_edge("b", "d");
157        dag.add_edge("c", "d");
158        let sorted = dag.toposort().unwrap();
159        assert_eq!(sorted[0], "a");
160        assert_eq!(*sorted.last().unwrap(), "d");
161    }
162
163    #[test]
164    fn test_cycle_detection() {
165        let mut dag = DAG::new();
166        dag.add_node("a");
167        dag.add_node("b");
168        dag.add_edge("a", "b");
169        dag.add_edge("b", "a");
170        assert!(dag.toposort().is_err());
171    }
172
173    #[test]
174    fn test_parallel_groups() {
175        let mut dag = DAG::new();
176        dag.add_node("a");
177        dag.add_node("b");
178        dag.add_node("c");
179        dag.add_edge("a", "b");
180        dag.add_edge("a", "c");
181        let groups = dag.parallel_groups().unwrap();
182        assert_eq!(groups.len(), 2); // [a], [b, c]
183        assert_eq!(groups[0], vec!["a"]);
184    }
185}