Skip to main content

waypoint_core/
dependency.rs

1//! Migration dependency graph with topological sort.
2//!
3//! Supports `-- waypoint:depends V3,V5` directives for non-linear
4//! migration ordering using Kahn's algorithm.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7
8use crate::error::{Result, WaypointError};
9use crate::migration::ResolvedMigration;
10
11/// A directed acyclic graph of migration dependencies.
12pub struct DependencyGraph {
13    /// version -> set of versions it depends on
14    edges: HashMap<String, HashSet<String>>,
15    /// version -> set of versions that depend on it
16    reverse_edges: HashMap<String, HashSet<String>>,
17    /// All known versions
18    all_versions: Vec<String>,
19}
20
21impl DependencyGraph {
22    /// Build a dependency graph from resolved migrations.
23    ///
24    /// If `implicit_chain` is true, each versioned migration implicitly depends
25    /// on the previous version in sort order (backward-compatible default).
26    pub fn build(migrations: &[&ResolvedMigration], implicit_chain: bool) -> Result<Self> {
27        let mut edges: HashMap<String, HashSet<String>> = HashMap::new();
28        let mut reverse_edges: HashMap<String, HashSet<String>> = HashMap::new();
29        let mut all_versions: Vec<String> = Vec::new();
30
31        // Collect all versioned migrations sorted by version
32        let mut versioned: Vec<&ResolvedMigration> = migrations
33            .iter()
34            .filter(|m| m.is_versioned())
35            .copied()
36            .collect();
37        versioned.sort_by(|a, b| a.version().unwrap().cmp(b.version().unwrap()));
38
39        for m in &versioned {
40            let version = m.version().unwrap().raw.clone();
41            edges.entry(version.clone()).or_default();
42            reverse_edges.entry(version.clone()).or_default();
43            all_versions.push(version);
44        }
45
46        // Add explicit dependencies from directives
47        for m in &versioned {
48            let version = &m.version().unwrap().raw;
49            for dep in &m.directives.depends {
50                if !edges.contains_key(dep) {
51                    return Err(WaypointError::MissingDependency {
52                        version: version.clone(),
53                        dependency: dep.clone(),
54                    });
55                }
56                edges.get_mut(version.as_str()).unwrap().insert(dep.clone());
57                reverse_edges
58                    .get_mut(dep.as_str())
59                    .unwrap()
60                    .insert(version.clone());
61            }
62        }
63
64        // Add implicit chain dependencies (each version depends on previous)
65        if implicit_chain {
66            for i in 1..all_versions.len() {
67                let current = &all_versions[i];
68                let previous = &all_versions[i - 1];
69                // Only add implicit dependency if no explicit dependencies are set
70                if edges.get(current).is_none_or(|deps| deps.is_empty()) {
71                    edges.get_mut(current).unwrap().insert(previous.clone());
72                    reverse_edges
73                        .get_mut(previous)
74                        .unwrap()
75                        .insert(current.clone());
76                }
77            }
78        }
79
80        Ok(DependencyGraph {
81            edges,
82            reverse_edges,
83            all_versions,
84        })
85    }
86
87    /// Produce a topologically sorted order of versions using Kahn's algorithm.
88    ///
89    /// Uses borrowed `&str` references internally to avoid cloning during
90    /// the sort; only clones into owned `String`s for the output.
91    pub fn topological_sort(&self) -> Result<Vec<String>> {
92        // Compute in-degree for each node using borrowed keys
93        let mut in_degree: HashMap<&str, usize> = HashMap::new();
94        for v in &self.all_versions {
95            in_degree.insert(v, self.edges.get(v).map_or(0, |deps| deps.len()));
96        }
97
98        // Start with nodes that have no dependencies
99        let mut queue: VecDeque<&str> = VecDeque::new();
100        for v in &self.all_versions {
101            if *in_degree.get(v.as_str()).unwrap_or(&0) == 0 {
102                queue.push_back(v);
103            }
104        }
105
106        let mut sorted = Vec::new();
107
108        while let Some(node) = queue.pop_front() {
109            sorted.push(node.to_string());
110
111            // For each node that depends on this one, decrement in-degree
112            if let Some(dependents) = self.reverse_edges.get(node) {
113                for dep in dependents {
114                    let deg = in_degree.get_mut(dep.as_str()).unwrap();
115                    *deg -= 1;
116                    if *deg == 0 {
117                        queue.push_back(dep);
118                    }
119                }
120            }
121        }
122
123        if sorted.len() != self.all_versions.len() {
124            // Trace an actual cycle path — convert in_degree to owned keys for trace_cycle
125            let owned_in_degree: HashMap<String, usize> = in_degree
126                .iter()
127                .map(|(&k, &v)| (k.to_string(), v))
128                .collect();
129            let cycle_path = self.trace_cycle(&owned_in_degree);
130            return Err(WaypointError::DependencyCycle { path: cycle_path });
131        }
132
133        Ok(sorted)
134    }
135
136    /// Trace an actual cycle path for error reporting.
137    fn trace_cycle(&self, in_degree: &HashMap<String, usize>) -> String {
138        // Start from any node still in the cycle
139        let start = self
140            .all_versions
141            .iter()
142            .find(|v| *in_degree.get(*v).unwrap_or(&0) > 0);
143
144        let Some(start) = start else {
145            return "unknown cycle".to_string();
146        };
147
148        // Follow dependency edges to trace the cycle
149        let mut path = vec![start.clone()];
150        let mut current = start.clone();
151        let mut visited = std::collections::HashSet::new();
152        visited.insert(current.clone());
153
154        loop {
155            // Find a dependency of `current` that is also in the cycle
156            let next = self
157                .edges
158                .get(&current)
159                .and_then(|deps| deps.iter().find(|d| *in_degree.get(*d).unwrap_or(&0) > 0));
160
161            match next {
162                Some(n) => {
163                    if !visited.insert(n.clone()) {
164                        // We've come back to a visited node — complete the cycle
165                        path.push(n.clone());
166                        // Trim path to start from the cycle entry point
167                        if let Some(pos) = path.iter().position(|v| v == n) {
168                            let cycle: Vec<String> = path[pos..].to_vec();
169                            return cycle.join(" -> ");
170                        }
171                        return path.join(" -> ");
172                    }
173                    path.push(n.clone());
174                    current = n.clone();
175                }
176                None => {
177                    // Fallback: list all nodes in cycle
178                    let in_cycle: Vec<String> = self
179                        .all_versions
180                        .iter()
181                        .filter(|v| *in_degree.get(*v).unwrap_or(&0) > 0)
182                        .cloned()
183                        .collect();
184                    return format!("cycle involving: {}", in_cycle.join(", "));
185                }
186            }
187        }
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::directive::MigrationDirectives;
195    use crate::migration::{MigrationKind, MigrationVersion, ResolvedMigration};
196
197    fn make_migration(version: &str, depends: Vec<&str>) -> ResolvedMigration {
198        ResolvedMigration {
199            kind: MigrationKind::Versioned(MigrationVersion::parse(version).unwrap()),
200            description: format!("V{}", version),
201            script: format!("V{}__test.sql", version),
202            checksum: 0,
203            sql: String::new(),
204            directives: MigrationDirectives {
205                depends: depends.into_iter().map(String::from).collect(),
206                env: vec![],
207                ..Default::default()
208            },
209        }
210    }
211
212    #[test]
213    fn test_simple_chain() {
214        let m1 = make_migration("1", vec![]);
215        let m2 = make_migration("2", vec![]);
216        let m3 = make_migration("3", vec![]);
217        let migrations: Vec<&ResolvedMigration> = vec![&m1, &m2, &m3];
218
219        let graph = DependencyGraph::build(&migrations, true).unwrap();
220        let order = graph.topological_sort().unwrap();
221        assert_eq!(order, vec!["1", "2", "3"]);
222    }
223
224    #[test]
225    fn test_explicit_dependency() {
226        let m1 = make_migration("1", vec![]);
227        let m2 = make_migration("2", vec![]);
228        let m3 = make_migration("3", vec!["1"]); // V3 depends on V1, skipping V2
229        let migrations: Vec<&ResolvedMigration> = vec![&m1, &m2, &m3];
230
231        let graph = DependencyGraph::build(&migrations, false).unwrap();
232        let order = graph.topological_sort().unwrap();
233        // V1 must come before V3, V2 has no deps so can be anywhere
234        let pos1 = order.iter().position(|v| v == "1").unwrap();
235        let pos3 = order.iter().position(|v| v == "3").unwrap();
236        assert!(pos1 < pos3);
237    }
238
239    #[test]
240    fn test_cycle_detection() {
241        let m1 = make_migration("1", vec!["2"]);
242        let m2 = make_migration("2", vec!["1"]);
243        let migrations: Vec<&ResolvedMigration> = vec![&m1, &m2];
244
245        let graph = DependencyGraph::build(&migrations, false).unwrap();
246        assert!(graph.topological_sort().is_err());
247    }
248
249    #[test]
250    fn test_missing_dependency() {
251        let m1 = make_migration("1", vec!["99"]);
252        let migrations: Vec<&ResolvedMigration> = vec![&m1];
253
254        assert!(DependencyGraph::build(&migrations, false).is_err());
255    }
256
257    #[test]
258    fn test_cycle_error_shows_path() {
259        let m1 = make_migration("1", vec!["3"]);
260        let m2 = make_migration("2", vec!["1"]);
261        let m3 = make_migration("3", vec!["2"]);
262        let migrations: Vec<&ResolvedMigration> = vec![&m1, &m2, &m3];
263
264        let graph = DependencyGraph::build(&migrations, false).unwrap();
265        let err = graph.topological_sort().unwrap_err();
266        let msg = err.to_string();
267        // The error should contain cycle path information
268        assert!(msg.contains("->"), "Cycle error should show path: {}", msg);
269    }
270
271    #[test]
272    fn test_empty_migrations() {
273        let migrations: Vec<&ResolvedMigration> = vec![];
274        let graph = DependencyGraph::build(&migrations, true).unwrap();
275        let order = graph.topological_sort().unwrap();
276        assert!(order.is_empty());
277    }
278
279    #[test]
280    fn test_single_migration() {
281        let m1 = make_migration("1", vec![]);
282        let migrations: Vec<&ResolvedMigration> = vec![&m1];
283        let graph = DependencyGraph::build(&migrations, true).unwrap();
284        let order = graph.topological_sort().unwrap();
285        assert_eq!(order, vec!["1"]);
286    }
287
288    #[test]
289    fn test_diamond_dependency() {
290        let m1 = make_migration("1", vec![]);
291        let m2 = make_migration("2", vec!["1"]);
292        let m3 = make_migration("3", vec!["1"]);
293        let m4 = make_migration("4", vec!["2", "3"]);
294        let migrations: Vec<&ResolvedMigration> = vec![&m1, &m2, &m3, &m4];
295
296        let graph = DependencyGraph::build(&migrations, false).unwrap();
297        let order = graph.topological_sort().unwrap();
298
299        // V1 must be first, V4 must be last
300        assert_eq!(order[0], "1");
301        assert_eq!(order[3], "4");
302    }
303
304    #[test]
305    fn test_self_referencing_cycle() {
306        let m1 = make_migration("1", vec!["1"]);
307        let migrations: Vec<&ResolvedMigration> = vec![&m1];
308
309        let graph = DependencyGraph::build(&migrations, false).unwrap();
310        assert!(graph.topological_sort().is_err());
311    }
312}