Skip to main content

tenflowers_core/graph/
analysis.rs

1//! Graph Analysis Operations
2//!
3//! This module provides graph analysis functionality including topological sorting,
4//! validation, and identification of input/output nodes.
5
6use super::core::*;
7use crate::error::TensorError;
8use std::collections::{HashMap, HashSet, VecDeque};
9
10impl Graph {
11    /// Compute and cache the topological order of nodes
12    pub fn compute_topological_order(&mut self) -> Result<&[NodeId], TensorError> {
13        if let Some(ref order) = self.topological_order {
14            return Ok(order);
15        }
16
17        let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
18        let mut adjacency: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
19
20        // Initialize in-degree count and adjacency list
21        for node in self.nodes.values() {
22            in_degree.insert(node.id, 0);
23            adjacency.insert(node.id, Vec::new());
24        }
25
26        // Build adjacency list and count in-degrees
27        for edge in self.edges.values() {
28            if !edge.is_control {
29                // Only consider data dependencies for topological order
30                adjacency
31                    .get_mut(&edge.from_node)
32                    .expect("Adjacency entry must exist for all nodes")
33                    .push(edge.to_node);
34                *in_degree
35                    .get_mut(&edge.to_node)
36                    .expect("In-degree entry must exist for all nodes") += 1;
37            }
38        }
39
40        // Kahn's algorithm
41        let mut queue: VecDeque<NodeId> = VecDeque::new();
42        let mut result: Vec<NodeId> = Vec::new();
43
44        // Start with nodes that have no incoming edges
45        for (&node_id, &degree) in &in_degree {
46            if degree == 0 {
47                queue.push_back(node_id);
48            }
49        }
50
51        while let Some(node_id) = queue.pop_front() {
52            result.push(node_id);
53
54            // Remove this node from the graph and update in-degrees
55            for &neighbor in adjacency
56                .get(&node_id)
57                .expect("Adjacency entry must exist for all nodes")
58            {
59                let neighbor_degree = in_degree
60                    .get_mut(&neighbor)
61                    .expect("In-degree entry must exist for all nodes");
62                *neighbor_degree -= 1;
63                if *neighbor_degree == 0 {
64                    queue.push_back(neighbor);
65                }
66            }
67        }
68
69        // Check for cycles
70        if result.len() != self.nodes.len() {
71            return Err(TensorError::invalid_argument(
72                "Graph contains cycles".to_string(),
73            ));
74        }
75
76        self.topological_order = Some(result);
77        Ok(self
78            .topological_order
79            .as_ref()
80            .expect("Topological order must be present after assignment"))
81    }
82
83    /// Validate the graph structure
84    pub fn validate(&self) -> Result<(), TensorError> {
85        // Check that all edge endpoints reference valid nodes
86        for edge in self.edges.values() {
87            if !self.nodes.contains_key(&edge.from_node) {
88                return Err(TensorError::invalid_argument(format!(
89                    "Edge {} references non-existent source node {}",
90                    edge.id, edge.from_node
91                )));
92            }
93            if !self.nodes.contains_key(&edge.to_node) {
94                return Err(TensorError::invalid_argument(format!(
95                    "Edge {} references non-existent destination node {}",
96                    edge.id, edge.to_node
97                )));
98            }
99        }
100
101        // Check that node edge lists are consistent with actual edges
102        for node in self.nodes.values() {
103            for &edge_id in &node.inputs {
104                if let Some(edge) = self.edges.get(&edge_id) {
105                    if edge.to_node != node.id {
106                        return Err(TensorError::invalid_argument(format!(
107                            "Node {} lists edge {} as input, but edge points to node {}",
108                            node.id, edge_id, edge.to_node
109                        )));
110                    }
111                } else {
112                    return Err(TensorError::invalid_argument(format!(
113                        "Node {} references non-existent input edge {}",
114                        node.id, edge_id
115                    )));
116                }
117            }
118
119            for &edge_id in &node.outputs {
120                if let Some(edge) = self.edges.get(&edge_id) {
121                    if edge.from_node != node.id {
122                        return Err(TensorError::invalid_argument(format!(
123                            "Node {} lists edge {} as output, but edge comes from node {}",
124                            node.id, edge_id, edge.from_node
125                        )));
126                    }
127                } else {
128                    return Err(TensorError::invalid_argument(format!(
129                        "Node {} references non-existent output edge {}",
130                        node.id, edge_id
131                    )));
132                }
133            }
134        }
135
136        // Check for name uniqueness
137        let mut seen_names = HashSet::new();
138        for node in self.nodes.values() {
139            if !seen_names.insert(&node.name) {
140                return Err(TensorError::invalid_argument(format!(
141                    "Duplicate node name: '{}'",
142                    node.name
143                )));
144            }
145        }
146
147        Ok(())
148    }
149
150    /// Find all input nodes (nodes with no incoming data edges)
151    pub fn input_nodes(&self) -> Vec<NodeId> {
152        self.nodes
153            .values()
154            .filter(|node| {
155                !node.inputs.iter().any(|&edge_id| {
156                    self.edges
157                        .get(&edge_id)
158                        .is_some_and(|edge| !edge.is_control)
159                })
160            })
161            .map(|node| node.id)
162            .collect()
163    }
164
165    /// Find all output nodes (nodes with no outgoing data edges)
166    pub fn output_nodes(&self) -> Vec<NodeId> {
167        self.nodes
168            .values()
169            .filter(|node| {
170                !node.outputs.iter().any(|&edge_id| {
171                    self.edges
172                        .get(&edge_id)
173                        .is_some_and(|edge| !edge.is_control)
174                })
175            })
176            .map(|node| node.id)
177            .collect()
178    }
179}