Skip to main content

tensorlogic_ir/graph/
validation.rs

1//! Graph validation utilities for post-compilation checks.
2//!
3//! This module provides comprehensive validation for `EinsumGraph` instances,
4//! checking for common errors and structural issues that may occur during compilation.
5
6use crate::graph::{EinsumGraph, OpType};
7use std::collections::{HashMap, HashSet};
8
9/// Result of graph validation with detailed diagnostics.
10#[derive(Debug, Clone)]
11pub struct ValidationReport {
12    /// Total number of validation checks performed
13    pub checks_performed: usize,
14    /// List of errors found
15    pub errors: Vec<ValidationError>,
16    /// List of warnings (non-fatal issues)
17    pub warnings: Vec<ValidationWarning>,
18    /// Graph statistics
19    pub stats: GraphValidationStats,
20}
21
22/// Validation error with severity and context.
23#[derive(Debug, Clone)]
24pub struct ValidationError {
25    pub kind: ValidationErrorKind,
26    pub message: String,
27    pub node_index: Option<usize>,
28    pub tensor_index: Option<usize>,
29}
30
31/// Types of validation errors.
32#[derive(Debug, Clone, PartialEq)]
33pub enum ValidationErrorKind {
34    /// Tensor index out of bounds
35    TensorOutOfBounds,
36    /// Undefined tensor referenced
37    UndefinedTensor,
38    /// Tensor is never produced (no node writes to it)
39    UnproducedTensor,
40    /// Output tensor has no producer
41    OutputWithoutProducer,
42    /// Cyclic dependency detected
43    CyclicDependency,
44    /// Empty einsum specification
45    EmptyEinsumSpec,
46    /// Invalid einsum specification syntax
47    InvalidEinsumSpec,
48    /// Node has no outputs
49    NoOutputs,
50    /// Duplicate output (two nodes write to same tensor)
51    DuplicateOutput,
52}
53
54/// Validation warning (non-fatal issue).
55#[derive(Debug, Clone)]
56pub struct ValidationWarning {
57    pub kind: ValidationWarningKind,
58    pub message: String,
59    pub tensor_index: Option<usize>,
60    pub node_index: Option<usize>,
61}
62
63/// Types of validation warnings.
64#[derive(Debug, Clone, PartialEq)]
65pub enum ValidationWarningKind {
66    /// Tensor is produced but never consumed
67    UnusedTensor,
68    /// Input tensor is never used
69    UnusedInput,
70    /// Tensor has unnamed or generated name
71    GeneratedTensorName,
72    /// Large number of operations (may be slow)
73    LargeGraph,
74    /// Deep operation nesting (may cause stack issues)
75    DeepNesting,
76}
77
78/// Statistics about the validated graph.
79#[derive(Debug, Clone, Default)]
80pub struct GraphValidationStats {
81    pub total_tensors: usize,
82    pub total_nodes: usize,
83    pub input_tensors: usize,
84    pub output_tensors: usize,
85    pub unused_tensors: usize,
86    pub max_operation_depth: usize,
87    pub einsum_operations: usize,
88    pub elem_unary_operations: usize,
89    pub elem_binary_operations: usize,
90    pub reduce_operations: usize,
91}
92
93impl ValidationReport {
94    /// Check if validation passed (no errors).
95    pub fn is_valid(&self) -> bool {
96        self.errors.is_empty()
97    }
98
99    /// Check if there are any issues (errors or warnings).
100    pub fn has_issues(&self) -> bool {
101        !self.errors.is_empty() || !self.warnings.is_empty()
102    }
103
104    /// Get a summary string of the validation results.
105    pub fn summary(&self) -> String {
106        format!(
107            "Validation: {} errors, {} warnings ({} checks)",
108            self.errors.len(),
109            self.warnings.len(),
110            self.checks_performed
111        )
112    }
113}
114
115/// Validate an `EinsumGraph` with comprehensive checks.
116///
117/// # Example
118///
119/// ```
120/// use tensorlogic_ir::{EinsumGraph, EinsumNode, validate_graph};
121///
122/// let mut graph = EinsumGraph::new();
123/// let t0 = graph.add_tensor("input".to_string());
124/// let t1 = graph.add_tensor("output".to_string());
125/// graph.inputs = vec![t0];
126/// graph.outputs = vec![t1];
127///
128/// let node = EinsumNode::elem_unary("relu", t0, t1);
129/// graph.add_node(node).unwrap();
130///
131/// let report = validate_graph(&graph);
132/// assert!(report.is_valid());
133/// ```
134pub fn validate_graph(graph: &EinsumGraph) -> ValidationReport {
135    let mut report = ValidationReport {
136        checks_performed: 0,
137        errors: Vec::new(),
138        warnings: Vec::new(),
139        stats: GraphValidationStats::default(),
140    };
141
142    // Collect statistics
143    report.stats.total_tensors = graph.tensors.len();
144    report.stats.total_nodes = graph.nodes.len();
145    report.stats.input_tensors = graph.inputs.len();
146    report.stats.output_tensors = graph.outputs.len();
147
148    // Check 1: Tensor index bounds
149    report.checks_performed += 1;
150    check_tensor_bounds(graph, &mut report);
151
152    // Check 2: Producer analysis (which tensors are written to)
153    report.checks_performed += 1;
154    let producers = analyze_producers(graph, &mut report);
155
156    // Check 3: Consumer analysis (which tensors are read from)
157    report.checks_performed += 1;
158    let consumers = analyze_consumers(graph, &mut report);
159
160    // Check 4: Output tensors have producers
161    report.checks_performed += 1;
162    check_output_producers(graph, &producers, &mut report);
163
164    // Check 5: Check for unused tensors
165    report.checks_performed += 1;
166    check_unused_tensors(graph, &producers, &consumers, &mut report);
167
168    // Check 6: Einsum specification validity
169    report.checks_performed += 1;
170    check_einsum_specs(graph, &mut report);
171
172    // Check 7: Check for cycles
173    report.checks_performed += 1;
174    check_cycles(graph, &mut report);
175
176    // Check 8: Nodes have outputs
177    report.checks_performed += 1;
178    check_node_outputs(graph, &mut report);
179
180    // Check 9: Count operation types
181    report.checks_performed += 1;
182    count_operations(graph, &mut report);
183
184    // Check 10: Check graph size warnings
185    report.checks_performed += 1;
186    check_graph_size(graph, &mut report);
187
188    report
189}
190
191/// Check that all tensor indices are within bounds.
192fn check_tensor_bounds(graph: &EinsumGraph, report: &mut ValidationReport) {
193    for (node_idx, node) in graph.nodes.iter().enumerate() {
194        for &input in &node.inputs {
195            if input >= graph.tensors.len() {
196                report.errors.push(ValidationError {
197                    kind: ValidationErrorKind::TensorOutOfBounds,
198                    message: format!(
199                        "Input tensor {} is out of bounds (max: {})",
200                        input,
201                        graph.tensors.len() - 1
202                    ),
203                    node_index: Some(node_idx),
204                    tensor_index: Some(input),
205                });
206            }
207        }
208
209        for &output in &node.outputs {
210            if output >= graph.tensors.len() {
211                report.errors.push(ValidationError {
212                    kind: ValidationErrorKind::TensorOutOfBounds,
213                    message: format!(
214                        "Output tensor {} is out of bounds (max: {})",
215                        output,
216                        graph.tensors.len() - 1
217                    ),
218                    node_index: Some(node_idx),
219                    tensor_index: Some(output),
220                });
221            }
222        }
223    }
224}
225
226/// Analyze which nodes produce which tensors.
227fn analyze_producers(graph: &EinsumGraph, report: &mut ValidationReport) -> HashMap<usize, usize> {
228    let mut producers = HashMap::new();
229
230    for (node_idx, node) in graph.nodes.iter().enumerate() {
231        for &output in &node.outputs {
232            if let Some(existing_producer) = producers.insert(output, node_idx) {
233                report.errors.push(ValidationError {
234                    kind: ValidationErrorKind::DuplicateOutput,
235                    message: format!(
236                        "Tensor {} is produced by multiple nodes: {} and {}",
237                        output, existing_producer, node_idx
238                    ),
239                    node_index: Some(node_idx),
240                    tensor_index: Some(output),
241                });
242            }
243        }
244    }
245
246    producers
247}
248
249/// Analyze which nodes consume which tensors.
250fn analyze_consumers(
251    graph: &EinsumGraph,
252    _report: &mut ValidationReport,
253) -> HashMap<usize, Vec<usize>> {
254    let mut consumers: HashMap<usize, Vec<usize>> = HashMap::new();
255
256    for (node_idx, node) in graph.nodes.iter().enumerate() {
257        for &input in &node.inputs {
258            consumers.entry(input).or_default().push(node_idx);
259        }
260    }
261
262    consumers
263}
264
265/// Check that output tensors have producers.
266fn check_output_producers(
267    graph: &EinsumGraph,
268    producers: &HashMap<usize, usize>,
269    report: &mut ValidationReport,
270) {
271    for &output_idx in &graph.outputs {
272        if output_idx >= graph.tensors.len() {
273            continue; // Already reported in bounds check
274        }
275
276        if !producers.contains_key(&output_idx) && !graph.inputs.contains(&output_idx) {
277            report.errors.push(ValidationError {
278                kind: ValidationErrorKind::OutputWithoutProducer,
279                message: format!(
280                    "Output tensor {} '{}' has no producer",
281                    output_idx, graph.tensors[output_idx]
282                ),
283                node_index: None,
284                tensor_index: Some(output_idx),
285            });
286        }
287    }
288}
289
290/// Check for unused tensors.
291fn check_unused_tensors(
292    graph: &EinsumGraph,
293    producers: &HashMap<usize, usize>,
294    consumers: &HashMap<usize, Vec<usize>>,
295    report: &mut ValidationReport,
296) {
297    for (tensor_idx, tensor_name) in graph.tensors.iter().enumerate() {
298        let is_input = graph.inputs.contains(&tensor_idx);
299        let is_output = graph.outputs.contains(&tensor_idx);
300        let has_producer = producers.contains_key(&tensor_idx);
301        let has_consumers = consumers.contains_key(&tensor_idx);
302
303        // Tensor is produced but never used (and not an output)
304        if has_producer && !has_consumers && !is_output {
305            report.warnings.push(ValidationWarning {
306                kind: ValidationWarningKind::UnusedTensor,
307                message: format!(
308                    "Tensor {} '{}' is produced but never consumed",
309                    tensor_idx, tensor_name
310                ),
311                tensor_index: Some(tensor_idx),
312                node_index: None,
313            });
314            report.stats.unused_tensors += 1;
315        }
316
317        // Input tensor is never used
318        if is_input && !has_consumers {
319            report.warnings.push(ValidationWarning {
320                kind: ValidationWarningKind::UnusedInput,
321                message: format!(
322                    "Input tensor {} '{}' is never consumed",
323                    tensor_idx, tensor_name
324                ),
325                tensor_index: Some(tensor_idx),
326                node_index: None,
327            });
328        }
329
330        // Check for generated names (e.g., "temp_0", "t_123")
331        if tensor_name.starts_with("temp_")
332            || tensor_name.starts_with("t_")
333            || tensor_name.starts_with("_")
334        {
335            report.warnings.push(ValidationWarning {
336                kind: ValidationWarningKind::GeneratedTensorName,
337                message: format!("Tensor {} has generated name '{}'", tensor_idx, tensor_name),
338                tensor_index: Some(tensor_idx),
339                node_index: None,
340            });
341        }
342    }
343}
344
345/// Check einsum specifications for validity.
346fn check_einsum_specs(graph: &EinsumGraph, report: &mut ValidationReport) {
347    for (node_idx, node) in graph.nodes.iter().enumerate() {
348        if let OpType::Einsum { spec } = &node.op {
349            if spec.is_empty() {
350                report.errors.push(ValidationError {
351                    kind: ValidationErrorKind::EmptyEinsumSpec,
352                    message: "Einsum operation has empty specification".to_string(),
353                    node_index: Some(node_idx),
354                    tensor_index: None,
355                });
356            }
357
358            // Basic syntax check: should contain "->"
359            if !spec.contains("->") {
360                report.errors.push(ValidationError {
361                    kind: ValidationErrorKind::InvalidEinsumSpec,
362                    message: format!("Einsum specification '{}' is invalid (missing '->')", spec),
363                    node_index: Some(node_idx),
364                    tensor_index: None,
365                });
366            }
367        }
368    }
369}
370
371/// Check for cyclic dependencies in the graph.
372fn check_cycles(graph: &EinsumGraph, report: &mut ValidationReport) {
373    // Build dependency map: which tensors does each node depend on
374    let mut visited = HashSet::new();
375    let mut rec_stack = HashSet::new();
376
377    for node_idx in 0..graph.nodes.len() {
378        if !visited.contains(&node_idx)
379            && has_cycle_dfs(node_idx, graph, &mut visited, &mut rec_stack)
380        {
381            report.errors.push(ValidationError {
382                kind: ValidationErrorKind::CyclicDependency,
383                message: format!("Cyclic dependency detected involving node {}", node_idx),
384                node_index: Some(node_idx),
385                tensor_index: None,
386            });
387        }
388    }
389}
390
391/// DFS helper for cycle detection.
392fn has_cycle_dfs(
393    node_idx: usize,
394    graph: &EinsumGraph,
395    visited: &mut HashSet<usize>,
396    rec_stack: &mut HashSet<usize>,
397) -> bool {
398    visited.insert(node_idx);
399    rec_stack.insert(node_idx);
400
401    let node = &graph.nodes[node_idx];
402
403    // Find nodes that depend on this node's outputs
404    for &output in &node.outputs {
405        for (next_node_idx, next_node) in graph.nodes.iter().enumerate() {
406            if next_node.inputs.contains(&output) {
407                if !visited.contains(&next_node_idx) {
408                    if has_cycle_dfs(next_node_idx, graph, visited, rec_stack) {
409                        return true;
410                    }
411                } else if rec_stack.contains(&next_node_idx) {
412                    return true;
413                }
414            }
415        }
416    }
417
418    rec_stack.remove(&node_idx);
419    false
420}
421
422/// Check that all nodes produce outputs.
423fn check_node_outputs(graph: &EinsumGraph, report: &mut ValidationReport) {
424    for (node_idx, node) in graph.nodes.iter().enumerate() {
425        if node.outputs.is_empty() {
426            report.errors.push(ValidationError {
427                kind: ValidationErrorKind::NoOutputs,
428                message: format!("Node {} has no outputs", node_idx),
429                node_index: Some(node_idx),
430                tensor_index: None,
431            });
432        }
433    }
434}
435
436/// Count operation types for statistics.
437fn count_operations(graph: &EinsumGraph, report: &mut ValidationReport) {
438    for node in &graph.nodes {
439        match &node.op {
440            OpType::Einsum { .. } => report.stats.einsum_operations += 1,
441            OpType::ElemUnary { .. } => report.stats.elem_unary_operations += 1,
442            OpType::ElemBinary { .. } => report.stats.elem_binary_operations += 1,
443            OpType::Reduce { .. } => report.stats.reduce_operations += 1,
444        }
445    }
446}
447
448/// Check for large graphs that may have performance issues.
449fn check_graph_size(graph: &EinsumGraph, report: &mut ValidationReport) {
450    if graph.nodes.len() > 1000 {
451        report.warnings.push(ValidationWarning {
452            kind: ValidationWarningKind::LargeGraph,
453            message: format!(
454                "Graph has {} operations (may be slow to execute)",
455                graph.nodes.len()
456            ),
457            tensor_index: None,
458            node_index: None,
459        });
460    }
461
462    if graph.tensors.len() > 10000 {
463        report.warnings.push(ValidationWarning {
464            kind: ValidationWarningKind::LargeGraph,
465            message: format!(
466                "Graph has {} tensors (may use significant memory)",
467                graph.tensors.len()
468            ),
469            tensor_index: None,
470            node_index: None,
471        });
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use crate::{EinsumGraph, EinsumNode};
479
480    #[test]
481    fn test_validate_empty_graph() {
482        let graph = EinsumGraph::new();
483        let report = validate_graph(&graph);
484        assert!(report.is_valid());
485        assert_eq!(report.errors.len(), 0);
486    }
487
488    #[test]
489    fn test_validate_simple_graph() {
490        let mut graph = EinsumGraph::new();
491        let t0 = graph.add_tensor("input".to_string());
492        let t1 = graph.add_tensor("output".to_string());
493        graph.inputs = vec![t0];
494        graph.outputs = vec![t1];
495
496        let node = EinsumNode::elem_unary("relu", t0, t1);
497        graph.add_node(node).unwrap();
498
499        let report = validate_graph(&graph);
500        assert!(report.is_valid());
501        assert_eq!(report.stats.total_tensors, 2);
502        assert_eq!(report.stats.total_nodes, 1);
503    }
504
505    #[test]
506    fn test_detect_tensor_out_of_bounds() {
507        let mut graph = EinsumGraph::new();
508        let t0 = graph.add_tensor("input".to_string());
509        graph.add_tensor("output".to_string());
510
511        // Create node with invalid tensor index
512        let bad_node = EinsumNode::elem_unary("relu", t0, 999);
513        graph.nodes.push(bad_node);
514
515        let report = validate_graph(&graph);
516        assert!(!report.is_valid());
517        assert_eq!(report.errors.len(), 1);
518        assert_eq!(
519            report.errors[0].kind,
520            ValidationErrorKind::TensorOutOfBounds
521        );
522    }
523
524    #[test]
525    fn test_detect_unused_tensor() {
526        let mut graph = EinsumGraph::new();
527        let t0 = graph.add_tensor("input".to_string());
528        let t1 = graph.add_tensor("intermediate".to_string());
529        let t2 = graph.add_tensor("output".to_string());
530        graph.inputs = vec![t0];
531        graph.outputs = vec![t2];
532
533        // t1 is produced but never used
534        graph
535            .add_node(EinsumNode::elem_unary("relu", t0, t1))
536            .unwrap();
537        graph
538            .add_node(EinsumNode::elem_unary("sigmoid", t0, t2))
539            .unwrap();
540
541        let report = validate_graph(&graph);
542        assert!(report.is_valid()); // No errors, just warnings
543        assert_eq!(report.warnings.len(), 1);
544        assert_eq!(report.warnings[0].kind, ValidationWarningKind::UnusedTensor);
545    }
546
547    #[test]
548    fn test_detect_output_without_producer() {
549        let mut graph = EinsumGraph::new();
550        let t0 = graph.add_tensor("input".to_string());
551        let t1 = graph.add_tensor("output".to_string());
552        graph.inputs = vec![t0];
553        graph.outputs = vec![t1]; // t1 is output but no node produces it
554
555        let report = validate_graph(&graph);
556        assert!(!report.is_valid());
557        assert_eq!(report.errors.len(), 1);
558        assert_eq!(
559            report.errors[0].kind,
560            ValidationErrorKind::OutputWithoutProducer
561        );
562    }
563
564    #[test]
565    fn test_detect_empty_einsum_spec() {
566        let mut graph = EinsumGraph::new();
567        let t0 = graph.add_tensor("input".to_string());
568        let t1 = graph.add_tensor("output".to_string());
569
570        let bad_node = EinsumNode::einsum("", vec![t0], vec![t1]);
571        graph.nodes.push(bad_node);
572
573        let report = validate_graph(&graph);
574        assert!(!report.is_valid());
575        assert!(report
576            .errors
577            .iter()
578            .any(|e| e.kind == ValidationErrorKind::EmptyEinsumSpec));
579    }
580
581    #[test]
582    fn test_detect_invalid_einsum_spec() {
583        let mut graph = EinsumGraph::new();
584        let t0 = graph.add_tensor("input".to_string());
585        let t1 = graph.add_tensor("output".to_string());
586
587        let bad_node = EinsumNode::einsum("ijk", vec![t0], vec![t1]); // Missing "->"
588        graph.nodes.push(bad_node);
589
590        let report = validate_graph(&graph);
591        assert!(!report.is_valid());
592        assert!(report
593            .errors
594            .iter()
595            .any(|e| e.kind == ValidationErrorKind::InvalidEinsumSpec));
596    }
597
598    #[test]
599    fn test_statistics_collection() {
600        let mut graph = EinsumGraph::new();
601        let t0 = graph.add_tensor("a".to_string());
602        let t1 = graph.add_tensor("b".to_string());
603        let t2 = graph.add_tensor("c".to_string());
604        let t3 = graph.add_tensor("d".to_string());
605
606        graph
607            .add_node(EinsumNode::elem_unary("relu", t0, t1))
608            .unwrap();
609        graph
610            .add_node(EinsumNode::elem_binary("add", t1, t2, t3))
611            .unwrap();
612
613        let report = validate_graph(&graph);
614        assert_eq!(report.stats.elem_unary_operations, 1);
615        assert_eq!(report.stats.elem_binary_operations, 1);
616        assert_eq!(report.stats.total_nodes, 2);
617        assert_eq!(report.stats.total_tensors, 4);
618    }
619}