1use crate::graph::{EinsumGraph, OpType};
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone)]
11pub struct ValidationReport {
12 pub checks_performed: usize,
14 pub errors: Vec<ValidationError>,
16 pub warnings: Vec<ValidationWarning>,
18 pub stats: GraphValidationStats,
20}
21
22#[derive(Debug, Clone)]
24pub struct ValidationError {
25 pub kind: ValidationErrorKind,
26 pub message: String,
27 pub node_index: Option<usize>,
28 pub tensor_index: Option<usize>,
29}
30
31#[derive(Debug, Clone, PartialEq)]
33pub enum ValidationErrorKind {
34 TensorOutOfBounds,
36 UndefinedTensor,
38 UnproducedTensor,
40 OutputWithoutProducer,
42 CyclicDependency,
44 EmptyEinsumSpec,
46 InvalidEinsumSpec,
48 NoOutputs,
50 DuplicateOutput,
52}
53
54#[derive(Debug, Clone)]
56pub struct ValidationWarning {
57 pub kind: ValidationWarningKind,
58 pub message: String,
59 pub tensor_index: Option<usize>,
60 pub node_index: Option<usize>,
61}
62
63#[derive(Debug, Clone, PartialEq)]
65pub enum ValidationWarningKind {
66 UnusedTensor,
68 UnusedInput,
70 GeneratedTensorName,
72 LargeGraph,
74 DeepNesting,
76}
77
78#[derive(Debug, Clone, Default)]
80pub struct GraphValidationStats {
81 pub total_tensors: usize,
82 pub total_nodes: usize,
83 pub input_tensors: usize,
84 pub output_tensors: usize,
85 pub unused_tensors: usize,
86 pub max_operation_depth: usize,
87 pub einsum_operations: usize,
88 pub elem_unary_operations: usize,
89 pub elem_binary_operations: usize,
90 pub reduce_operations: usize,
91}
92
93impl ValidationReport {
94 pub fn is_valid(&self) -> bool {
96 self.errors.is_empty()
97 }
98
99 pub fn has_issues(&self) -> bool {
101 !self.errors.is_empty() || !self.warnings.is_empty()
102 }
103
104 pub fn summary(&self) -> String {
106 format!(
107 "Validation: {} errors, {} warnings ({} checks)",
108 self.errors.len(),
109 self.warnings.len(),
110 self.checks_performed
111 )
112 }
113}
114
115pub fn validate_graph(graph: &EinsumGraph) -> ValidationReport {
135 let mut report = ValidationReport {
136 checks_performed: 0,
137 errors: Vec::new(),
138 warnings: Vec::new(),
139 stats: GraphValidationStats::default(),
140 };
141
142 report.stats.total_tensors = graph.tensors.len();
144 report.stats.total_nodes = graph.nodes.len();
145 report.stats.input_tensors = graph.inputs.len();
146 report.stats.output_tensors = graph.outputs.len();
147
148 report.checks_performed += 1;
150 check_tensor_bounds(graph, &mut report);
151
152 report.checks_performed += 1;
154 let producers = analyze_producers(graph, &mut report);
155
156 report.checks_performed += 1;
158 let consumers = analyze_consumers(graph, &mut report);
159
160 report.checks_performed += 1;
162 check_output_producers(graph, &producers, &mut report);
163
164 report.checks_performed += 1;
166 check_unused_tensors(graph, &producers, &consumers, &mut report);
167
168 report.checks_performed += 1;
170 check_einsum_specs(graph, &mut report);
171
172 report.checks_performed += 1;
174 check_cycles(graph, &mut report);
175
176 report.checks_performed += 1;
178 check_node_outputs(graph, &mut report);
179
180 report.checks_performed += 1;
182 count_operations(graph, &mut report);
183
184 report.checks_performed += 1;
186 check_graph_size(graph, &mut report);
187
188 report
189}
190
191fn check_tensor_bounds(graph: &EinsumGraph, report: &mut ValidationReport) {
193 for (node_idx, node) in graph.nodes.iter().enumerate() {
194 for &input in &node.inputs {
195 if input >= graph.tensors.len() {
196 report.errors.push(ValidationError {
197 kind: ValidationErrorKind::TensorOutOfBounds,
198 message: format!(
199 "Input tensor {} is out of bounds (max: {})",
200 input,
201 graph.tensors.len() - 1
202 ),
203 node_index: Some(node_idx),
204 tensor_index: Some(input),
205 });
206 }
207 }
208
209 for &output in &node.outputs {
210 if output >= graph.tensors.len() {
211 report.errors.push(ValidationError {
212 kind: ValidationErrorKind::TensorOutOfBounds,
213 message: format!(
214 "Output tensor {} is out of bounds (max: {})",
215 output,
216 graph.tensors.len() - 1
217 ),
218 node_index: Some(node_idx),
219 tensor_index: Some(output),
220 });
221 }
222 }
223 }
224}
225
226fn analyze_producers(graph: &EinsumGraph, report: &mut ValidationReport) -> HashMap<usize, usize> {
228 let mut producers = HashMap::new();
229
230 for (node_idx, node) in graph.nodes.iter().enumerate() {
231 for &output in &node.outputs {
232 if let Some(existing_producer) = producers.insert(output, node_idx) {
233 report.errors.push(ValidationError {
234 kind: ValidationErrorKind::DuplicateOutput,
235 message: format!(
236 "Tensor {} is produced by multiple nodes: {} and {}",
237 output, existing_producer, node_idx
238 ),
239 node_index: Some(node_idx),
240 tensor_index: Some(output),
241 });
242 }
243 }
244 }
245
246 producers
247}
248
249fn analyze_consumers(
251 graph: &EinsumGraph,
252 _report: &mut ValidationReport,
253) -> HashMap<usize, Vec<usize>> {
254 let mut consumers: HashMap<usize, Vec<usize>> = HashMap::new();
255
256 for (node_idx, node) in graph.nodes.iter().enumerate() {
257 for &input in &node.inputs {
258 consumers.entry(input).or_default().push(node_idx);
259 }
260 }
261
262 consumers
263}
264
265fn check_output_producers(
267 graph: &EinsumGraph,
268 producers: &HashMap<usize, usize>,
269 report: &mut ValidationReport,
270) {
271 for &output_idx in &graph.outputs {
272 if output_idx >= graph.tensors.len() {
273 continue; }
275
276 if !producers.contains_key(&output_idx) && !graph.inputs.contains(&output_idx) {
277 report.errors.push(ValidationError {
278 kind: ValidationErrorKind::OutputWithoutProducer,
279 message: format!(
280 "Output tensor {} '{}' has no producer",
281 output_idx, graph.tensors[output_idx]
282 ),
283 node_index: None,
284 tensor_index: Some(output_idx),
285 });
286 }
287 }
288}
289
290fn check_unused_tensors(
292 graph: &EinsumGraph,
293 producers: &HashMap<usize, usize>,
294 consumers: &HashMap<usize, Vec<usize>>,
295 report: &mut ValidationReport,
296) {
297 for (tensor_idx, tensor_name) in graph.tensors.iter().enumerate() {
298 let is_input = graph.inputs.contains(&tensor_idx);
299 let is_output = graph.outputs.contains(&tensor_idx);
300 let has_producer = producers.contains_key(&tensor_idx);
301 let has_consumers = consumers.contains_key(&tensor_idx);
302
303 if has_producer && !has_consumers && !is_output {
305 report.warnings.push(ValidationWarning {
306 kind: ValidationWarningKind::UnusedTensor,
307 message: format!(
308 "Tensor {} '{}' is produced but never consumed",
309 tensor_idx, tensor_name
310 ),
311 tensor_index: Some(tensor_idx),
312 node_index: None,
313 });
314 report.stats.unused_tensors += 1;
315 }
316
317 if is_input && !has_consumers {
319 report.warnings.push(ValidationWarning {
320 kind: ValidationWarningKind::UnusedInput,
321 message: format!(
322 "Input tensor {} '{}' is never consumed",
323 tensor_idx, tensor_name
324 ),
325 tensor_index: Some(tensor_idx),
326 node_index: None,
327 });
328 }
329
330 if tensor_name.starts_with("temp_")
332 || tensor_name.starts_with("t_")
333 || tensor_name.starts_with("_")
334 {
335 report.warnings.push(ValidationWarning {
336 kind: ValidationWarningKind::GeneratedTensorName,
337 message: format!("Tensor {} has generated name '{}'", tensor_idx, tensor_name),
338 tensor_index: Some(tensor_idx),
339 node_index: None,
340 });
341 }
342 }
343}
344
345fn check_einsum_specs(graph: &EinsumGraph, report: &mut ValidationReport) {
347 for (node_idx, node) in graph.nodes.iter().enumerate() {
348 if let OpType::Einsum { spec } = &node.op {
349 if spec.is_empty() {
350 report.errors.push(ValidationError {
351 kind: ValidationErrorKind::EmptyEinsumSpec,
352 message: "Einsum operation has empty specification".to_string(),
353 node_index: Some(node_idx),
354 tensor_index: None,
355 });
356 }
357
358 if !spec.contains("->") {
360 report.errors.push(ValidationError {
361 kind: ValidationErrorKind::InvalidEinsumSpec,
362 message: format!("Einsum specification '{}' is invalid (missing '->')", spec),
363 node_index: Some(node_idx),
364 tensor_index: None,
365 });
366 }
367 }
368 }
369}
370
371fn check_cycles(graph: &EinsumGraph, report: &mut ValidationReport) {
373 let mut visited = HashSet::new();
375 let mut rec_stack = HashSet::new();
376
377 for node_idx in 0..graph.nodes.len() {
378 if !visited.contains(&node_idx)
379 && has_cycle_dfs(node_idx, graph, &mut visited, &mut rec_stack)
380 {
381 report.errors.push(ValidationError {
382 kind: ValidationErrorKind::CyclicDependency,
383 message: format!("Cyclic dependency detected involving node {}", node_idx),
384 node_index: Some(node_idx),
385 tensor_index: None,
386 });
387 }
388 }
389}
390
391fn has_cycle_dfs(
393 node_idx: usize,
394 graph: &EinsumGraph,
395 visited: &mut HashSet<usize>,
396 rec_stack: &mut HashSet<usize>,
397) -> bool {
398 visited.insert(node_idx);
399 rec_stack.insert(node_idx);
400
401 let node = &graph.nodes[node_idx];
402
403 for &output in &node.outputs {
405 for (next_node_idx, next_node) in graph.nodes.iter().enumerate() {
406 if next_node.inputs.contains(&output) {
407 if !visited.contains(&next_node_idx) {
408 if has_cycle_dfs(next_node_idx, graph, visited, rec_stack) {
409 return true;
410 }
411 } else if rec_stack.contains(&next_node_idx) {
412 return true;
413 }
414 }
415 }
416 }
417
418 rec_stack.remove(&node_idx);
419 false
420}
421
422fn check_node_outputs(graph: &EinsumGraph, report: &mut ValidationReport) {
424 for (node_idx, node) in graph.nodes.iter().enumerate() {
425 if node.outputs.is_empty() {
426 report.errors.push(ValidationError {
427 kind: ValidationErrorKind::NoOutputs,
428 message: format!("Node {} has no outputs", node_idx),
429 node_index: Some(node_idx),
430 tensor_index: None,
431 });
432 }
433 }
434}
435
436fn count_operations(graph: &EinsumGraph, report: &mut ValidationReport) {
438 for node in &graph.nodes {
439 match &node.op {
440 OpType::Einsum { .. } => report.stats.einsum_operations += 1,
441 OpType::ElemUnary { .. } => report.stats.elem_unary_operations += 1,
442 OpType::ElemBinary { .. } => report.stats.elem_binary_operations += 1,
443 OpType::Reduce { .. } => report.stats.reduce_operations += 1,
444 }
445 }
446}
447
448fn check_graph_size(graph: &EinsumGraph, report: &mut ValidationReport) {
450 if graph.nodes.len() > 1000 {
451 report.warnings.push(ValidationWarning {
452 kind: ValidationWarningKind::LargeGraph,
453 message: format!(
454 "Graph has {} operations (may be slow to execute)",
455 graph.nodes.len()
456 ),
457 tensor_index: None,
458 node_index: None,
459 });
460 }
461
462 if graph.tensors.len() > 10000 {
463 report.warnings.push(ValidationWarning {
464 kind: ValidationWarningKind::LargeGraph,
465 message: format!(
466 "Graph has {} tensors (may use significant memory)",
467 graph.tensors.len()
468 ),
469 tensor_index: None,
470 node_index: None,
471 });
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478 use crate::{EinsumGraph, EinsumNode};
479
480 #[test]
481 fn test_validate_empty_graph() {
482 let graph = EinsumGraph::new();
483 let report = validate_graph(&graph);
484 assert!(report.is_valid());
485 assert_eq!(report.errors.len(), 0);
486 }
487
488 #[test]
489 fn test_validate_simple_graph() {
490 let mut graph = EinsumGraph::new();
491 let t0 = graph.add_tensor("input".to_string());
492 let t1 = graph.add_tensor("output".to_string());
493 graph.inputs = vec![t0];
494 graph.outputs = vec![t1];
495
496 let node = EinsumNode::elem_unary("relu", t0, t1);
497 graph.add_node(node).unwrap();
498
499 let report = validate_graph(&graph);
500 assert!(report.is_valid());
501 assert_eq!(report.stats.total_tensors, 2);
502 assert_eq!(report.stats.total_nodes, 1);
503 }
504
505 #[test]
506 fn test_detect_tensor_out_of_bounds() {
507 let mut graph = EinsumGraph::new();
508 let t0 = graph.add_tensor("input".to_string());
509 graph.add_tensor("output".to_string());
510
511 let bad_node = EinsumNode::elem_unary("relu", t0, 999);
513 graph.nodes.push(bad_node);
514
515 let report = validate_graph(&graph);
516 assert!(!report.is_valid());
517 assert_eq!(report.errors.len(), 1);
518 assert_eq!(
519 report.errors[0].kind,
520 ValidationErrorKind::TensorOutOfBounds
521 );
522 }
523
524 #[test]
525 fn test_detect_unused_tensor() {
526 let mut graph = EinsumGraph::new();
527 let t0 = graph.add_tensor("input".to_string());
528 let t1 = graph.add_tensor("intermediate".to_string());
529 let t2 = graph.add_tensor("output".to_string());
530 graph.inputs = vec![t0];
531 graph.outputs = vec![t2];
532
533 graph
535 .add_node(EinsumNode::elem_unary("relu", t0, t1))
536 .unwrap();
537 graph
538 .add_node(EinsumNode::elem_unary("sigmoid", t0, t2))
539 .unwrap();
540
541 let report = validate_graph(&graph);
542 assert!(report.is_valid()); assert_eq!(report.warnings.len(), 1);
544 assert_eq!(report.warnings[0].kind, ValidationWarningKind::UnusedTensor);
545 }
546
547 #[test]
548 fn test_detect_output_without_producer() {
549 let mut graph = EinsumGraph::new();
550 let t0 = graph.add_tensor("input".to_string());
551 let t1 = graph.add_tensor("output".to_string());
552 graph.inputs = vec![t0];
553 graph.outputs = vec![t1]; let report = validate_graph(&graph);
556 assert!(!report.is_valid());
557 assert_eq!(report.errors.len(), 1);
558 assert_eq!(
559 report.errors[0].kind,
560 ValidationErrorKind::OutputWithoutProducer
561 );
562 }
563
564 #[test]
565 fn test_detect_empty_einsum_spec() {
566 let mut graph = EinsumGraph::new();
567 let t0 = graph.add_tensor("input".to_string());
568 let t1 = graph.add_tensor("output".to_string());
569
570 let bad_node = EinsumNode::einsum("", vec![t0], vec![t1]);
571 graph.nodes.push(bad_node);
572
573 let report = validate_graph(&graph);
574 assert!(!report.is_valid());
575 assert!(report
576 .errors
577 .iter()
578 .any(|e| e.kind == ValidationErrorKind::EmptyEinsumSpec));
579 }
580
581 #[test]
582 fn test_detect_invalid_einsum_spec() {
583 let mut graph = EinsumGraph::new();
584 let t0 = graph.add_tensor("input".to_string());
585 let t1 = graph.add_tensor("output".to_string());
586
587 let bad_node = EinsumNode::einsum("ijk", vec![t0], vec![t1]); graph.nodes.push(bad_node);
589
590 let report = validate_graph(&graph);
591 assert!(!report.is_valid());
592 assert!(report
593 .errors
594 .iter()
595 .any(|e| e.kind == ValidationErrorKind::InvalidEinsumSpec));
596 }
597
598 #[test]
599 fn test_statistics_collection() {
600 let mut graph = EinsumGraph::new();
601 let t0 = graph.add_tensor("a".to_string());
602 let t1 = graph.add_tensor("b".to_string());
603 let t2 = graph.add_tensor("c".to_string());
604 let t3 = graph.add_tensor("d".to_string());
605
606 graph
607 .add_node(EinsumNode::elem_unary("relu", t0, t1))
608 .unwrap();
609 graph
610 .add_node(EinsumNode::elem_binary("add", t1, t2, t3))
611 .unwrap();
612
613 let report = validate_graph(&graph);
614 assert_eq!(report.stats.elem_unary_operations, 1);
615 assert_eq!(report.stats.elem_binary_operations, 1);
616 assert_eq!(report.stats.total_nodes, 2);
617 assert_eq!(report.stats.total_tensors, 4);
618 }
619}