1use crate::EinsumGraph;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet, VecDeque};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub enum IssueSeverity {
17 Error,
19 Warning,
21 Info,
23}
24
25impl std::fmt::Display for IssueSeverity {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 match self {
28 IssueSeverity::Error => write!(f, "ERROR"),
29 IssueSeverity::Warning => write!(f, "WARNING"),
30 IssueSeverity::Info => write!(f, "INFO"),
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ValidationIssue {
38 pub severity: IssueSeverity,
40 pub code: String,
42 pub message: String,
44 pub node_index: Option<usize>,
46}
47
48impl std::fmt::Display for ValidationIssue {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 if let Some(idx) = self.node_index {
51 write!(
52 f,
53 "[{}] {} (node {}): {}",
54 self.severity, self.code, idx, self.message
55 )
56 } else {
57 write!(f, "[{}] {}: {}", self.severity, self.code, self.message)
58 }
59 }
60}
61
62#[derive(Debug, Clone, Default, Serialize, Deserialize)]
64pub struct ValidationResult {
65 pub issues: Vec<ValidationIssue>,
67}
68
69impl ValidationResult {
70 pub fn is_valid(&self) -> bool {
72 !self
73 .issues
74 .iter()
75 .any(|i| i.severity == IssueSeverity::Error)
76 }
77
78 pub fn error_count(&self) -> usize {
80 self.issues
81 .iter()
82 .filter(|i| i.severity == IssueSeverity::Error)
83 .count()
84 }
85
86 pub fn warning_count(&self) -> usize {
88 self.issues
89 .iter()
90 .filter(|i| i.severity == IssueSeverity::Warning)
91 .count()
92 }
93
94 pub fn info_count(&self) -> usize {
96 self.issues
97 .iter()
98 .filter(|i| i.severity == IssueSeverity::Info)
99 .count()
100 }
101
102 pub fn summary(&self) -> String {
104 format!(
105 "{} errors, {} warnings",
106 self.error_count(),
107 self.warning_count()
108 )
109 }
110
111 pub fn issues_by_severity(&self, severity: IssueSeverity) -> Vec<&ValidationIssue> {
113 self.issues
114 .iter()
115 .filter(|i| i.severity == severity)
116 .collect()
117 }
118
119 pub fn issues_by_code(&self, code: &str) -> Vec<&ValidationIssue> {
121 self.issues.iter().filter(|i| i.code == code).collect()
122 }
123}
124
125impl std::fmt::Display for ValidationResult {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 writeln!(f, "Validation: {}", self.summary())?;
128 for issue in &self.issues {
129 writeln!(f, " {}", issue)?;
130 }
131 Ok(())
132 }
133}
134
135pub fn validate_einsum_graph(graph: &EinsumGraph) -> ValidationResult {
156 let mut result = ValidationResult::default();
157
158 check_empty_graph(graph, &mut result);
159 check_duplicate_outputs(graph, &mut result);
160 check_node_input_refs(graph, &mut result);
161 check_node_output_refs(graph, &mut result);
162 check_unreachable_nodes(graph, &mut result);
163 check_output_refs(graph, &mut result);
164 check_outputs_have_producers(graph, &mut result);
165 check_cycles(graph, &mut result);
166 check_empty_node_outputs(graph, &mut result);
167 check_duplicate_tensor_names(graph, &mut result);
168
169 result
170}
171
172fn check_empty_graph(graph: &EinsumGraph, result: &mut ValidationResult) {
177 if graph.nodes.is_empty() {
178 result.issues.push(ValidationIssue {
179 severity: IssueSeverity::Warning,
180 code: "empty-graph".to_string(),
181 message: "Graph has no nodes".to_string(),
182 node_index: None,
183 });
184 }
185}
186
187fn check_duplicate_outputs(graph: &EinsumGraph, result: &mut ValidationResult) {
188 let mut seen = HashSet::new();
189 for &output in &graph.outputs {
190 if !seen.insert(output) {
191 result.issues.push(ValidationIssue {
192 severity: IssueSeverity::Warning,
193 code: "duplicate-output".to_string(),
194 message: format!("Duplicate output tensor index: {}", output),
195 node_index: None,
196 });
197 }
198 }
199}
200
201fn check_node_input_refs(graph: &EinsumGraph, result: &mut ValidationResult) {
202 let num_tensors = graph.tensors.len();
203 for (node_idx, node) in graph.nodes.iter().enumerate() {
204 for &input_idx in &node.inputs {
205 if input_idx >= num_tensors {
206 result.issues.push(ValidationIssue {
207 severity: IssueSeverity::Error,
208 code: "invalid-input-ref".to_string(),
209 message: format!(
210 "Node {} input references tensor index {} but only {} tensors exist",
211 node_idx, input_idx, num_tensors
212 ),
213 node_index: Some(node_idx),
214 });
215 }
216 }
217 }
218}
219
220fn check_node_output_refs(graph: &EinsumGraph, result: &mut ValidationResult) {
221 let num_tensors = graph.tensors.len();
222 for (node_idx, node) in graph.nodes.iter().enumerate() {
223 for &output_idx in &node.outputs {
224 if output_idx >= num_tensors {
225 result.issues.push(ValidationIssue {
226 severity: IssueSeverity::Error,
227 code: "invalid-output-ref".to_string(),
228 message: format!(
229 "Node {} output references tensor index {} but only {} tensors exist",
230 node_idx, output_idx, num_tensors
231 ),
232 node_index: Some(node_idx),
233 });
234 }
235 }
236 }
237}
238
239fn check_unreachable_nodes(graph: &EinsumGraph, result: &mut ValidationResult) {
240 if graph.nodes.is_empty() {
241 return;
242 }
243
244 let output_set: HashSet<usize> = graph.outputs.iter().copied().collect();
247
248 let mut consumed_tensors: HashSet<usize> = HashSet::new();
250 for node in &graph.nodes {
251 for &inp in &node.inputs {
252 consumed_tensors.insert(inp);
253 }
254 }
255
256 for (node_idx, node) in graph.nodes.iter().enumerate() {
257 let any_output_used = node
258 .outputs
259 .iter()
260 .any(|o| consumed_tensors.contains(o) || output_set.contains(o));
261 if !any_output_used {
262 result.issues.push(ValidationIssue {
263 severity: IssueSeverity::Warning,
264 code: "unreachable-node".to_string(),
265 message: format!(
266 "Node {} outputs are never consumed and not graph outputs",
267 node_idx
268 ),
269 node_index: Some(node_idx),
270 });
271 }
272 }
273}
274
275fn check_output_refs(graph: &EinsumGraph, result: &mut ValidationResult) {
276 let num_tensors = graph.tensors.len();
277 for &output_idx in &graph.outputs {
278 if output_idx >= num_tensors {
279 result.issues.push(ValidationIssue {
280 severity: IssueSeverity::Error,
281 code: "invalid-graph-output".to_string(),
282 message: format!(
283 "Graph output references tensor index {} but only {} tensors exist",
284 output_idx, num_tensors
285 ),
286 node_index: None,
287 });
288 }
289 }
290}
291
292fn check_outputs_have_producers(graph: &EinsumGraph, result: &mut ValidationResult) {
293 let mut produced: HashSet<usize> = HashSet::new();
295 for node in &graph.nodes {
296 for &out in &node.outputs {
297 produced.insert(out);
298 }
299 }
300
301 let input_set: HashSet<usize> = graph.inputs.iter().copied().collect();
302
303 for &output_idx in &graph.outputs {
304 if output_idx >= graph.tensors.len() {
305 continue; }
307 if !produced.contains(&output_idx) && !input_set.contains(&output_idx) {
308 result.issues.push(ValidationIssue {
309 severity: IssueSeverity::Error,
310 code: "output-no-producer".to_string(),
311 message: format!(
312 "Output tensor {} ('{}') is not produced by any node and is not a graph input",
313 output_idx, graph.tensors[output_idx]
314 ),
315 node_index: None,
316 });
317 }
318 }
319}
320
321fn check_cycles(graph: &EinsumGraph, result: &mut ValidationResult) {
322 if graph.nodes.is_empty() {
323 return;
324 }
325
326 let num_nodes = graph.nodes.len();
328 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
329
330 let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
332 for (nidx, node) in graph.nodes.iter().enumerate() {
333 for &out in &node.outputs {
334 tensor_producer.insert(out, nidx);
335 }
336 }
337
338 for (nidx, node) in graph.nodes.iter().enumerate() {
340 for &out in &node.outputs {
341 for (other_idx, other_node) in graph.nodes.iter().enumerate() {
342 if other_idx != nidx && other_node.inputs.contains(&out) {
343 adj[nidx].push(other_idx);
344 }
345 }
346 }
347 }
348
349 let mut visited = vec![0u8; num_nodes]; for start in 0..num_nodes {
353 if visited[start] == 0 && dfs_has_cycle(start, &adj, &mut visited) {
354 result.issues.push(ValidationIssue {
355 severity: IssueSeverity::Error,
356 code: "cycle-detected".to_string(),
357 message: format!("Cyclic dependency detected involving node {}", start),
358 node_index: Some(start),
359 });
360 }
361 }
362}
363
364fn dfs_has_cycle(node: usize, adj: &[Vec<usize>], visited: &mut [u8]) -> bool {
365 visited[node] = 1; for &next in &adj[node] {
367 if visited[next] == 1 {
368 return true;
369 }
370 if visited[next] == 0 && dfs_has_cycle(next, adj, visited) {
371 return true;
372 }
373 }
374 visited[node] = 2; false
376}
377
378fn check_empty_node_outputs(graph: &EinsumGraph, result: &mut ValidationResult) {
379 for (node_idx, node) in graph.nodes.iter().enumerate() {
380 if node.outputs.is_empty() {
381 result.issues.push(ValidationIssue {
382 severity: IssueSeverity::Error,
383 code: "node-no-outputs".to_string(),
384 message: format!("Node {} produces no outputs", node_idx),
385 node_index: Some(node_idx),
386 });
387 }
388 }
389}
390
391fn check_duplicate_tensor_names(graph: &EinsumGraph, result: &mut ValidationResult) {
392 let mut name_indices: HashMap<&str, Vec<usize>> = HashMap::new();
393 for (idx, name) in graph.tensors.iter().enumerate() {
394 name_indices.entry(name.as_str()).or_default().push(idx);
395 }
396 for (name, indices) in &name_indices {
397 if indices.len() > 1 {
398 result.issues.push(ValidationIssue {
399 severity: IssueSeverity::Info,
400 code: "duplicate-tensor-name".to_string(),
401 message: format!("Tensor name '{}' is used by indices {:?}", name, indices),
402 node_index: None,
403 });
404 }
405 }
406}
407
408#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
414pub struct GraphSanitizationStats {
415 pub node_count: usize,
417 pub output_count: usize,
419 pub tensor_count: usize,
421 pub has_cycles: bool,
423 pub unreachable_count: usize,
425 pub max_depth: usize,
427}
428
429pub fn compute_graph_stats(graph: &EinsumGraph) -> GraphSanitizationStats {
431 let validation = validate_einsum_graph(graph);
432
433 let unreachable_count = validation
434 .issues
435 .iter()
436 .filter(|i| i.code == "unreachable-node")
437 .count();
438
439 let has_cycles = validation.issues.iter().any(|i| i.code == "cycle-detected");
440
441 let max_depth = compute_max_depth(graph);
442
443 GraphSanitizationStats {
444 node_count: graph.nodes.len(),
445 output_count: graph.outputs.len(),
446 tensor_count: graph.tensors.len(),
447 has_cycles,
448 unreachable_count,
449 max_depth,
450 }
451}
452
453fn compute_max_depth(graph: &EinsumGraph) -> usize {
455 if graph.nodes.is_empty() {
456 return 0;
457 }
458
459 let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
461 for (nidx, node) in graph.nodes.iter().enumerate() {
462 for &out in &node.outputs {
463 tensor_producer.insert(out, nidx);
464 }
465 }
466
467 let num_nodes = graph.nodes.len();
469 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
470 let mut in_degree: Vec<usize> = vec![0; num_nodes];
471
472 for (nidx, node) in graph.nodes.iter().enumerate() {
473 for &out in &node.outputs {
474 for (other_idx, other_node) in graph.nodes.iter().enumerate() {
475 if other_idx != nidx && other_node.inputs.contains(&out) {
476 adj[nidx].push(other_idx);
477 in_degree[other_idx] += 1;
478 }
479 }
480 }
481 }
482
483 let mut depth = vec![0usize; num_nodes];
485 let mut queue: VecDeque<usize> = VecDeque::new();
486
487 for (i, °) in in_degree.iter().enumerate() {
488 if deg == 0 {
489 queue.push_back(i);
490 }
491 }
492
493 let mut max_d = 0usize;
494 while let Some(n) = queue.pop_front() {
495 for &next in &adj[n] {
496 let new_depth = depth[n] + 1;
497 if new_depth > depth[next] {
498 depth[next] = new_depth;
499 }
500 in_degree[next] -= 1;
501 if in_degree[next] == 0 {
502 queue.push_back(next);
503 }
504 }
505 if depth[n] > max_d {
506 max_d = depth[n];
507 }
508 }
509
510 max_d
511}
512
513pub fn sanitize_graph(graph: &EinsumGraph) -> EinsumGraph {
524 let mut sanitized = graph.clone();
525
526 let mut seen = HashSet::new();
528 sanitized.outputs.retain(|o| seen.insert(*o));
529
530 sanitized
531}
532
533#[cfg(test)]
538mod tests {
539 use super::*;
540 use crate::{EinsumGraph, EinsumNode};
541
542 fn make_valid_graph() -> EinsumGraph {
544 let mut g = EinsumGraph::new();
545 let t0 = g.add_tensor("input");
546 let t1 = g.add_tensor("output");
547 g.inputs = vec![t0];
548 g.outputs = vec![t1];
549 g.add_node(EinsumNode::elem_unary("relu", t0, t1))
550 .expect("failed to add node");
551 g
552 }
553
554 #[test]
556 fn test_validate_empty_graph() {
557 let graph = EinsumGraph::new();
558 let result = validate_einsum_graph(&graph);
559 assert!(
560 result.is_valid(),
561 "empty graph should be valid (only warnings)"
562 );
563 assert!(
564 result.issues.iter().any(|i| i.code == "empty-graph"),
565 "should have empty-graph warning"
566 );
567 }
568
569 #[test]
571 fn test_validate_valid_graph() {
572 let graph = make_valid_graph();
573 let result = validate_einsum_graph(&graph);
574 assert!(
575 result.is_valid(),
576 "well-formed graph should be valid: {:?}",
577 result
578 );
579 }
580
581 #[test]
583 fn test_validate_duplicate_outputs() {
584 let mut graph = make_valid_graph();
585 graph.outputs.push(graph.outputs[0]);
587 let result = validate_einsum_graph(&graph);
588 assert!(
589 result.issues.iter().any(|i| i.code == "duplicate-output"),
590 "should detect duplicate outputs"
591 );
592 }
593
594 #[test]
596 fn test_validate_result_summary() {
597 let graph = EinsumGraph::new();
598 let result = validate_einsum_graph(&graph);
599 let summary = result.summary();
600 assert!(
601 summary.contains("errors") && summary.contains("warnings"),
602 "summary should mention errors and warnings: {}",
603 summary
604 );
605 }
606
607 #[test]
609 fn test_validate_error_count() {
610 let mut graph = EinsumGraph::new();
612 let t0 = graph.add_tensor("a");
613 let _t1 = graph.add_tensor("b");
614 graph.outputs = vec![t0]; let result = validate_einsum_graph(&graph);
616 assert!(
617 result.error_count() >= 1,
618 "should have at least one error for output without producer"
619 );
620 }
621
622 #[test]
624 fn test_validate_warning_count() {
625 let graph = EinsumGraph::new();
626 let result = validate_einsum_graph(&graph);
627 assert!(
628 result.warning_count() >= 1,
629 "empty graph should have at least one warning"
630 );
631 }
632
633 #[test]
635 fn test_graph_stats_node_count() {
636 let graph = make_valid_graph();
637 let stats = compute_graph_stats(&graph);
638 assert_eq!(stats.node_count, 1);
639 }
640
641 #[test]
643 fn test_graph_stats_output_count() {
644 let graph = make_valid_graph();
645 let stats = compute_graph_stats(&graph);
646 assert_eq!(stats.output_count, 1);
647 }
648
649 #[test]
651 fn test_sanitize_dedup_outputs() {
652 let mut graph = make_valid_graph();
653 graph.outputs.push(graph.outputs[0]);
654 assert_eq!(graph.outputs.len(), 2);
655 let sanitized = sanitize_graph(&graph);
656 assert_eq!(sanitized.outputs.len(), 1, "duplicates should be removed");
657 }
658
659 #[test]
661 fn test_sanitize_preserves_valid() {
662 let graph = make_valid_graph();
663 let sanitized = sanitize_graph(&graph);
664 assert_eq!(sanitized.tensors, graph.tensors);
665 assert_eq!(sanitized.nodes, graph.nodes);
666 assert_eq!(sanitized.outputs, graph.outputs);
667 assert_eq!(sanitized.inputs, graph.inputs);
668 }
669
670 #[test]
672 fn test_issue_severity_eq() {
673 assert_eq!(IssueSeverity::Error, IssueSeverity::Error);
674 assert_eq!(IssueSeverity::Warning, IssueSeverity::Warning);
675 assert_eq!(IssueSeverity::Info, IssueSeverity::Info);
676 assert_ne!(IssueSeverity::Error, IssueSeverity::Warning);
677 }
678
679 #[test]
681 fn test_validation_result_default() {
682 let result = ValidationResult::default();
683 assert!(result.issues.is_empty());
684 assert!(result.is_valid());
685 }
686
687 #[test]
689 fn test_validation_result_is_valid_no_errors() {
690 let mut result = ValidationResult::default();
691 result.issues.push(ValidationIssue {
692 severity: IssueSeverity::Warning,
693 code: "test".to_string(),
694 message: "just a warning".to_string(),
695 node_index: None,
696 });
697 assert!(result.is_valid(), "warnings only => valid");
698 }
699
700 #[test]
702 fn test_validation_result_is_valid_with_errors() {
703 let mut result = ValidationResult::default();
704 result.issues.push(ValidationIssue {
705 severity: IssueSeverity::Error,
706 code: "test-error".to_string(),
707 message: "an error".to_string(),
708 node_index: None,
709 });
710 assert!(!result.is_valid(), "errors => not valid");
711 }
712
713 #[test]
715 fn test_graph_stats_default() {
716 let stats = GraphSanitizationStats::default();
717 assert_eq!(stats.node_count, 0);
718 assert_eq!(stats.output_count, 0);
719 assert_eq!(stats.tensor_count, 0);
720 assert!(!stats.has_cycles);
721 assert_eq!(stats.unreachable_count, 0);
722 assert_eq!(stats.max_depth, 0);
723 }
724
725 #[test]
727 fn test_validate_outputs_reference() {
728 let mut graph = EinsumGraph::new();
730 graph.add_tensor("a");
731 graph.outputs = vec![999]; let result = validate_einsum_graph(&graph);
733 assert!(
734 result
735 .issues
736 .iter()
737 .any(|i| i.code == "invalid-graph-output"),
738 "should detect invalid graph output reference"
739 );
740 assert!(!result.is_valid());
741 }
742
743 #[test]
745 fn test_sanitize_returns_clone() {
746 let graph = make_valid_graph();
747 let sanitized = sanitize_graph(&graph);
748 let mut original = graph;
750 original.tensors.push("extra".to_string());
751 assert_ne!(original.tensors.len(), sanitized.tensors.len());
752 }
753
754 #[test]
756 fn test_compute_stats_empty() {
757 let graph = EinsumGraph::new();
758 let stats = compute_graph_stats(&graph);
759 assert_eq!(stats.node_count, 0);
760 assert_eq!(stats.output_count, 0);
761 assert_eq!(stats.tensor_count, 0);
762 assert!(!stats.has_cycles);
763 assert_eq!(stats.unreachable_count, 0);
764 assert_eq!(stats.max_depth, 0);
765 }
766}