Skip to main content

torsh_fx/
graph_analysis.rs

1//! Graph analysis and linting utilities for FX graphs
2//!
3//! This module provides comprehensive graph analysis capabilities including:
4//! - Graph linting with best practice suggestions
5//! - Graph diff and merge functionality for version control
6//! - Advanced graph metrics and health checking
7//! - Pattern detection and architectural analysis
8
9use crate::{Edge, FxGraph, Node, TorshResult};
10use petgraph::graph::NodeIndex;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13
14/// Graph linting utilities for best practice validation
15#[derive(Debug)]
16pub struct GraphLinter {
17    rules: Vec<LintRule>,
18    severity_threshold: LintSeverity,
19}
20
21/// Lint rule for graph validation
22#[derive(Debug, Clone)]
23pub struct LintRule {
24    pub name: String,
25    pub description: String,
26    pub severity: LintSeverity,
27    pub checker: fn(&FxGraph) -> Vec<LintIssue>,
28}
29
30/// Severity levels for lint issues
31#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
32pub enum LintSeverity {
33    Info,
34    Warning,
35    Error,
36    Critical,
37}
38
39/// Individual lint issue found in graph
40#[derive(Debug, Clone)]
41pub struct LintIssue {
42    pub rule_name: String,
43    pub severity: LintSeverity,
44    pub message: String,
45    pub node_index: Option<NodeIndex>,
46    pub suggestions: Vec<String>,
47}
48
49/// Complete lint report for a graph
50#[derive(Debug, Serialize, Deserialize)]
51pub struct LintReport {
52    pub total_issues: u32,
53    pub issues_by_severity: HashMap<LintSeverity, u32>,
54    #[serde(skip)]
55    pub issues: Vec<LintIssue>,
56    pub overall_score: f64, // 0.0 (bad) to 1.0 (perfect)
57    pub recommendations: Vec<String>,
58}
59
60impl GraphLinter {
61    /// Create a new graph linter with default rules
62    pub fn new() -> Self {
63        let mut linter = Self {
64            rules: Vec::new(),
65            severity_threshold: LintSeverity::Info,
66        };
67        linter.add_default_rules();
68        linter
69    }
70
71    /// Set minimum severity threshold for reporting
72    pub fn with_severity_threshold(mut self, threshold: LintSeverity) -> Self {
73        self.severity_threshold = threshold;
74        self
75    }
76
77    /// Add a custom lint rule
78    pub fn add_rule(&mut self, rule: LintRule) {
79        self.rules.push(rule);
80    }
81
82    /// Lint a graph and return comprehensive report
83    pub fn lint_graph(&self, graph: &FxGraph) -> LintReport {
84        let mut all_issues = Vec::new();
85
86        // Run all lint rules
87        for rule in &self.rules {
88            let mut issues = (rule.checker)(graph);
89            // Add rule name to each issue
90            for issue in &mut issues {
91                issue.rule_name = rule.name.clone();
92            }
93            all_issues.extend(issues);
94        }
95
96        // Filter by severity threshold
97        all_issues.retain(|issue| issue.severity >= self.severity_threshold);
98
99        // Generate statistics
100        let total_issues = all_issues.len() as u32;
101        let mut issues_by_severity = HashMap::new();
102        for issue in &all_issues {
103            *issues_by_severity
104                .entry(issue.severity.clone())
105                .or_insert(0) += 1;
106        }
107
108        // Calculate overall score (0.0 to 1.0)
109        let overall_score = self.calculate_overall_score(&all_issues, graph);
110
111        // Generate global recommendations
112        let recommendations = self.generate_global_recommendations(&all_issues, graph);
113
114        LintReport {
115            total_issues,
116            issues_by_severity,
117            issues: all_issues,
118            overall_score,
119            recommendations,
120        }
121    }
122
123    /// Add default lint rules
124    fn add_default_rules(&mut self) {
125        // Rule: Check for disconnected nodes
126        self.add_rule(LintRule {
127            name: "disconnected_nodes".to_string(),
128            description: "Detect nodes with no incoming or outgoing connections".to_string(),
129            severity: LintSeverity::Warning,
130            checker: |graph| {
131                let mut issues = Vec::new();
132                for (idx, node) in graph.nodes() {
133                    let has_incoming = graph
134                        .graph
135                        .edges_directed(idx, petgraph::Incoming)
136                        .next()
137                        .is_some();
138                    let has_outgoing = graph
139                        .graph
140                        .edges_directed(idx, petgraph::Outgoing)
141                        .next()
142                        .is_some();
143
144                    if !has_incoming
145                        && !has_outgoing
146                        && !matches!(node, Node::Input(_) | Node::Output)
147                    {
148                        issues.push(LintIssue {
149                            rule_name: "".to_string(), // Will be filled by caller
150                            severity: LintSeverity::Warning,
151                            message: format!("Node {idx:?} is disconnected from the graph"),
152                            node_index: Some(idx),
153                            suggestions: vec![
154                                "Remove unused node or connect it to the graph".to_string()
155                            ],
156                        });
157                    }
158                }
159                issues
160            },
161        });
162
163        // Rule: Check for cycles in the graph
164        self.add_rule(LintRule {
165            name: "cycles".to_string(),
166            description: "Detect cycles that may cause infinite loops".to_string(),
167            severity: LintSeverity::Error,
168            checker: |graph| {
169                let mut issues = Vec::new();
170                if petgraph::algo::is_cyclic_directed(&graph.graph) {
171                    issues.push(LintIssue {
172                        rule_name: "".to_string(),
173                        severity: LintSeverity::Error,
174                        message: "Graph contains cycles which may cause infinite loops".to_string(),
175                        node_index: None,
176                        suggestions: vec![
177                            "Review loop constructs and ensure proper termination conditions"
178                                .to_string(),
179                            "Consider breaking cycles with merge nodes".to_string(),
180                        ],
181                    });
182                }
183                issues
184            },
185        });
186
187        // Rule: Check for missing inputs/outputs
188        self.add_rule(LintRule {
189            name: "missing_io".to_string(),
190            description: "Ensure graph has proper input and output nodes".to_string(),
191            severity: LintSeverity::Error,
192            checker: |graph| {
193                let mut issues = Vec::new();
194
195                if graph.inputs().is_empty() {
196                    issues.push(LintIssue {
197                        rule_name: "".to_string(),
198                        severity: LintSeverity::Error,
199                        message: "Graph has no input nodes".to_string(),
200                        node_index: None,
201                        suggestions: vec![
202                            "Add input nodes to define graph entry points".to_string()
203                        ],
204                    });
205                }
206
207                if graph.outputs().is_empty() {
208                    issues.push(LintIssue {
209                        rule_name: "".to_string(),
210                        severity: LintSeverity::Error,
211                        message: "Graph has no output nodes".to_string(),
212                        node_index: None,
213                        suggestions: vec!["Add output nodes to define graph results".to_string()],
214                    });
215                }
216                issues
217            },
218        });
219
220        // Rule: Check for inefficient patterns
221        self.add_rule(LintRule {
222            name: "inefficient_patterns".to_string(),
223            description: "Detect known inefficient operation patterns".to_string(),
224            severity: LintSeverity::Info,
225            checker: |graph| {
226                let mut issues = Vec::new();
227
228                // Check for consecutive transpose operations
229                for (idx, node) in graph.nodes() {
230                    if let Node::Call(op, _) = node {
231                        if op == "transpose" {
232                            // Check if followed by another transpose
233                            for neighbor in graph.graph.neighbors(idx) {
234                                if let Some(Node::Call(neighbor_op, _)) = graph.get_node(neighbor) {
235                                    if neighbor_op == "transpose" {
236                                        issues.push(LintIssue {
237                                            rule_name: "".to_string(),
238                                            severity: LintSeverity::Info,
239                                            message: "Consecutive transpose operations detected".to_string(),
240                                            node_index: Some(idx),
241                                            suggestions: vec!["Consider fusing consecutive transposes or eliminating them if they cancel out".to_string()],
242                                        });
243                                    }
244                                }
245                            }
246                        }
247                    }
248                }
249                issues
250            },
251        });
252
253        // Rule: Check for large fan-out
254        self.add_rule(LintRule {
255            name: "large_fanout".to_string(),
256            description: "Detect nodes with excessive fan-out".to_string(),
257            severity: LintSeverity::Warning,
258            checker: |graph| {
259                let mut issues = Vec::new();
260                const MAX_FANOUT: usize = 10;
261
262                for (idx, _node) in graph.nodes() {
263                    let fanout = graph.graph.edges_directed(idx, petgraph::Outgoing).count();
264                    if fanout > MAX_FANOUT {
265                        issues.push(LintIssue {
266                            rule_name: "".to_string(),
267                            severity: LintSeverity::Warning,
268                            message: format!("Node {idx:?} has high fan-out of {fanout}"),
269                            node_index: Some(idx),
270                            suggestions: vec![
271                                "Consider adding intermediate nodes to reduce fan-out".to_string(),
272                                "Verify if all outputs are necessary".to_string(),
273                            ],
274                        });
275                    }
276                }
277                issues
278            },
279        });
280    }
281
282    /// Calculate overall graph health score
283    fn calculate_overall_score(&self, issues: &[LintIssue], graph: &FxGraph) -> f64 {
284        let total_nodes = graph.node_count() as f64;
285        if total_nodes == 0.0 {
286            return 0.0;
287        }
288
289        let mut penalty = 0.0;
290        for issue in issues {
291            penalty += match issue.severity {
292                LintSeverity::Info => 0.1,
293                LintSeverity::Warning => 0.3,
294                LintSeverity::Error => 0.7,
295                LintSeverity::Critical => 1.0,
296            };
297        }
298
299        // Normalize penalty by graph size
300        let normalized_penalty = penalty / total_nodes;
301        (1.0 - normalized_penalty).max(0.0)
302    }
303
304    /// Generate global recommendations based on all issues
305    fn generate_global_recommendations(
306        &self,
307        issues: &[LintIssue],
308        graph: &FxGraph,
309    ) -> Vec<String> {
310        let mut recommendations = Vec::new();
311
312        // Count issue types
313        let error_count = issues
314            .iter()
315            .filter(|i| i.severity >= LintSeverity::Error)
316            .count();
317        let warning_count = issues
318            .iter()
319            .filter(|i| i.severity == LintSeverity::Warning)
320            .count();
321
322        if error_count > 0 {
323            recommendations.push("Fix critical errors before deploying the graph".to_string());
324        }
325
326        if warning_count > 3 {
327            recommendations.push("Consider refactoring to address multiple warnings".to_string());
328        }
329
330        if graph.node_count() > 100 {
331            recommendations
332                .push("Consider breaking large graphs into smaller subgraphs".to_string());
333        }
334
335        if graph.edge_count() > graph.node_count() * 2 {
336            recommendations.push(
337                "Graph appears to have high connectivity - verify if all connections are necessary"
338                    .to_string(),
339            );
340        }
341
342        recommendations.push("Run graph optimization passes to improve performance".to_string());
343        recommendations.push("Add comprehensive documentation for complex operations".to_string());
344
345        recommendations
346    }
347}
348
349/// Graph diff and merge functionality for version control
350pub struct GraphDiff;
351
352#[derive(Debug, Clone)]
353pub struct GraphDifference {
354    pub added_nodes: Vec<(NodeIndex, Node)>,
355    pub removed_nodes: Vec<(NodeIndex, Node)>,
356    pub modified_nodes: Vec<(NodeIndex, Node, Node)>, // (index, old, new)
357    pub added_edges: Vec<(NodeIndex, NodeIndex, Edge)>,
358    pub removed_edges: Vec<(NodeIndex, NodeIndex, Edge)>,
359}
360
361impl GraphDiff {
362    /// Calculate differences between two graphs
363    pub fn diff(old_graph: &FxGraph, new_graph: &FxGraph) -> GraphDifference {
364        let mut diff = GraphDifference {
365            added_nodes: Vec::new(),
366            removed_nodes: Vec::new(),
367            modified_nodes: Vec::new(),
368            added_edges: Vec::new(),
369            removed_edges: Vec::new(),
370        };
371
372        // Create node maps for comparison
373        let old_nodes: HashMap<String, (NodeIndex, &Node)> = old_graph
374            .nodes()
375            .map(|(idx, node)| (Self::node_key(node), (idx, node)))
376            .collect();
377
378        let new_nodes: HashMap<String, (NodeIndex, &Node)> = new_graph
379            .nodes()
380            .map(|(idx, node)| (Self::node_key(node), (idx, node)))
381            .collect();
382
383        // Find added and modified nodes
384        for (key, (new_idx, new_node)) in &new_nodes {
385            if let Some((_old_idx, old_node)) = old_nodes.get(key) {
386                if !Self::nodes_equal(old_node, new_node) {
387                    diff.modified_nodes
388                        .push((*new_idx, (*old_node).clone(), (*new_node).clone()));
389                }
390            } else {
391                diff.added_nodes.push((*new_idx, (*new_node).clone()));
392            }
393        }
394
395        // Find removed nodes
396        for (key, (old_idx, old_node)) in &old_nodes {
397            if !new_nodes.contains_key(key) {
398                diff.removed_nodes.push((*old_idx, (*old_node).clone()));
399            }
400        }
401
402        // Compare edges (simplified comparison)
403        let _old_edges: HashSet<String> = old_graph
404            .graph
405            .edge_references()
406            .map(|edge| {
407                use petgraph::visit::EdgeRef;
408                format!(
409                    "{}->{}:{}",
410                    edge.source().index(),
411                    edge.target().index(),
412                    edge.weight().name
413                )
414            })
415            .collect();
416
417        let _new_edges: HashSet<String> = new_graph
418            .graph
419            .edge_references()
420            .map(|edge| {
421                use petgraph::visit::EdgeRef;
422                format!(
423                    "{}->{}:{}",
424                    edge.source().index(),
425                    edge.target().index(),
426                    edge.weight().name
427                )
428            })
429            .collect();
430
431        // For now, just track edge count differences
432        // A more sophisticated implementation would track actual edge changes
433
434        diff
435    }
436
437    /// Merge changes from one graph into another
438    pub fn merge(base_graph: &FxGraph, diff: &GraphDifference) -> TorshResult<FxGraph> {
439        let mut merged_graph = base_graph.clone();
440
441        // Apply node additions
442        for (_idx, node) in &diff.added_nodes {
443            merged_graph.graph.add_node(node.clone());
444        }
445
446        // Apply node modifications (simplified)
447        for (idx, _old_node, new_node) in &diff.modified_nodes {
448            if let Some(node_weight) = merged_graph.graph.node_weight_mut(*idx) {
449                *node_weight = new_node.clone();
450            }
451        }
452
453        // Apply edge changes would require more complex tracking
454        // This is a simplified implementation
455
456        Ok(merged_graph)
457    }
458
459    /// Generate a unique key for a node for comparison
460    fn node_key(node: &Node) -> String {
461        match node {
462            Node::Input(name) => format!("input:{name}"),
463            Node::Call(op, args) => {
464                let args_str = args.join(",");
465                format!("call:{op}:{args_str}")
466            }
467            Node::Output => "output".to_string(),
468            Node::Conditional { condition, .. } => format!("conditional:{condition}"),
469            Node::Loop { condition, .. } => format!("loop:{}", condition),
470            Node::Merge { inputs } => format!("merge:{}", inputs.join(",")),
471            Node::GetAttr { target, attr } => format!("getattr:{}:{}", target, attr),
472        }
473    }
474
475    /// Check if two nodes are functionally equal
476    fn nodes_equal(node1: &Node, node2: &Node) -> bool {
477        std::mem::discriminant(node1) == std::mem::discriminant(node2)
478            && Self::node_key(node1) == Self::node_key(node2)
479    }
480}
481
482/// Advanced graph metrics and health checking
483#[derive(Debug, Serialize, Deserialize)]
484pub struct GraphMetrics {
485    pub node_count: usize,
486    pub edge_count: usize,
487    pub input_count: usize,
488    pub output_count: usize,
489    pub max_depth: usize,
490    pub average_fanout: f64,
491    pub connectivity_ratio: f64,
492    pub complexity_score: f64,
493    pub operation_distribution: HashMap<String, u32>,
494    pub critical_path_length: usize,
495}
496
497/// Graph pattern detection
498pub struct PatternDetector;
499
500#[derive(Debug, Clone)]
501pub struct DetectedPattern {
502    pub pattern_type: String,
503    pub description: String,
504    pub nodes: Vec<NodeIndex>,
505    pub confidence: f64,
506    pub optimization_potential: String,
507}
508
509impl PatternDetector {
510    /// Detect common patterns in the graph
511    pub fn detect_patterns(graph: &FxGraph) -> Vec<DetectedPattern> {
512        let mut patterns = Vec::new();
513
514        // Detect linear chains
515        patterns.extend(Self::detect_linear_chains(graph));
516
517        // Detect fan-out patterns
518        patterns.extend(Self::detect_fanout_patterns(graph));
519
520        // Detect bottlenecks
521        patterns.extend(Self::detect_bottlenecks(graph));
522
523        patterns
524    }
525
526    /// Detect linear chains of operations
527    fn detect_linear_chains(graph: &FxGraph) -> Vec<DetectedPattern> {
528        let mut patterns = Vec::new();
529        let mut visited = HashSet::new();
530
531        for (start_idx, _) in graph.nodes() {
532            if visited.contains(&start_idx) {
533                continue;
534            }
535
536            let chain = Self::trace_linear_chain(graph, start_idx, &mut visited);
537            if chain.len() > 3 {
538                // Consider chains of 4+ nodes as patterns
539                patterns.push(DetectedPattern {
540                    pattern_type: "linear_chain".to_string(),
541                    description: format!("Linear chain of {} operations", chain.len()),
542                    nodes: chain,
543                    confidence: 0.9,
544                    optimization_potential: "Consider operation fusion for better performance"
545                        .to_string(),
546                });
547            }
548        }
549
550        patterns
551    }
552
553    /// Trace a linear chain from a starting node
554    fn trace_linear_chain(
555        graph: &FxGraph,
556        start: NodeIndex,
557        visited: &mut HashSet<NodeIndex>,
558    ) -> Vec<NodeIndex> {
559        let mut chain = vec![start];
560        visited.insert(start);
561        let mut current = start;
562
563        loop {
564            let neighbors: Vec<_> = graph.graph.neighbors(current).collect();
565            if neighbors.len() != 1 {
566                break; // Not a linear chain
567            }
568
569            let next = neighbors[0];
570            if visited.contains(&next) {
571                break; // Already processed
572            }
573
574            let incoming: Vec<_> = graph
575                .graph
576                .neighbors_directed(next, petgraph::Incoming)
577                .collect();
578            if incoming.len() != 1 {
579                break; // Next node has multiple inputs
580            }
581
582            chain.push(next);
583            visited.insert(next);
584            current = next;
585        }
586
587        chain
588    }
589
590    /// Detect fan-out patterns
591    fn detect_fanout_patterns(graph: &FxGraph) -> Vec<DetectedPattern> {
592        let mut patterns = Vec::new();
593
594        for (idx, _node) in graph.nodes() {
595            let fanout = graph.graph.neighbors(idx).count();
596            if fanout > 5 {
597                // High fan-out threshold
598                let _neighbors: Vec<_> = graph.graph.neighbors(idx).collect();
599                patterns.push(DetectedPattern {
600                    pattern_type: "high_fanout".to_string(),
601                    description: format!("High fan-out node with {} outputs", fanout),
602                    nodes: vec![idx],
603                    confidence: 1.0,
604                    optimization_potential: "Consider broadcast optimization or result caching"
605                        .to_string(),
606                });
607            }
608        }
609
610        patterns
611    }
612
613    /// Detect potential bottlenecks
614    fn detect_bottlenecks(graph: &FxGraph) -> Vec<DetectedPattern> {
615        let mut patterns = Vec::new();
616
617        for (idx, _node) in graph.nodes() {
618            let incoming = graph
619                .graph
620                .neighbors_directed(idx, petgraph::Incoming)
621                .count();
622            let _outgoing = graph
623                .graph
624                .neighbors_directed(idx, petgraph::Outgoing)
625                .count();
626
627            // High incoming connections might indicate a bottleneck
628            if incoming > 5 {
629                patterns.push(DetectedPattern {
630                    pattern_type: "potential_bottleneck".to_string(),
631                    description: format!("Node with {} incoming connections", incoming),
632                    nodes: vec![idx],
633                    confidence: 0.7,
634                    optimization_potential: "Consider parallelization or input batching"
635                        .to_string(),
636                });
637            }
638        }
639
640        patterns
641    }
642}
643
644/// Calculate comprehensive graph metrics
645pub fn calculate_graph_metrics(graph: &FxGraph) -> GraphMetrics {
646    let node_count = graph.node_count();
647    let edge_count = graph.edge_count();
648    let input_count = graph.inputs().len();
649    let output_count = graph.outputs().len();
650
651    // Calculate max depth using DFS
652    let max_depth = calculate_max_depth(graph);
653
654    // Calculate average fanout
655    let total_fanout: usize = graph
656        .nodes()
657        .map(|(idx, _)| graph.graph.neighbors(idx).count())
658        .sum();
659    let average_fanout = if node_count > 0 {
660        total_fanout as f64 / node_count as f64
661    } else {
662        0.0
663    };
664
665    // Calculate connectivity ratio
666    let max_possible_edges = if node_count > 1 {
667        node_count * (node_count - 1)
668    } else {
669        1
670    };
671    let connectivity_ratio = edge_count as f64 / max_possible_edges as f64;
672
673    // Calculate complexity score (heuristic)
674    let complexity_score =
675        (node_count as f64).ln() * (1.0 + connectivity_ratio) * (1.0 + average_fanout);
676
677    // Calculate operation distribution
678    let mut operation_distribution = HashMap::new();
679    for (_, node) in graph.nodes() {
680        let op_type = match node {
681            Node::Input(_) => "input",
682            Node::Call(op, _) => op,
683            Node::Output => "output",
684            Node::Conditional { .. } => "conditional",
685            Node::Loop { .. } => "loop",
686            Node::Merge { .. } => "merge",
687            Node::GetAttr { .. } => "getattr",
688        };
689        *operation_distribution
690            .entry(op_type.to_string())
691            .or_insert(0) += 1;
692    }
693
694    // Calculate critical path length (simplified)
695    let critical_path_length = max_depth;
696
697    GraphMetrics {
698        node_count,
699        edge_count,
700        input_count,
701        output_count,
702        max_depth,
703        average_fanout,
704        connectivity_ratio,
705        complexity_score,
706        operation_distribution,
707        critical_path_length,
708    }
709}
710
711/// Calculate maximum depth of the graph
712fn calculate_max_depth(graph: &FxGraph) -> usize {
713    let mut max_depth = 0;
714    let mut visited = HashSet::new();
715
716    for &input_idx in graph.inputs() {
717        let depth = calculate_depth_from_node(graph, input_idx, &mut visited, 0);
718        max_depth = max_depth.max(depth);
719    }
720
721    max_depth
722}
723
724/// Calculate depth from a specific node using DFS
725fn calculate_depth_from_node(
726    graph: &FxGraph,
727    node: NodeIndex,
728    visited: &mut HashSet<NodeIndex>,
729    current_depth: usize,
730) -> usize {
731    if visited.contains(&node) {
732        return current_depth;
733    }
734
735    visited.insert(node);
736    let mut max_child_depth = current_depth;
737
738    for neighbor in graph.graph.neighbors(node) {
739        let child_depth = calculate_depth_from_node(graph, neighbor, visited, current_depth + 1);
740        max_child_depth = max_child_depth.max(child_depth);
741    }
742
743    max_child_depth
744}
745
746impl Default for GraphLinter {
747    fn default() -> Self {
748        Self::new()
749    }
750}
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755    use crate::{Edge, FxGraph, Node};
756
757    #[test]
758    fn test_graph_linter() {
759        let mut graph = FxGraph::new();
760        let input = graph.graph.add_node(Node::Input("x".to_string()));
761        let relu = graph
762            .graph
763            .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
764        let output = graph.graph.add_node(Node::Output);
765
766        graph.graph.add_edge(
767            input,
768            relu,
769            Edge {
770                name: "x".to_string(),
771            },
772        );
773        graph.graph.add_edge(
774            relu,
775            output,
776            Edge {
777                name: "relu_out".to_string(),
778            },
779        );
780        graph.inputs.push(input);
781        graph.outputs.push(output);
782
783        let linter = GraphLinter::new();
784        let report = linter.lint_graph(&graph);
785
786        assert_eq!(report.total_issues, 0); // Should be a clean graph
787        assert!(report.overall_score > 0.8); // Should have a good score
788    }
789
790    #[test]
791    fn test_graph_linter_with_issues() {
792        let mut graph = FxGraph::new();
793        let _disconnected = graph.graph.add_node(Node::Call("relu".to_string(), vec![]));
794
795        // No inputs or outputs - should trigger warnings
796
797        let linter = GraphLinter::new();
798        let report = linter.lint_graph(&graph);
799
800        assert!(report.total_issues > 0);
801        assert!(report.overall_score < 1.0);
802    }
803
804    #[test]
805    fn test_graph_diff() {
806        let mut old_graph = FxGraph::new();
807        let _input1 = old_graph.graph.add_node(Node::Input("x".to_string()));
808        let _relu1 = old_graph
809            .graph
810            .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
811
812        let mut new_graph = FxGraph::new();
813        let _input2 = new_graph.graph.add_node(Node::Input("x".to_string()));
814        let _relu2 = new_graph
815            .graph
816            .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
817        let _sigmoid = new_graph.graph.add_node(Node::Call(
818            "sigmoid".to_string(),
819            vec!["relu_out".to_string()],
820        ));
821
822        let diff = GraphDiff::diff(&old_graph, &new_graph);
823
824        assert_eq!(diff.added_nodes.len(), 1); // sigmoid node added
825        assert_eq!(diff.removed_nodes.len(), 0);
826    }
827
828    #[test]
829    fn test_pattern_detection() {
830        let mut graph = FxGraph::new();
831        let input = graph.graph.add_node(Node::Input("x".to_string()));
832        let relu1 = graph
833            .graph
834            .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
835        let relu2 = graph
836            .graph
837            .add_node(Node::Call("relu".to_string(), vec!["relu1".to_string()]));
838        let relu3 = graph
839            .graph
840            .add_node(Node::Call("relu".to_string(), vec!["relu2".to_string()]));
841        let output = graph.graph.add_node(Node::Output);
842
843        // Create linear chain
844        graph.graph.add_edge(
845            input,
846            relu1,
847            Edge {
848                name: "x".to_string(),
849            },
850        );
851        graph.graph.add_edge(
852            relu1,
853            relu2,
854            Edge {
855                name: "relu1".to_string(),
856            },
857        );
858        graph.graph.add_edge(
859            relu2,
860            relu3,
861            Edge {
862                name: "relu2".to_string(),
863            },
864        );
865        graph.graph.add_edge(
866            relu3,
867            output,
868            Edge {
869                name: "relu3".to_string(),
870            },
871        );
872
873        let patterns = PatternDetector::detect_patterns(&graph);
874
875        assert!(!patterns.is_empty());
876        assert!(patterns.iter().any(|p| p.pattern_type == "linear_chain"));
877    }
878
879    #[test]
880    fn test_graph_metrics() {
881        let mut graph = FxGraph::new();
882        let input = graph.graph.add_node(Node::Input("x".to_string()));
883        let relu = graph
884            .graph
885            .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
886        let output = graph.graph.add_node(Node::Output);
887
888        graph.graph.add_edge(
889            input,
890            relu,
891            Edge {
892                name: "x".to_string(),
893            },
894        );
895        graph.graph.add_edge(
896            relu,
897            output,
898            Edge {
899                name: "relu_out".to_string(),
900            },
901        );
902        graph.inputs.push(input);
903        graph.outputs.push(output);
904
905        let metrics = calculate_graph_metrics(&graph);
906
907        assert_eq!(metrics.node_count, 3);
908        assert_eq!(metrics.edge_count, 2);
909        assert_eq!(metrics.input_count, 1);
910        assert_eq!(metrics.output_count, 1);
911        assert!(metrics.average_fanout > 0.0);
912    }
913}