Skip to main content

tensorlogic_infer/
validation.rs

1//! Graph validation utilities for ensuring well-formed execution graphs.
2
3use std::collections::HashSet;
4
5use tensorlogic_ir::{EinsumGraph, OpType};
6
7use crate::error::ExecutorError;
8
9/// Validation result with detailed diagnostics
10#[derive(Debug, Clone)]
11pub struct ValidationResult {
12    pub is_valid: bool,
13    pub errors: Vec<String>,
14    pub warnings: Vec<String>,
15}
16
17impl ValidationResult {
18    pub fn new() -> Self {
19        ValidationResult {
20            is_valid: true,
21            errors: Vec::new(),
22            warnings: Vec::new(),
23        }
24    }
25
26    pub fn add_error(&mut self, error: impl Into<String>) {
27        self.is_valid = false;
28        self.errors.push(error.into());
29    }
30
31    pub fn add_warning(&mut self, warning: impl Into<String>) {
32        self.warnings.push(warning.into());
33    }
34
35    pub fn merge(&mut self, other: ValidationResult) {
36        self.is_valid &= other.is_valid;
37        self.errors.extend(other.errors);
38        self.warnings.extend(other.warnings);
39    }
40
41    pub fn summary(&self) -> String {
42        let mut summary = String::new();
43        if self.is_valid {
44            summary.push_str("✓ Graph is valid\n");
45        } else {
46            summary.push_str("✗ Graph validation failed\n");
47        }
48
49        if !self.errors.is_empty() {
50            summary.push_str("\nErrors:\n");
51            for error in &self.errors {
52                summary.push_str(&format!("  - {}\n", error));
53            }
54        }
55
56        if !self.warnings.is_empty() {
57            summary.push_str("\nWarnings:\n");
58            for warning in &self.warnings {
59                summary.push_str(&format!("  - {}\n", warning));
60            }
61        }
62
63        summary
64    }
65}
66
67impl Default for ValidationResult {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73/// Graph validator
74pub struct GraphValidator;
75
76impl GraphValidator {
77    pub fn new() -> Self {
78        GraphValidator
79    }
80
81    /// Validate a complete execution graph
82    pub fn validate(&self, graph: &EinsumGraph) -> ValidationResult {
83        let mut result = ValidationResult::new();
84
85        // Check if graph is empty
86        if graph.nodes.is_empty() {
87            result.add_warning("Graph has no computation nodes");
88        }
89
90        // Validate tensor indices
91        self.validate_tensor_indices(graph, &mut result);
92
93        // Validate node dependencies
94        self.validate_dependencies(graph, &mut result);
95
96        // Validate operations
97        self.validate_operations(graph, &mut result);
98
99        // Check for cycles (DAG property)
100        self.validate_dag(graph, &mut result);
101
102        result
103    }
104
105    fn validate_tensor_indices(&self, graph: &EinsumGraph, result: &mut ValidationResult) {
106        let num_tensors = graph.tensors.len();
107
108        for (node_idx, node) in graph.nodes.iter().enumerate() {
109            for &input_idx in &node.inputs {
110                // Input indices should be either:
111                // 1. Input tensors (< num_tensors)
112                // 2. Outputs from previous nodes (>= num_tensors && < num_tensors + node_idx)
113                let max_valid_idx = num_tensors + node_idx;
114
115                if input_idx >= max_valid_idx {
116                    result.add_error(format!(
117                        "Node {} references invalid tensor index {} (max valid: {})",
118                        node_idx, input_idx, max_valid_idx
119                    ));
120                }
121            }
122        }
123    }
124
125    fn validate_dependencies(&self, graph: &EinsumGraph, result: &mut ValidationResult) {
126        let num_tensors = graph.tensors.len();
127
128        for (node_idx, node) in graph.nodes.iter().enumerate() {
129            // Check that all dependencies come from earlier in the graph
130            for &input_idx in &node.inputs {
131                if input_idx >= num_tensors {
132                    let dep_node_idx = input_idx - num_tensors;
133                    if dep_node_idx >= node_idx {
134                        result.add_error(format!(
135                            "Node {} has forward dependency on node {}",
136                            node_idx, dep_node_idx
137                        ));
138                    }
139                }
140            }
141        }
142    }
143
144    fn validate_operations(&self, graph: &EinsumGraph, result: &mut ValidationResult) {
145        for (node_idx, node) in graph.nodes.iter().enumerate() {
146            match &node.op {
147                OpType::Einsum { spec } => {
148                    if spec.is_empty() {
149                        result.add_error(format!("Node {} has empty einsum spec", node_idx));
150                    }
151                    if node.inputs.is_empty() {
152                        result.add_error(format!("Node {} einsum has no inputs", node_idx));
153                    }
154                }
155                OpType::ElemUnary { op: _ } => {
156                    if node.inputs.len() != 1 {
157                        result.add_error(format!(
158                            "Node {} unary operation requires exactly 1 input, got {}",
159                            node_idx,
160                            node.inputs.len()
161                        ));
162                    }
163                }
164                OpType::ElemBinary { op: _ } => {
165                    if node.inputs.len() != 2 {
166                        result.add_error(format!(
167                            "Node {} binary operation requires exactly 2 inputs, got {}",
168                            node_idx,
169                            node.inputs.len()
170                        ));
171                    }
172                }
173                OpType::Reduce { op: _, axes } => {
174                    if node.inputs.len() != 1 {
175                        result.add_error(format!(
176                            "Node {} reduce operation requires exactly 1 input, got {}",
177                            node_idx,
178                            node.inputs.len()
179                        ));
180                    }
181                    if axes.is_empty() {
182                        result.add_warning(format!(
183                            "Node {} reduce operation has no axes (identity operation)",
184                            node_idx
185                        ));
186                    }
187                }
188            }
189        }
190    }
191
192    fn validate_dag(&self, graph: &EinsumGraph, result: &mut ValidationResult) {
193        // Build adjacency list
194        let num_nodes = graph.nodes.len();
195        let num_tensors = graph.tensors.len();
196
197        // Detect cycles using DFS
198        let mut visited = vec![false; num_nodes];
199        let mut rec_stack = vec![false; num_nodes];
200
201        for node_idx in 0..num_nodes {
202            if !visited[node_idx]
203                && has_cycle_helper(node_idx, graph, num_tensors, &mut visited, &mut rec_stack)
204            {
205                result.add_error("Graph contains a cycle (not a DAG)");
206                break;
207            }
208        }
209    }
210
211    /// Quick validation that returns an error if graph is invalid
212    pub fn validate_or_error(&self, graph: &EinsumGraph) -> Result<(), ExecutorError> {
213        let result = self.validate(graph);
214        if result.is_valid {
215            Ok(())
216        } else {
217            Err(ExecutorError::GraphValidationError(
218                result.errors.join("; "),
219            ))
220        }
221    }
222
223    /// Check if graph has any unreachable nodes
224    pub fn find_unreachable_nodes(&self, graph: &EinsumGraph) -> HashSet<usize> {
225        let num_nodes = graph.nodes.len();
226        let num_tensors = graph.tensors.len();
227
228        let mut reachable = HashSet::new();
229
230        // Work backwards from the last node
231        if num_nodes > 0 {
232            let mut to_visit = vec![num_nodes - 1];
233            while let Some(node_idx) = to_visit.pop() {
234                if reachable.insert(node_idx) {
235                    // Add dependencies to visit list
236                    for &input_idx in &graph.nodes[node_idx].inputs {
237                        if input_idx >= num_tensors {
238                            let dep_node_idx = input_idx - num_tensors;
239                            if !reachable.contains(&dep_node_idx) {
240                                to_visit.push(dep_node_idx);
241                            }
242                        }
243                    }
244                }
245            }
246        }
247
248        // Return nodes that are not reachable
249        (0..num_nodes)
250            .filter(|idx| !reachable.contains(idx))
251            .collect()
252    }
253}
254
255// Helper function to detect cycles (separate to avoid clippy recursion warning)
256#[allow(clippy::only_used_in_recursion)]
257fn has_cycle_helper(
258    node_idx: usize,
259    graph: &EinsumGraph,
260    num_tensors: usize,
261    visited: &mut [bool],
262    rec_stack: &mut [bool],
263) -> bool {
264    visited[node_idx] = true;
265    rec_stack[node_idx] = true;
266
267    // Check all dependencies
268    for &input_idx in &graph.nodes[node_idx].inputs {
269        if input_idx >= num_tensors {
270            let dep_node_idx = input_idx - num_tensors;
271            // Bounds check to prevent panic on invalid indices
272            if dep_node_idx >= visited.len() {
273                continue; // Skip invalid indices - they'll be caught by validate_tensor_indices
274            }
275            if !visited[dep_node_idx] {
276                if has_cycle_helper(dep_node_idx, graph, num_tensors, visited, rec_stack) {
277                    return true;
278                }
279            } else if rec_stack[dep_node_idx] {
280                return true;
281            }
282        }
283    }
284
285    rec_stack[node_idx] = false;
286    false
287}
288
289impl Default for GraphValidator {
290    fn default() -> Self {
291        Self::new()
292    }
293}