tensorlogic_infer/
validation.rs1use std::collections::HashSet;
4
5use tensorlogic_ir::{EinsumGraph, OpType};
6
7use crate::error::ExecutorError;
8
9#[derive(Debug, Clone)]
11pub struct ValidationResult {
12 pub is_valid: bool,
13 pub errors: Vec<String>,
14 pub warnings: Vec<String>,
15}
16
17impl ValidationResult {
18 pub fn new() -> Self {
19 ValidationResult {
20 is_valid: true,
21 errors: Vec::new(),
22 warnings: Vec::new(),
23 }
24 }
25
26 pub fn add_error(&mut self, error: impl Into<String>) {
27 self.is_valid = false;
28 self.errors.push(error.into());
29 }
30
31 pub fn add_warning(&mut self, warning: impl Into<String>) {
32 self.warnings.push(warning.into());
33 }
34
35 pub fn merge(&mut self, other: ValidationResult) {
36 self.is_valid &= other.is_valid;
37 self.errors.extend(other.errors);
38 self.warnings.extend(other.warnings);
39 }
40
41 pub fn summary(&self) -> String {
42 let mut summary = String::new();
43 if self.is_valid {
44 summary.push_str("✓ Graph is valid\n");
45 } else {
46 summary.push_str("✗ Graph validation failed\n");
47 }
48
49 if !self.errors.is_empty() {
50 summary.push_str("\nErrors:\n");
51 for error in &self.errors {
52 summary.push_str(&format!(" - {}\n", error));
53 }
54 }
55
56 if !self.warnings.is_empty() {
57 summary.push_str("\nWarnings:\n");
58 for warning in &self.warnings {
59 summary.push_str(&format!(" - {}\n", warning));
60 }
61 }
62
63 summary
64 }
65}
66
67impl Default for ValidationResult {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73pub struct GraphValidator;
75
76impl GraphValidator {
77 pub fn new() -> Self {
78 GraphValidator
79 }
80
81 pub fn validate(&self, graph: &EinsumGraph) -> ValidationResult {
83 let mut result = ValidationResult::new();
84
85 if graph.nodes.is_empty() {
87 result.add_warning("Graph has no computation nodes");
88 }
89
90 self.validate_tensor_indices(graph, &mut result);
92
93 self.validate_dependencies(graph, &mut result);
95
96 self.validate_operations(graph, &mut result);
98
99 self.validate_dag(graph, &mut result);
101
102 result
103 }
104
105 fn validate_tensor_indices(&self, graph: &EinsumGraph, result: &mut ValidationResult) {
106 let num_tensors = graph.tensors.len();
107
108 for (node_idx, node) in graph.nodes.iter().enumerate() {
109 for &input_idx in &node.inputs {
110 let max_valid_idx = num_tensors + node_idx;
114
115 if input_idx >= max_valid_idx {
116 result.add_error(format!(
117 "Node {} references invalid tensor index {} (max valid: {})",
118 node_idx, input_idx, max_valid_idx
119 ));
120 }
121 }
122 }
123 }
124
125 fn validate_dependencies(&self, graph: &EinsumGraph, result: &mut ValidationResult) {
126 let num_tensors = graph.tensors.len();
127
128 for (node_idx, node) in graph.nodes.iter().enumerate() {
129 for &input_idx in &node.inputs {
131 if input_idx >= num_tensors {
132 let dep_node_idx = input_idx - num_tensors;
133 if dep_node_idx >= node_idx {
134 result.add_error(format!(
135 "Node {} has forward dependency on node {}",
136 node_idx, dep_node_idx
137 ));
138 }
139 }
140 }
141 }
142 }
143
144 fn validate_operations(&self, graph: &EinsumGraph, result: &mut ValidationResult) {
145 for (node_idx, node) in graph.nodes.iter().enumerate() {
146 match &node.op {
147 OpType::Einsum { spec } => {
148 if spec.is_empty() {
149 result.add_error(format!("Node {} has empty einsum spec", node_idx));
150 }
151 if node.inputs.is_empty() {
152 result.add_error(format!("Node {} einsum has no inputs", node_idx));
153 }
154 }
155 OpType::ElemUnary { op: _ } => {
156 if node.inputs.len() != 1 {
157 result.add_error(format!(
158 "Node {} unary operation requires exactly 1 input, got {}",
159 node_idx,
160 node.inputs.len()
161 ));
162 }
163 }
164 OpType::ElemBinary { op: _ } => {
165 if node.inputs.len() != 2 {
166 result.add_error(format!(
167 "Node {} binary operation requires exactly 2 inputs, got {}",
168 node_idx,
169 node.inputs.len()
170 ));
171 }
172 }
173 OpType::Reduce { op: _, axes } => {
174 if node.inputs.len() != 1 {
175 result.add_error(format!(
176 "Node {} reduce operation requires exactly 1 input, got {}",
177 node_idx,
178 node.inputs.len()
179 ));
180 }
181 if axes.is_empty() {
182 result.add_warning(format!(
183 "Node {} reduce operation has no axes (identity operation)",
184 node_idx
185 ));
186 }
187 }
188 }
189 }
190 }
191
192 fn validate_dag(&self, graph: &EinsumGraph, result: &mut ValidationResult) {
193 let num_nodes = graph.nodes.len();
195 let num_tensors = graph.tensors.len();
196
197 let mut visited = vec![false; num_nodes];
199 let mut rec_stack = vec![false; num_nodes];
200
201 for node_idx in 0..num_nodes {
202 if !visited[node_idx]
203 && has_cycle_helper(node_idx, graph, num_tensors, &mut visited, &mut rec_stack)
204 {
205 result.add_error("Graph contains a cycle (not a DAG)");
206 break;
207 }
208 }
209 }
210
211 pub fn validate_or_error(&self, graph: &EinsumGraph) -> Result<(), ExecutorError> {
213 let result = self.validate(graph);
214 if result.is_valid {
215 Ok(())
216 } else {
217 Err(ExecutorError::GraphValidationError(
218 result.errors.join("; "),
219 ))
220 }
221 }
222
223 pub fn find_unreachable_nodes(&self, graph: &EinsumGraph) -> HashSet<usize> {
225 let num_nodes = graph.nodes.len();
226 let num_tensors = graph.tensors.len();
227
228 let mut reachable = HashSet::new();
229
230 if num_nodes > 0 {
232 let mut to_visit = vec![num_nodes - 1];
233 while let Some(node_idx) = to_visit.pop() {
234 if reachable.insert(node_idx) {
235 for &input_idx in &graph.nodes[node_idx].inputs {
237 if input_idx >= num_tensors {
238 let dep_node_idx = input_idx - num_tensors;
239 if !reachable.contains(&dep_node_idx) {
240 to_visit.push(dep_node_idx);
241 }
242 }
243 }
244 }
245 }
246 }
247
248 (0..num_nodes)
250 .filter(|idx| !reachable.contains(idx))
251 .collect()
252 }
253}
254
255#[allow(clippy::only_used_in_recursion)]
257fn has_cycle_helper(
258 node_idx: usize,
259 graph: &EinsumGraph,
260 num_tensors: usize,
261 visited: &mut [bool],
262 rec_stack: &mut [bool],
263) -> bool {
264 visited[node_idx] = true;
265 rec_stack[node_idx] = true;
266
267 for &input_idx in &graph.nodes[node_idx].inputs {
269 if input_idx >= num_tensors {
270 let dep_node_idx = input_idx - num_tensors;
271 if dep_node_idx >= visited.len() {
273 continue; }
275 if !visited[dep_node_idx] {
276 if has_cycle_helper(dep_node_idx, graph, num_tensors, visited, rec_stack) {
277 return true;
278 }
279 } else if rec_stack[dep_node_idx] {
280 return true;
281 }
282 }
283 }
284
285 rec_stack[node_idx] = false;
286 false
287}
288
289impl Default for GraphValidator {
290 fn default() -> Self {
291 Self::new()
292 }
293}