worldinterface_core/flowspec/
topo.rs1use std::cmp::Reverse;
7use std::collections::{BinaryHeap, HashMap};
8
9use crate::flowspec::FlowSpec;
10use crate::id::NodeId;
11
12pub fn topological_sort(spec: &FlowSpec) -> Result<Vec<NodeId>, Vec<NodeId>> {
24 let node_ids: Vec<NodeId> = spec.nodes.iter().map(|n| n.id).collect();
25
26 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
28 let mut successors: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
29 for &id in &node_ids {
30 in_degree.entry(id).or_insert(0);
31 successors.entry(id).or_default();
32 }
33
34 let node_set: std::collections::HashSet<NodeId> = node_ids.iter().copied().collect();
35 for edge in &spec.edges {
36 if node_set.contains(&edge.from) && node_set.contains(&edge.to) {
37 *in_degree.entry(edge.to).or_insert(0) += 1;
38 successors.entry(edge.from).or_default().push(edge.to);
39 }
40 }
41
42 let mut heap: BinaryHeap<Reverse<NodeId>> =
44 in_degree.iter().filter(|(_, °)| deg == 0).map(|(&id, _)| Reverse(id)).collect();
45
46 let mut sorted = Vec::with_capacity(node_ids.len());
47
48 while let Some(Reverse(node_id)) = heap.pop() {
49 sorted.push(node_id);
50 if let Some(succs) = successors.get(&node_id) {
51 for &succ in succs {
52 if let Some(deg) = in_degree.get_mut(&succ) {
53 *deg -= 1;
54 if *deg == 0 {
55 heap.push(Reverse(succ));
56 }
57 }
58 }
59 }
60 }
61
62 if sorted.len() == node_ids.len() {
63 Ok(sorted)
64 } else {
65 let sorted_set: std::collections::HashSet<NodeId> = sorted.into_iter().collect();
67 let remaining: Vec<NodeId> =
68 node_ids.into_iter().filter(|id| !sorted_set.contains(id)).collect();
69 Err(remaining)
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use serde_json::json;
76
77 use super::*;
78 use crate::flowspec::*;
79
80 fn connector_node(id: NodeId, name: &str) -> Node {
81 Node {
82 id,
83 label: None,
84 node_type: NodeType::Connector(ConnectorNode {
85 connector: name.into(),
86 params: json!({}),
87 idempotency_config: None,
88 }),
89 }
90 }
91
92 fn edge(from: NodeId, to: NodeId) -> Edge {
93 Edge { from, to, condition: None }
94 }
95
96 fn make_ids(n: usize) -> Vec<NodeId> {
97 (0..n).map(|_| NodeId::new()).collect()
98 }
99
100 fn make_spec(nodes: Vec<Node>, edges: Vec<Edge>) -> FlowSpec {
101 FlowSpec { id: None, name: None, nodes, edges, params: None }
102 }
103
104 #[test]
105 fn linear_flow_sort_order() {
106 let ids = make_ids(3);
107 let spec = make_spec(
108 vec![
109 connector_node(ids[0], "a"),
110 connector_node(ids[1], "b"),
111 connector_node(ids[2], "c"),
112 ],
113 vec![edge(ids[0], ids[1]), edge(ids[1], ids[2])],
114 );
115 let sorted = topological_sort(&spec).unwrap();
116 assert_eq!(sorted, vec![ids[0], ids[1], ids[2]]);
117 }
118
119 #[test]
120 fn diamond_flow_sort_order() {
121 let ids = make_ids(4);
122 let spec = make_spec(
124 vec![
125 connector_node(ids[0], "a"),
126 connector_node(ids[1], "b"),
127 connector_node(ids[2], "c"),
128 connector_node(ids[3], "d"),
129 ],
130 vec![
131 edge(ids[0], ids[1]),
132 edge(ids[0], ids[2]),
133 edge(ids[1], ids[3]),
134 edge(ids[2], ids[3]),
135 ],
136 );
137 let sorted = topological_sort(&spec).unwrap();
138 assert_eq!(sorted[0], ids[0]);
140 assert_eq!(sorted[3], ids[3]);
141 let pos_b = sorted.iter().position(|&id| id == ids[1]).unwrap();
143 let pos_c = sorted.iter().position(|&id| id == ids[2]).unwrap();
144 let pos_d = sorted.iter().position(|&id| id == ids[3]).unwrap();
145 assert!(pos_b < pos_d);
146 assert!(pos_c < pos_d);
147 }
148
149 #[test]
150 fn single_node() {
151 let id = NodeId::new();
152 let spec = make_spec(vec![connector_node(id, "a")], vec![]);
153 let sorted = topological_sort(&spec).unwrap();
154 assert_eq!(sorted, vec![id]);
155 }
156
157 #[test]
158 fn cycle_detection() {
159 let ids = make_ids(2);
160 let spec = make_spec(
161 vec![connector_node(ids[0], "a"), connector_node(ids[1], "b")],
162 vec![edge(ids[0], ids[1]), edge(ids[1], ids[0])],
163 );
164 let err = topological_sort(&spec).unwrap_err();
165 assert_eq!(err.len(), 2);
166 }
167
168 #[test]
169 fn determinism() {
170 let ids = make_ids(4);
171 let spec = make_spec(
172 vec![
173 connector_node(ids[0], "a"),
174 connector_node(ids[1], "b"),
175 connector_node(ids[2], "c"),
176 connector_node(ids[3], "d"),
177 ],
178 vec![
179 edge(ids[0], ids[1]),
180 edge(ids[0], ids[2]),
181 edge(ids[1], ids[3]),
182 edge(ids[2], ids[3]),
183 ],
184 );
185 let r1 = topological_sort(&spec).unwrap();
186 let r2 = topological_sort(&spec).unwrap();
187 assert_eq!(r1, r2);
188 }
189}