1use std::fmt;
7use tensorlogic_ir::EinsumGraph;
8
9use crate::shape::TensorShape;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
13pub enum Severity {
14 Info,
16 Warning,
18 Error,
20 Critical,
22}
23
24impl fmt::Display for Severity {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 match self {
27 Severity::Info => write!(f, "INFO"),
28 Severity::Warning => write!(f, "WARNING"),
29 Severity::Error => write!(f, "ERROR"),
30 Severity::Critical => write!(f, "CRITICAL"),
31 }
32 }
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct SourceLocation {
38 pub file: Option<String>,
39 pub line: Option<usize>,
40 pub column: Option<usize>,
41}
42
43impl SourceLocation {
44 pub fn new() -> Self {
45 SourceLocation {
46 file: None,
47 line: None,
48 column: None,
49 }
50 }
51
52 pub fn with_file(mut self, file: String) -> Self {
53 self.file = Some(file);
54 self
55 }
56
57 pub fn with_line(mut self, line: usize) -> Self {
58 self.line = Some(line);
59 self
60 }
61
62 pub fn with_column(mut self, column: usize) -> Self {
63 self.column = Some(column);
64 self
65 }
66}
67
68impl Default for SourceLocation {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl fmt::Display for SourceLocation {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 if let Some(ref file) = self.file {
77 write!(f, "{}", file)?;
78 if let Some(line) = self.line {
79 write!(f, ":{}", line)?;
80 if let Some(column) = self.column {
81 write!(f, ":{}", column)?;
82 }
83 }
84 } else {
85 write!(f, "<unknown>")?;
86 }
87 Ok(())
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct Diagnostic {
94 pub severity: Severity,
96 pub message: String,
98 pub location: Option<SourceLocation>,
100 pub context: Vec<String>,
102 pub suggestions: Vec<String>,
104 pub related: Vec<String>,
106 pub code: Option<String>,
108}
109
110impl Diagnostic {
111 pub fn new(severity: Severity, message: impl Into<String>) -> Self {
113 Diagnostic {
114 severity,
115 message: message.into(),
116 location: None,
117 context: Vec::new(),
118 suggestions: Vec::new(),
119 related: Vec::new(),
120 code: None,
121 }
122 }
123
124 pub fn error(message: impl Into<String>) -> Self {
126 Self::new(Severity::Error, message)
127 }
128
129 pub fn warning(message: impl Into<String>) -> Self {
131 Self::new(Severity::Warning, message)
132 }
133
134 pub fn info(message: impl Into<String>) -> Self {
136 Self::new(Severity::Info, message)
137 }
138
139 pub fn with_location(mut self, location: SourceLocation) -> Self {
141 self.location = Some(location);
142 self
143 }
144
145 pub fn with_context(mut self, context: impl Into<String>) -> Self {
147 self.context.push(context.into());
148 self
149 }
150
151 pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
153 self.suggestions.push(suggestion.into());
154 self
155 }
156
157 pub fn with_related(mut self, related: impl Into<String>) -> Self {
159 self.related.push(related.into());
160 self
161 }
162
163 pub fn with_code(mut self, code: impl Into<String>) -> Self {
165 self.code = Some(code.into());
166 self
167 }
168
169 pub fn format(&self) -> String {
171 let mut output = String::new();
172
173 output.push_str(&format!("[{}] {}\n", self.severity, self.message));
175
176 if let Some(ref loc) = self.location {
178 output.push_str(&format!(" at {}\n", loc));
179 }
180
181 if let Some(ref code) = self.code {
183 output.push_str(&format!(" code: {}\n", code));
184 }
185
186 if !self.context.is_empty() {
188 output.push_str("\nContext:\n");
189 for ctx in &self.context {
190 output.push_str(&format!(" {}\n", ctx));
191 }
192 }
193
194 if !self.suggestions.is_empty() {
196 output.push_str("\nSuggestions:\n");
197 for (i, suggestion) in self.suggestions.iter().enumerate() {
198 output.push_str(&format!(" {}. {}\n", i + 1, suggestion));
199 }
200 }
201
202 if !self.related.is_empty() {
204 output.push_str("\nRelated:\n");
205 for rel in &self.related {
206 output.push_str(&format!(" - {}\n", rel));
207 }
208 }
209
210 output
211 }
212}
213
214impl fmt::Display for Diagnostic {
215 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216 write!(f, "{}", self.format())
217 }
218}
219
220pub struct ShapeMismatchDiagnostic;
222
223impl ShapeMismatchDiagnostic {
224 pub fn create(expected: &TensorShape, actual: &TensorShape, operation: &str) -> Diagnostic {
225 let mut diag = Diagnostic::error(format!("Shape mismatch in {} operation", operation))
226 .with_code("E001")
227 .with_context(format!(
228 "Expected shape: {:?}, but got: {:?}",
229 expected.dims, actual.dims
230 ));
231
232 if expected.rank() != actual.rank() {
234 diag = diag
235 .with_suggestion(format!(
236 "Expected rank {} but got rank {}. Consider reshaping your tensor.",
237 expected.rank(),
238 actual.rank()
239 ))
240 .with_suggestion(format!(
241 "Use tensor.reshape({:?}) to match the expected shape",
242 expected.dims
243 ));
244 } else {
245 let mismatches: Vec<_> = expected
247 .dims
248 .iter()
249 .zip(actual.dims.iter())
250 .enumerate()
251 .filter(|(_, (e, a))| e != a)
252 .collect();
253
254 for (dim, (exp, act)) in mismatches {
255 diag = diag.with_context(format!(
256 "Dimension {} mismatch: expected {:?}, got {:?}",
257 dim, exp, act
258 ));
259 }
260
261 diag = diag.with_suggestion(
262 "Check your input tensor shapes match the expected dimensions".to_string(),
263 );
264 }
265
266 diag
267 }
268}
269
270pub struct TypeMismatchDiagnostic;
272
273impl TypeMismatchDiagnostic {
274 pub fn create(expected: &str, actual: &str, context: &str) -> Diagnostic {
275 Diagnostic::error(format!("Type mismatch in {}", context))
276 .with_code("E002")
277 .with_context(format!("Expected type: {}, but got: {}", expected, actual))
278 .with_suggestion(format!("Convert your data to {} type", expected))
279 .with_suggestion("Check the input data types match the expected types".to_string())
280 }
281}
282
283pub struct NodeExecutionDiagnostic;
285
286impl NodeExecutionDiagnostic {
287 pub fn create(node_id: usize, error: &str, graph: &EinsumGraph) -> Diagnostic {
288 let mut diag = Diagnostic::error(format!("Failed to execute node {}", node_id))
289 .with_code("E003")
290 .with_context(error.to_string());
291
292 if let Some(node) = graph.nodes.get(node_id) {
294 diag = diag.with_context(format!("Node operation: {:?}", node.op));
295
296 if !node.inputs.is_empty() {
298 diag = diag.with_context(format!("Input nodes: {:?}", node.inputs));
299 }
300
301 diag = diag.with_suggestion(
303 "Check that all input tensors are properly initialized".to_string(),
304 );
305 diag = diag.with_suggestion(
306 "Verify input tensor shapes are compatible with this operation".to_string(),
307 );
308 }
309
310 for input_id in graph
312 .nodes
313 .get(node_id)
314 .map(|n| &n.inputs)
315 .unwrap_or(&vec![])
316 {
317 diag = diag.with_related(format!("Input node: {}", input_id));
318 }
319
320 diag
321 }
322}
323
324pub struct MemoryDiagnostic;
326
327impl MemoryDiagnostic {
328 pub fn out_of_memory(requested_bytes: usize, available_bytes: usize) -> Diagnostic {
329 let requested_mb = requested_bytes as f64 / (1024.0 * 1024.0);
330 let available_mb = available_bytes as f64 / (1024.0 * 1024.0);
331
332 Diagnostic::error("Out of memory")
333 .with_code("E004")
334 .with_context(format!(
335 "Requested: {:.2} MB, Available: {:.2} MB",
336 requested_mb, available_mb
337 ))
338 .with_suggestion("Reduce batch size to lower memory usage".to_string())
339 .with_suggestion("Enable streaming execution for large datasets".to_string())
340 .with_suggestion("Consider using a machine with more memory".to_string())
341 .with_suggestion("Enable memory pooling to reuse allocations".to_string())
342 }
343
344 pub fn memory_leak_warning(leaked_bytes: usize) -> Diagnostic {
345 let leaked_mb = leaked_bytes as f64 / (1024.0 * 1024.0);
346
347 Diagnostic::warning(format!(
348 "Potential memory leak detected: {:.2} MB",
349 leaked_mb
350 ))
351 .with_code("W001")
352 .with_suggestion("Check that all tensors are properly released".to_string())
353 .with_suggestion("Enable memory profiling to identify the leak source".to_string())
354 .with_suggestion("Use memory pooling to manage allocations".to_string())
355 }
356}
357
358impl ShapeMismatchDiagnostic {
359 pub fn with_transpose_suggestion(
361 mut diag: Diagnostic,
362 expected: &[usize],
363 actual: &[usize],
364 ) -> Diagnostic {
365 if expected.len() == actual.len() {
366 let mut sorted_expected = expected.to_vec();
367 let mut sorted_actual = actual.to_vec();
368 sorted_expected.sort_unstable();
369 sorted_actual.sort_unstable();
370 if sorted_expected == sorted_actual {
371 let perm: Vec<usize> = expected
373 .iter()
374 .map(|&e| actual.iter().position(|&a| a == e).unwrap_or(0))
375 .collect();
376 diag = diag.with_suggestion(format!(
377 "Shapes are permutations of each other. Consider transposing with axes {:?}",
378 perm
379 ));
380 }
381 }
382 diag
383 }
384
385 pub fn with_broadcast_suggestion(
387 mut diag: Diagnostic,
388 expected: &[usize],
389 actual: &[usize],
390 ) -> Diagnostic {
391 let rank_diff = (expected.len() as isize - actual.len() as isize).unsigned_abs();
392 if rank_diff == 1 {
393 let (longer, shorter) = if expected.len() > actual.len() {
394 (expected, actual)
395 } else {
396 (actual, expected)
397 };
398 let suffix_matches = longer
400 .iter()
401 .rev()
402 .zip(shorter.iter().rev())
403 .all(|(&l, &s)| l == s || l == 1 || s == 1);
404 if suffix_matches {
405 diag = diag.with_suggestion(format!(
406 "Ranks differ by 1. Try unsqueezing to shape {:?} or using broadcasting",
407 longer
408 ));
409 }
410 }
411 diag
412 }
413}
414
415pub struct PerformanceDiagnostic;
417
418impl PerformanceDiagnostic {
419 pub fn slow_operation(
420 operation: &str,
421 actual_time_ms: f64,
422 expected_time_ms: f64,
423 ) -> Diagnostic {
424 let slowdown = actual_time_ms / expected_time_ms;
425
426 Diagnostic::warning(format!(
427 "Slow {} operation: {:.2}x slower than expected",
428 operation, slowdown
429 ))
430 .with_code("W002")
431 .with_context(format!(
432 "Actual: {:.2}ms, Expected: {:.2}ms",
433 actual_time_ms, expected_time_ms
434 ))
435 .with_suggestion("Enable graph optimization to improve performance".to_string())
436 .with_suggestion("Check if operation fusion is enabled".to_string())
437 .with_suggestion("Consider using a more powerful device (GPU)".to_string())
438 .with_suggestion("Profile the execution to identify bottlenecks".to_string())
439 }
440
441 pub fn high_memory_usage(peak_mb: f64, threshold_mb: f64) -> Diagnostic {
442 Diagnostic::warning(format!("High memory usage: {:.2} MB", peak_mb))
443 .with_code("W003")
444 .with_context(format!("Threshold: {:.2} MB", threshold_mb))
445 .with_suggestion("Enable memory optimization".to_string())
446 .with_suggestion("Reduce batch size".to_string())
447 .with_suggestion("Use streaming execution for large datasets".to_string())
448 }
449
450 pub fn parallelism_available(num_independent_ops: usize, current_threads: usize) -> Diagnostic {
452 Diagnostic::info(format!(
453 "Parallelism opportunity: {} independent ops, only {} threads active",
454 num_independent_ops, current_threads
455 ))
456 .with_code("P001")
457 .with_context(format!(
458 "{} operations could run in parallel but only {} worker threads are available",
459 num_independent_ops, current_threads
460 ))
461 .with_suggestion(format!(
462 "Increase thread pool size to at least {} for maximum throughput",
463 num_independent_ops
464 ))
465 .with_suggestion(
466 "Use rayon or a work-stealing scheduler for automatic parallelism".to_string(),
467 )
468 }
469
470 pub fn high_allocation_rate(allocs_per_second: f64, threshold: f64) -> Diagnostic {
472 Diagnostic::warning(format!(
473 "High allocation rate: {:.1} allocs/s (threshold: {:.1})",
474 allocs_per_second, threshold
475 ))
476 .with_code("P002")
477 .with_context(format!(
478 "Tensor allocations are occurring at {:.1} per second",
479 allocs_per_second
480 ))
481 .with_suggestion("Enable a memory pool (WorkspacePool) to reuse buffers".to_string())
482 .with_suggestion("Pre-allocate output tensors where output shapes are known".to_string())
483 }
484
485 pub fn fusion_opportunity(num_fuseable: usize, op_names: &[&str]) -> Diagnostic {
487 Diagnostic::info(format!(
488 "Fusion opportunity: {} operations could be fused",
489 num_fuseable
490 ))
491 .with_code("P003")
492 .with_context(format!("Fuseable operations: {}", op_names.join(", ")))
493 .with_suggestion(
494 "Enable the FusionOptimizer pass to reduce kernel launch overhead".to_string(),
495 )
496 .with_suggestion("Consider using FusionStrategy::Aggressive for maximum fusion".to_string())
497 }
498
499 pub fn precision_downgrade_available(estimated_speedup: f64) -> Diagnostic {
501 Diagnostic::info(format!(
502 "Precision downgrade available: estimated {:.1}x speedup using f32",
503 estimated_speedup
504 ))
505 .with_code("P004")
506 .with_context("Computation is currently using f64 (double) precision".to_string())
507 .with_suggestion(
508 "Switch to f32 (single precision) if model accuracy tolerates it".to_string(),
509 )
510 .with_suggestion(
511 "Use MixedPrecisionConfig to selectively apply f16/f32 where safe".to_string(),
512 )
513 }
514}
515
516#[derive(Debug, Default)]
518pub struct DiagnosticCollector {
519 diagnostics: Vec<Diagnostic>,
520}
521
522impl DiagnosticCollector {
523 pub fn new() -> Self {
524 Self::default()
525 }
526
527 pub fn add(&mut self, diagnostic: Diagnostic) {
529 self.diagnostics.push(diagnostic);
530 }
531
532 pub fn diagnostics(&self) -> &[Diagnostic] {
534 &self.diagnostics
535 }
536
537 pub fn has_errors(&self) -> bool {
539 self.diagnostics
540 .iter()
541 .any(|d| d.severity >= Severity::Error)
542 }
543
544 pub fn error_count(&self) -> usize {
546 self.diagnostics
547 .iter()
548 .filter(|d| d.severity == Severity::Error)
549 .count()
550 }
551
552 pub fn warning_count(&self) -> usize {
554 self.diagnostics
555 .iter()
556 .filter(|d| d.severity == Severity::Warning)
557 .count()
558 }
559
560 pub fn format_all(&self) -> String {
562 let mut output = String::new();
563 for diag in &self.diagnostics {
564 output.push_str(&diag.format());
565 output.push('\n');
566 }
567
568 output.push_str(&format!(
569 "\nSummary: {} error(s), {} warning(s)\n",
570 self.error_count(),
571 self.warning_count()
572 ));
573
574 output
575 }
576
577 pub fn clear(&mut self) {
579 self.diagnostics.clear();
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586
587 #[test]
588 fn test_diagnostic_creation() {
589 let diag = Diagnostic::error("Test error")
590 .with_code("E001")
591 .with_context("Additional context")
592 .with_suggestion("Try this fix");
593
594 assert_eq!(diag.severity, Severity::Error);
595 assert_eq!(diag.message, "Test error");
596 assert_eq!(diag.code, Some("E001".to_string()));
597 assert_eq!(diag.context.len(), 1);
598 assert_eq!(diag.suggestions.len(), 1);
599 }
600
601 #[test]
602 fn test_shape_mismatch_diagnostic() {
603 let expected = TensorShape::static_shape(vec![64, 128]);
604 let actual = TensorShape::static_shape(vec![64, 256]);
605
606 let diag = ShapeMismatchDiagnostic::create(&expected, &actual, "matmul");
607
608 assert_eq!(diag.severity, Severity::Error);
609 assert!(diag.message.contains("Shape mismatch"));
610 assert!(!diag.suggestions.is_empty());
611 }
612
613 #[test]
614 fn test_type_mismatch_diagnostic() {
615 let diag = TypeMismatchDiagnostic::create("f32", "f64", "tensor operation");
616
617 assert_eq!(diag.severity, Severity::Error);
618 assert!(diag.message.contains("Type mismatch"));
619 assert_eq!(diag.code, Some("E002".to_string()));
620 }
621
622 #[test]
623 fn test_memory_diagnostic() {
624 let diag = MemoryDiagnostic::out_of_memory(1024 * 1024 * 1024, 512 * 1024 * 1024);
625
626 assert_eq!(diag.severity, Severity::Error);
627 assert!(diag.message.contains("Out of memory"));
628 assert!(!diag.suggestions.is_empty());
629 }
630
631 #[test]
632 fn test_performance_diagnostic() {
633 let diag = PerformanceDiagnostic::slow_operation("einsum", 100.0, 50.0);
634
635 assert_eq!(diag.severity, Severity::Warning);
636 assert!(diag.message.contains("Slow"));
637 assert!(diag.message.contains("2.00x"));
638 }
639
640 #[test]
641 fn test_diagnostic_collector() {
642 let mut collector = DiagnosticCollector::new();
643
644 collector.add(Diagnostic::error("Error 1"));
645 collector.add(Diagnostic::warning("Warning 1"));
646 collector.add(Diagnostic::error("Error 2"));
647
648 assert_eq!(collector.error_count(), 2);
649 assert_eq!(collector.warning_count(), 1);
650 assert!(collector.has_errors());
651
652 let formatted = collector.format_all();
653 assert!(formatted.contains("2 error(s), 1 warning(s)"));
654 }
655
656 #[test]
657 fn test_source_location() {
658 let loc = SourceLocation::new()
659 .with_file("test.rs".to_string())
660 .with_line(42)
661 .with_column(10);
662
663 assert_eq!(loc.to_string(), "test.rs:42:10");
664 }
665
666 #[test]
667 fn test_severity_ordering() {
668 assert!(Severity::Info < Severity::Warning);
669 assert!(Severity::Warning < Severity::Error);
670 assert!(Severity::Error < Severity::Critical);
671 }
672
673 #[test]
674 fn test_transpose_suggestion_added() {
675 let base = Diagnostic::error("shape mismatch");
676 let diag = ShapeMismatchDiagnostic::with_transpose_suggestion(base, &[3, 2], &[2, 3]);
678 assert!(
679 diag.suggestions.iter().any(|s| s.contains("transpos")),
680 "Expected transpose suggestion, got: {:?}",
681 diag.suggestions
682 );
683 }
684
685 #[test]
686 fn test_broadcast_suggestion_added() {
687 let base = Diagnostic::error("shape mismatch");
688 let diag = ShapeMismatchDiagnostic::with_broadcast_suggestion(base, &[1, 4], &[4]);
690 assert!(
691 diag.suggestions
692 .iter()
693 .any(|s| s.contains("unsqueez") || s.contains("broadcast")),
694 "Expected broadcast suggestion, got: {:?}",
695 diag.suggestions
696 );
697 }
698
699 #[test]
700 fn test_parallelism_diagnostic() {
701 let diag = PerformanceDiagnostic::parallelism_available(8, 2);
702 assert_eq!(diag.severity, Severity::Info);
703 assert!(diag.message.contains("Parallelism opportunity"));
704 assert!(!diag.suggestions.is_empty());
705 }
706
707 #[test]
708 fn test_fusion_opportunity_diagnostic() {
709 let diag = PerformanceDiagnostic::fusion_opportunity(3, &["relu", "matmul", "add"]);
710 assert_eq!(diag.severity, Severity::Info);
711 assert!(diag.message.contains("Fusion opportunity"));
712 assert!(diag.context.iter().any(|c| c.contains("relu")));
713 }
714}