1use std::collections::{HashMap, HashSet, VecDeque};
7
8use crate::error::{Result, WaypointError};
9use crate::migration::ResolvedMigration;
10
11pub struct DependencyGraph {
13 edges: HashMap<String, HashSet<String>>,
15 reverse_edges: HashMap<String, HashSet<String>>,
17 all_versions: Vec<String>,
19}
20
21impl DependencyGraph {
22 pub fn build(migrations: &[&ResolvedMigration], implicit_chain: bool) -> Result<Self> {
27 let mut edges: HashMap<String, HashSet<String>> = HashMap::new();
28 let mut reverse_edges: HashMap<String, HashSet<String>> = HashMap::new();
29 let mut all_versions: Vec<String> = Vec::new();
30
31 let mut versioned: Vec<&ResolvedMigration> = migrations
33 .iter()
34 .filter(|m| m.is_versioned())
35 .copied()
36 .collect();
37 versioned.sort_by(|a, b| a.version().unwrap().cmp(b.version().unwrap()));
38
39 for m in &versioned {
40 let version = m.version().unwrap().raw.clone();
41 edges.entry(version.clone()).or_default();
42 reverse_edges.entry(version.clone()).or_default();
43 all_versions.push(version);
44 }
45
46 for m in &versioned {
48 let version = &m.version().unwrap().raw;
49 for dep in &m.directives.depends {
50 if !edges.contains_key(dep) {
51 return Err(WaypointError::MissingDependency {
52 version: version.clone(),
53 dependency: dep.clone(),
54 });
55 }
56 edges.get_mut(version.as_str()).unwrap().insert(dep.clone());
57 reverse_edges
58 .get_mut(dep.as_str())
59 .unwrap()
60 .insert(version.clone());
61 }
62 }
63
64 if implicit_chain {
66 for i in 1..all_versions.len() {
67 let current = &all_versions[i];
68 let previous = &all_versions[i - 1];
69 if edges.get(current).is_none_or(|deps| deps.is_empty()) {
71 edges.get_mut(current).unwrap().insert(previous.clone());
72 reverse_edges
73 .get_mut(previous)
74 .unwrap()
75 .insert(current.clone());
76 }
77 }
78 }
79
80 Ok(DependencyGraph {
81 edges,
82 reverse_edges,
83 all_versions,
84 })
85 }
86
87 pub fn topological_sort(&self) -> Result<Vec<String>> {
92 let mut in_degree: HashMap<&str, usize> = HashMap::new();
94 for v in &self.all_versions {
95 in_degree.insert(v, self.edges.get(v).map_or(0, |deps| deps.len()));
96 }
97
98 let mut queue: VecDeque<&str> = VecDeque::new();
100 for v in &self.all_versions {
101 if *in_degree.get(v.as_str()).unwrap_or(&0) == 0 {
102 queue.push_back(v);
103 }
104 }
105
106 let mut sorted = Vec::new();
107
108 while let Some(node) = queue.pop_front() {
109 sorted.push(node.to_string());
110
111 if let Some(dependents) = self.reverse_edges.get(node) {
113 for dep in dependents {
114 let deg = in_degree.get_mut(dep.as_str()).unwrap();
115 *deg -= 1;
116 if *deg == 0 {
117 queue.push_back(dep);
118 }
119 }
120 }
121 }
122
123 if sorted.len() != self.all_versions.len() {
124 let owned_in_degree: HashMap<String, usize> = in_degree
126 .iter()
127 .map(|(&k, &v)| (k.to_string(), v))
128 .collect();
129 let cycle_path = self.trace_cycle(&owned_in_degree);
130 return Err(WaypointError::DependencyCycle { path: cycle_path });
131 }
132
133 Ok(sorted)
134 }
135
136 fn trace_cycle(&self, in_degree: &HashMap<String, usize>) -> String {
138 let start = self
140 .all_versions
141 .iter()
142 .find(|v| *in_degree.get(*v).unwrap_or(&0) > 0);
143
144 let Some(start) = start else {
145 return "unknown cycle".to_string();
146 };
147
148 let mut path = vec![start.clone()];
150 let mut current = start.clone();
151 let mut visited = std::collections::HashSet::new();
152 visited.insert(current.clone());
153
154 loop {
155 let next = self
157 .edges
158 .get(¤t)
159 .and_then(|deps| deps.iter().find(|d| *in_degree.get(*d).unwrap_or(&0) > 0));
160
161 match next {
162 Some(n) => {
163 if !visited.insert(n.clone()) {
164 path.push(n.clone());
166 if let Some(pos) = path.iter().position(|v| v == n) {
168 let cycle: Vec<String> = path[pos..].to_vec();
169 return cycle.join(" -> ");
170 }
171 return path.join(" -> ");
172 }
173 path.push(n.clone());
174 current = n.clone();
175 }
176 None => {
177 let in_cycle: Vec<String> = self
179 .all_versions
180 .iter()
181 .filter(|v| *in_degree.get(*v).unwrap_or(&0) > 0)
182 .cloned()
183 .collect();
184 return format!("cycle involving: {}", in_cycle.join(", "));
185 }
186 }
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::directive::MigrationDirectives;
195 use crate::migration::{MigrationKind, MigrationVersion, ResolvedMigration};
196
197 fn make_migration(version: &str, depends: Vec<&str>) -> ResolvedMigration {
198 ResolvedMigration {
199 kind: MigrationKind::Versioned(MigrationVersion::parse(version).unwrap()),
200 description: format!("V{}", version),
201 script: format!("V{}__test.sql", version),
202 checksum: 0,
203 sql: String::new(),
204 directives: MigrationDirectives {
205 depends: depends.into_iter().map(String::from).collect(),
206 env: vec![],
207 ..Default::default()
208 },
209 }
210 }
211
212 #[test]
213 fn test_simple_chain() {
214 let m1 = make_migration("1", vec![]);
215 let m2 = make_migration("2", vec![]);
216 let m3 = make_migration("3", vec![]);
217 let migrations: Vec<&ResolvedMigration> = vec![&m1, &m2, &m3];
218
219 let graph = DependencyGraph::build(&migrations, true).unwrap();
220 let order = graph.topological_sort().unwrap();
221 assert_eq!(order, vec!["1", "2", "3"]);
222 }
223
224 #[test]
225 fn test_explicit_dependency() {
226 let m1 = make_migration("1", vec![]);
227 let m2 = make_migration("2", vec![]);
228 let m3 = make_migration("3", vec!["1"]); let migrations: Vec<&ResolvedMigration> = vec![&m1, &m2, &m3];
230
231 let graph = DependencyGraph::build(&migrations, false).unwrap();
232 let order = graph.topological_sort().unwrap();
233 let pos1 = order.iter().position(|v| v == "1").unwrap();
235 let pos3 = order.iter().position(|v| v == "3").unwrap();
236 assert!(pos1 < pos3);
237 }
238
239 #[test]
240 fn test_cycle_detection() {
241 let m1 = make_migration("1", vec!["2"]);
242 let m2 = make_migration("2", vec!["1"]);
243 let migrations: Vec<&ResolvedMigration> = vec![&m1, &m2];
244
245 let graph = DependencyGraph::build(&migrations, false).unwrap();
246 assert!(graph.topological_sort().is_err());
247 }
248
249 #[test]
250 fn test_missing_dependency() {
251 let m1 = make_migration("1", vec!["99"]);
252 let migrations: Vec<&ResolvedMigration> = vec![&m1];
253
254 assert!(DependencyGraph::build(&migrations, false).is_err());
255 }
256
257 #[test]
258 fn test_cycle_error_shows_path() {
259 let m1 = make_migration("1", vec!["3"]);
260 let m2 = make_migration("2", vec!["1"]);
261 let m3 = make_migration("3", vec!["2"]);
262 let migrations: Vec<&ResolvedMigration> = vec![&m1, &m2, &m3];
263
264 let graph = DependencyGraph::build(&migrations, false).unwrap();
265 let err = graph.topological_sort().unwrap_err();
266 let msg = err.to_string();
267 assert!(msg.contains("->"), "Cycle error should show path: {}", msg);
269 }
270
271 #[test]
272 fn test_empty_migrations() {
273 let migrations: Vec<&ResolvedMigration> = vec![];
274 let graph = DependencyGraph::build(&migrations, true).unwrap();
275 let order = graph.topological_sort().unwrap();
276 assert!(order.is_empty());
277 }
278
279 #[test]
280 fn test_single_migration() {
281 let m1 = make_migration("1", vec![]);
282 let migrations: Vec<&ResolvedMigration> = vec![&m1];
283 let graph = DependencyGraph::build(&migrations, true).unwrap();
284 let order = graph.topological_sort().unwrap();
285 assert_eq!(order, vec!["1"]);
286 }
287
288 #[test]
289 fn test_diamond_dependency() {
290 let m1 = make_migration("1", vec![]);
291 let m2 = make_migration("2", vec!["1"]);
292 let m3 = make_migration("3", vec!["1"]);
293 let m4 = make_migration("4", vec!["2", "3"]);
294 let migrations: Vec<&ResolvedMigration> = vec![&m1, &m2, &m3, &m4];
295
296 let graph = DependencyGraph::build(&migrations, false).unwrap();
297 let order = graph.topological_sort().unwrap();
298
299 assert_eq!(order[0], "1");
301 assert_eq!(order[3], "4");
302 }
303
304 #[test]
305 fn test_self_referencing_cycle() {
306 let m1 = make_migration("1", vec!["1"]);
307 let migrations: Vec<&ResolvedMigration> = vec![&m1];
308
309 let graph = DependencyGraph::build(&migrations, false).unwrap();
310 assert!(graph.topological_sort().is_err());
311 }
312}