tauri_typegen/build/
dependency_resolver.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2use thiserror::Error;
3
4#[derive(Error, Debug)]
5pub enum DependencyError {
6    #[error("Circular dependency detected: {0}")]
7    CircularDependency(String),
8    #[error("Unresolved dependency: {0} required by {1}")]
9    UnresolvedDependency(String, String),
10    #[error("Invalid dependency specification: {0}")]
11    InvalidSpecification(String),
12}
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub struct DependencyNode {
16    pub name: String,
17    pub path: String,
18    pub node_type: DependencyNodeType,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub enum DependencyNodeType {
23    Command,
24    Struct,
25    Enum,
26    Type,
27    Module,
28}
29
30#[derive(Debug, Clone)]
31pub struct Dependency {
32    pub from: DependencyNode,
33    pub to: DependencyNode,
34    pub dependency_type: DependencyType,
35}
36
37#[derive(Debug, Clone, PartialEq)]
38pub enum DependencyType {
39    /// Direct usage (parameter, return type)
40    Direct,
41    /// Field in a struct
42    Field,
43    /// Variant in an enum
44    Variant,
45    /// Import/use statement
46    Import,
47    /// Generic type parameter
48    Generic,
49}
50
51pub struct DependencyResolver {
52    dependencies: Vec<Dependency>,
53    nodes: HashSet<DependencyNode>,
54}
55
56impl DependencyResolver {
57    pub fn new() -> Self {
58        Self {
59            dependencies: Vec::new(),
60            nodes: HashSet::new(),
61        }
62    }
63
64    /// Add a dependency relationship
65    pub fn add_dependency(&mut self, dependency: Dependency) {
66        self.nodes.insert(dependency.from.clone());
67        self.nodes.insert(dependency.to.clone());
68        self.dependencies.push(dependency);
69    }
70
71    /// Add a node without dependencies
72    pub fn add_node(&mut self, node: DependencyNode) {
73        self.nodes.insert(node);
74    }
75
76    /// Resolve dependencies and return them in topological order
77    pub fn resolve_build_order(&self) -> Result<Vec<DependencyNode>, DependencyError> {
78        let mut in_degree = HashMap::new();
79        let mut adjacency = HashMap::new();
80
81        // Initialize in-degree count and adjacency list
82        for node in &self.nodes {
83            in_degree.insert(node.clone(), 0);
84            adjacency.insert(node.clone(), Vec::new());
85        }
86
87        // Build adjacency list and count in-degrees
88        // If "from" uses "to", then "to" should be processed before "from"
89        // So we create an edge from "to" to "from" for the topological sort
90        for dep in &self.dependencies {
91            adjacency.get_mut(&dep.to).unwrap().push(dep.from.clone());
92
93            *in_degree.get_mut(&dep.from).unwrap() += 1;
94        }
95
96        // Topological sort using Kahn's algorithm
97        let mut queue: VecDeque<_> = in_degree
98            .iter()
99            .filter_map(|(node, &degree)| {
100                if degree == 0 {
101                    Some(node.clone())
102                } else {
103                    None
104                }
105            })
106            .collect();
107        let mut result = Vec::new();
108
109        while let Some(node) = queue.pop_front() {
110            result.push(node.clone());
111
112            // Remove this node and update in-degrees of adjacent nodes
113            if let Some(adjacent_nodes) = adjacency.get(&node) {
114                for adjacent in adjacent_nodes {
115                    let degree = in_degree.get_mut(adjacent).unwrap();
116                    *degree -= 1;
117                    if *degree == 0 {
118                        queue.push_back(adjacent.clone());
119                    }
120                }
121            }
122        }
123
124        // Check for circular dependencies
125        if result.len() != self.nodes.len() {
126            let remaining: Vec<String> = self
127                .nodes
128                .iter()
129                .filter(|n| !result.contains(n))
130                .map(|n| n.name.clone())
131                .collect();
132            return Err(DependencyError::CircularDependency(remaining.join(", ")));
133        }
134
135        Ok(result)
136    }
137
138    /// Get dependencies for a specific node
139    pub fn get_dependencies_for(&self, node: &DependencyNode) -> Vec<&Dependency> {
140        self.dependencies
141            .iter()
142            .filter(|dep| dep.from == *node)
143            .collect()
144    }
145
146    /// Get reverse dependencies (dependents) for a specific node
147    pub fn get_dependents_of(&self, node: &DependencyNode) -> Vec<&Dependency> {
148        self.dependencies
149            .iter()
150            .filter(|dep| dep.to == *node)
151            .collect()
152    }
153
154    /// Check if there are any unresolved dependencies
155    pub fn validate_dependencies(&self) -> Result<(), DependencyError> {
156        for dep in &self.dependencies {
157            if !self.nodes.contains(&dep.from) {
158                return Err(DependencyError::UnresolvedDependency(
159                    dep.from.name.clone(),
160                    "unknown".to_string(),
161                ));
162            }
163            if !self.nodes.contains(&dep.to) {
164                return Err(DependencyError::UnresolvedDependency(
165                    dep.to.name.clone(),
166                    dep.from.name.clone(),
167                ));
168            }
169        }
170        Ok(())
171    }
172
173    /// Generate a visual representation of the dependency graph
174    pub fn generate_dot_graph(&self) -> String {
175        let mut dot = String::from("digraph Dependencies {\n");
176        dot.push_str("    rankdir=LR;\n");
177        dot.push_str("    node [shape=box];\n\n");
178
179        // Add nodes with different shapes based on type
180        for node in &self.nodes {
181            let (shape, color) = match node.node_type {
182                DependencyNodeType::Command => ("ellipse", "lightblue"),
183                DependencyNodeType::Struct => ("box", "lightgreen"),
184                DependencyNodeType::Enum => ("diamond", "lightyellow"),
185                DependencyNodeType::Type => ("circle", "lightgray"),
186                DependencyNodeType::Module => ("folder", "lightcoral"),
187            };
188
189            dot.push_str(&format!(
190                "    \"{}\" [shape={}, fillcolor={}, style=filled];\n",
191                node.name, shape, color
192            ));
193        }
194
195        dot.push('\n');
196
197        // Add edges with different styles based on dependency type
198        for dep in &self.dependencies {
199            let style = match dep.dependency_type {
200                DependencyType::Direct => "solid",
201                DependencyType::Field => "dashed",
202                DependencyType::Variant => "dotted",
203                DependencyType::Import => "bold",
204                DependencyType::Generic => "double",
205            };
206
207            dot.push_str(&format!(
208                "    \"{}\" -> \"{}\" [style={}];\n",
209                dep.from.name, dep.to.name, style
210            ));
211        }
212
213        dot.push_str("}\n");
214        dot
215    }
216
217    /// Generate a text-based visualization of dependencies
218    pub fn generate_text_graph(&self) -> String {
219        let mut output = String::from("Dependency Graph:\n");
220        output.push_str("=================\n\n");
221
222        for node in &self.nodes {
223            let deps = self.get_dependencies_for(node);
224            let dependents = self.get_dependents_of(node);
225
226            output.push_str(&format!("{} ({:?})\n", node.name, node.node_type));
227
228            if !deps.is_empty() {
229                output.push_str("  Dependencies:\n");
230                for dep in deps {
231                    output.push_str(&format!(
232                        "    -> {} ({:?})\n",
233                        dep.to.name, dep.dependency_type
234                    ));
235                }
236            }
237
238            if !dependents.is_empty() {
239                output.push_str("  Dependents:\n");
240                for dep in dependents {
241                    output.push_str(&format!(
242                        "    <- {} ({:?})\n",
243                        dep.from.name, dep.dependency_type
244                    ));
245                }
246            }
247
248            output.push('\n');
249        }
250
251        output
252    }
253
254    /// Group nodes by their type for organized code generation
255    pub fn group_by_type(&self) -> HashMap<DependencyNodeType, Vec<DependencyNode>> {
256        let mut groups = HashMap::new();
257
258        for node in &self.nodes {
259            groups
260                .entry(node.node_type.clone())
261                .or_insert_with(Vec::new)
262                .push(node.clone());
263        }
264
265        groups
266    }
267
268    /// Get the dependency depth for a node (longest path to a leaf)
269    pub fn get_dependency_depth(&self, node: &DependencyNode) -> usize {
270        let mut visited = HashSet::new();
271        self.calculate_depth(node, &mut visited)
272    }
273
274    fn calculate_depth(
275        &self,
276        node: &DependencyNode,
277        visited: &mut HashSet<DependencyNode>,
278    ) -> usize {
279        if visited.contains(node) {
280            return 0; // Avoid infinite recursion on cycles
281        }
282
283        visited.insert(node.clone());
284
285        let max_child_depth = self
286            .get_dependencies_for(node)
287            .iter()
288            .map(|dep| self.calculate_depth(&dep.to, visited))
289            .max()
290            .unwrap_or(0);
291
292        visited.remove(node);
293        max_child_depth + 1
294    }
295}
296
297impl Default for DependencyResolver {
298    fn default() -> Self {
299        Self::new()
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    fn create_test_node(name: &str, node_type: DependencyNodeType) -> DependencyNode {
308        DependencyNode {
309            name: name.to_string(),
310            path: format!("/test/{}.rs", name),
311            node_type,
312        }
313    }
314
315    #[test]
316    fn test_simple_dependency_resolution() {
317        let mut resolver = DependencyResolver::new();
318
319        let node_a = create_test_node("A", DependencyNodeType::Struct);
320        let node_b = create_test_node("B", DependencyNodeType::Struct);
321
322        resolver.add_node(node_a.clone());
323        resolver.add_node(node_b.clone());
324
325        // B depends on A (B uses A), so A must be processed before B
326        // In our system: "from" uses "to", which means "to" should come first
327        resolver.add_dependency(Dependency {
328            from: node_b.clone(),
329            to: node_a.clone(),
330            dependency_type: DependencyType::Direct,
331        });
332
333        let order = resolver.resolve_build_order().unwrap();
334        assert_eq!(order.len(), 2);
335
336        // A should come before B since B depends on A
337        let a_pos = order.iter().position(|n| n.name == "A").unwrap();
338        let b_pos = order.iter().position(|n| n.name == "B").unwrap();
339        assert!(
340            a_pos < b_pos,
341            "A should come before B, but got order: {:?}",
342            order.iter().map(|n| &n.name).collect::<Vec<_>>()
343        );
344    }
345
346    #[test]
347    fn test_circular_dependency_detection() {
348        let mut resolver = DependencyResolver::new();
349
350        let node_a = create_test_node("A", DependencyNodeType::Struct);
351        let node_b = create_test_node("B", DependencyNodeType::Struct);
352
353        resolver.add_node(node_a.clone());
354        resolver.add_node(node_b.clone());
355        resolver.add_dependency(Dependency {
356            from: node_a.clone(),
357            to: node_b.clone(),
358            dependency_type: DependencyType::Direct,
359        });
360        resolver.add_dependency(Dependency {
361            from: node_b.clone(),
362            to: node_a.clone(),
363            dependency_type: DependencyType::Direct,
364        });
365
366        let result = resolver.resolve_build_order();
367        assert!(result.is_err());
368        if let Err(DependencyError::CircularDependency(_)) = result {
369            // Expected
370        } else {
371            panic!("Expected circular dependency error");
372        }
373    }
374
375    #[test]
376    fn test_complex_dependency_chain() {
377        let mut resolver = DependencyResolver::new();
378
379        let node_a = create_test_node("A", DependencyNodeType::Struct);
380        let node_b = create_test_node("B", DependencyNodeType::Struct);
381        let node_c = create_test_node("C", DependencyNodeType::Struct);
382        let node_d = create_test_node("D", DependencyNodeType::Command);
383
384        resolver.add_node(node_a.clone());
385        resolver.add_node(node_b.clone());
386        resolver.add_node(node_c.clone());
387        resolver.add_node(node_d.clone());
388
389        // D depends on C, C depends on B, B depends on A
390        resolver.add_dependency(Dependency {
391            from: node_d.clone(),
392            to: node_c.clone(),
393            dependency_type: DependencyType::Direct,
394        });
395        resolver.add_dependency(Dependency {
396            from: node_c.clone(),
397            to: node_b.clone(),
398            dependency_type: DependencyType::Field,
399        });
400        resolver.add_dependency(Dependency {
401            from: node_b.clone(),
402            to: node_a.clone(),
403            dependency_type: DependencyType::Direct,
404        });
405
406        let order = resolver.resolve_build_order().unwrap();
407        assert_eq!(order.len(), 4);
408
409        // Verify ordering: A -> B -> C -> D
410        let positions: HashMap<String, usize> = order
411            .iter()
412            .enumerate()
413            .map(|(i, n)| (n.name.clone(), i))
414            .collect();
415
416        assert!(positions["A"] < positions["B"]);
417        assert!(positions["B"] < positions["C"]);
418        assert!(positions["C"] < positions["D"]);
419    }
420
421    #[test]
422    fn test_dependency_depth_calculation() {
423        let mut resolver = DependencyResolver::new();
424
425        let node_a = create_test_node("A", DependencyNodeType::Struct);
426        let node_b = create_test_node("B", DependencyNodeType::Struct);
427        let node_c = create_test_node("C", DependencyNodeType::Command);
428
429        resolver.add_node(node_a.clone());
430        resolver.add_node(node_b.clone());
431        resolver.add_node(node_c.clone());
432
433        // C -> B -> A
434        resolver.add_dependency(Dependency {
435            from: node_c.clone(),
436            to: node_b.clone(),
437            dependency_type: DependencyType::Direct,
438        });
439        resolver.add_dependency(Dependency {
440            from: node_b.clone(),
441            to: node_a.clone(),
442            dependency_type: DependencyType::Direct,
443        });
444
445        assert_eq!(resolver.get_dependency_depth(&node_a), 1); // Leaf node
446        assert_eq!(resolver.get_dependency_depth(&node_b), 2); // A + itself
447        assert_eq!(resolver.get_dependency_depth(&node_c), 3); // A + B + itself
448    }
449
450    #[test]
451    fn test_group_by_type() {
452        let mut resolver = DependencyResolver::new();
453
454        let struct_a = create_test_node("StructA", DependencyNodeType::Struct);
455        let struct_b = create_test_node("StructB", DependencyNodeType::Struct);
456        let cmd_a = create_test_node("CommandA", DependencyNodeType::Command);
457        let enum_a = create_test_node("EnumA", DependencyNodeType::Enum);
458
459        resolver.add_node(struct_a);
460        resolver.add_node(struct_b);
461        resolver.add_node(cmd_a);
462        resolver.add_node(enum_a);
463
464        let groups = resolver.group_by_type();
465
466        assert_eq!(groups.get(&DependencyNodeType::Struct).unwrap().len(), 2);
467        assert_eq!(groups.get(&DependencyNodeType::Command).unwrap().len(), 1);
468        assert_eq!(groups.get(&DependencyNodeType::Enum).unwrap().len(), 1);
469    }
470}