tenflowers_core/graph/
analysis.rs1use super::core::*;
7use crate::error::TensorError;
8use std::collections::{HashMap, HashSet, VecDeque};
9
10impl Graph {
11 pub fn compute_topological_order(&mut self) -> Result<&[NodeId], TensorError> {
13 if let Some(ref order) = self.topological_order {
14 return Ok(order);
15 }
16
17 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
18 let mut adjacency: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
19
20 for node in self.nodes.values() {
22 in_degree.insert(node.id, 0);
23 adjacency.insert(node.id, Vec::new());
24 }
25
26 for edge in self.edges.values() {
28 if !edge.is_control {
29 adjacency
31 .get_mut(&edge.from_node)
32 .expect("Adjacency entry must exist for all nodes")
33 .push(edge.to_node);
34 *in_degree
35 .get_mut(&edge.to_node)
36 .expect("In-degree entry must exist for all nodes") += 1;
37 }
38 }
39
40 let mut queue: VecDeque<NodeId> = VecDeque::new();
42 let mut result: Vec<NodeId> = Vec::new();
43
44 for (&node_id, °ree) in &in_degree {
46 if degree == 0 {
47 queue.push_back(node_id);
48 }
49 }
50
51 while let Some(node_id) = queue.pop_front() {
52 result.push(node_id);
53
54 for &neighbor in adjacency
56 .get(&node_id)
57 .expect("Adjacency entry must exist for all nodes")
58 {
59 let neighbor_degree = in_degree
60 .get_mut(&neighbor)
61 .expect("In-degree entry must exist for all nodes");
62 *neighbor_degree -= 1;
63 if *neighbor_degree == 0 {
64 queue.push_back(neighbor);
65 }
66 }
67 }
68
69 if result.len() != self.nodes.len() {
71 return Err(TensorError::invalid_argument(
72 "Graph contains cycles".to_string(),
73 ));
74 }
75
76 self.topological_order = Some(result);
77 Ok(self
78 .topological_order
79 .as_ref()
80 .expect("Topological order must be present after assignment"))
81 }
82
83 pub fn validate(&self) -> Result<(), TensorError> {
85 for edge in self.edges.values() {
87 if !self.nodes.contains_key(&edge.from_node) {
88 return Err(TensorError::invalid_argument(format!(
89 "Edge {} references non-existent source node {}",
90 edge.id, edge.from_node
91 )));
92 }
93 if !self.nodes.contains_key(&edge.to_node) {
94 return Err(TensorError::invalid_argument(format!(
95 "Edge {} references non-existent destination node {}",
96 edge.id, edge.to_node
97 )));
98 }
99 }
100
101 for node in self.nodes.values() {
103 for &edge_id in &node.inputs {
104 if let Some(edge) = self.edges.get(&edge_id) {
105 if edge.to_node != node.id {
106 return Err(TensorError::invalid_argument(format!(
107 "Node {} lists edge {} as input, but edge points to node {}",
108 node.id, edge_id, edge.to_node
109 )));
110 }
111 } else {
112 return Err(TensorError::invalid_argument(format!(
113 "Node {} references non-existent input edge {}",
114 node.id, edge_id
115 )));
116 }
117 }
118
119 for &edge_id in &node.outputs {
120 if let Some(edge) = self.edges.get(&edge_id) {
121 if edge.from_node != node.id {
122 return Err(TensorError::invalid_argument(format!(
123 "Node {} lists edge {} as output, but edge comes from node {}",
124 node.id, edge_id, edge.from_node
125 )));
126 }
127 } else {
128 return Err(TensorError::invalid_argument(format!(
129 "Node {} references non-existent output edge {}",
130 node.id, edge_id
131 )));
132 }
133 }
134 }
135
136 let mut seen_names = HashSet::new();
138 for node in self.nodes.values() {
139 if !seen_names.insert(&node.name) {
140 return Err(TensorError::invalid_argument(format!(
141 "Duplicate node name: '{}'",
142 node.name
143 )));
144 }
145 }
146
147 Ok(())
148 }
149
150 pub fn input_nodes(&self) -> Vec<NodeId> {
152 self.nodes
153 .values()
154 .filter(|node| {
155 !node.inputs.iter().any(|&edge_id| {
156 self.edges
157 .get(&edge_id)
158 .is_some_and(|edge| !edge.is_control)
159 })
160 })
161 .map(|node| node.id)
162 .collect()
163 }
164
165 pub fn output_nodes(&self) -> Vec<NodeId> {
167 self.nodes
168 .values()
169 .filter(|node| {
170 !node.outputs.iter().any(|&edge_id| {
171 self.edges
172 .get(&edge_id)
173 .is_some_and(|edge| !edge.is_control)
174 })
175 })
176 .map(|node| node.id)
177 .collect()
178 }
179}