1use std::fmt;
7use tensorlogic_ir::EinsumGraph;
8
9use crate::shape::TensorShape;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
13pub enum Severity {
14 Info,
16 Warning,
18 Error,
20 Critical,
22}
23
24impl fmt::Display for Severity {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 match self {
27 Severity::Info => write!(f, "INFO"),
28 Severity::Warning => write!(f, "WARNING"),
29 Severity::Error => write!(f, "ERROR"),
30 Severity::Critical => write!(f, "CRITICAL"),
31 }
32 }
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct SourceLocation {
38 pub file: Option<String>,
39 pub line: Option<usize>,
40 pub column: Option<usize>,
41}
42
43impl SourceLocation {
44 pub fn new() -> Self {
45 SourceLocation {
46 file: None,
47 line: None,
48 column: None,
49 }
50 }
51
52 pub fn with_file(mut self, file: String) -> Self {
53 self.file = Some(file);
54 self
55 }
56
57 pub fn with_line(mut self, line: usize) -> Self {
58 self.line = Some(line);
59 self
60 }
61
62 pub fn with_column(mut self, column: usize) -> Self {
63 self.column = Some(column);
64 self
65 }
66}
67
68impl Default for SourceLocation {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl fmt::Display for SourceLocation {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 if let Some(ref file) = self.file {
77 write!(f, "{}", file)?;
78 if let Some(line) = self.line {
79 write!(f, ":{}", line)?;
80 if let Some(column) = self.column {
81 write!(f, ":{}", column)?;
82 }
83 }
84 } else {
85 write!(f, "<unknown>")?;
86 }
87 Ok(())
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct Diagnostic {
94 pub severity: Severity,
96 pub message: String,
98 pub location: Option<SourceLocation>,
100 pub context: Vec<String>,
102 pub suggestions: Vec<String>,
104 pub related: Vec<String>,
106 pub code: Option<String>,
108}
109
110impl Diagnostic {
111 pub fn new(severity: Severity, message: impl Into<String>) -> Self {
113 Diagnostic {
114 severity,
115 message: message.into(),
116 location: None,
117 context: Vec::new(),
118 suggestions: Vec::new(),
119 related: Vec::new(),
120 code: None,
121 }
122 }
123
124 pub fn error(message: impl Into<String>) -> Self {
126 Self::new(Severity::Error, message)
127 }
128
129 pub fn warning(message: impl Into<String>) -> Self {
131 Self::new(Severity::Warning, message)
132 }
133
134 pub fn info(message: impl Into<String>) -> Self {
136 Self::new(Severity::Info, message)
137 }
138
139 pub fn with_location(mut self, location: SourceLocation) -> Self {
141 self.location = Some(location);
142 self
143 }
144
145 pub fn with_context(mut self, context: impl Into<String>) -> Self {
147 self.context.push(context.into());
148 self
149 }
150
151 pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
153 self.suggestions.push(suggestion.into());
154 self
155 }
156
157 pub fn with_related(mut self, related: impl Into<String>) -> Self {
159 self.related.push(related.into());
160 self
161 }
162
163 pub fn with_code(mut self, code: impl Into<String>) -> Self {
165 self.code = Some(code.into());
166 self
167 }
168
169 pub fn format(&self) -> String {
171 let mut output = String::new();
172
173 output.push_str(&format!("[{}] {}\n", self.severity, self.message));
175
176 if let Some(ref loc) = self.location {
178 output.push_str(&format!(" at {}\n", loc));
179 }
180
181 if let Some(ref code) = self.code {
183 output.push_str(&format!(" code: {}\n", code));
184 }
185
186 if !self.context.is_empty() {
188 output.push_str("\nContext:\n");
189 for ctx in &self.context {
190 output.push_str(&format!(" {}\n", ctx));
191 }
192 }
193
194 if !self.suggestions.is_empty() {
196 output.push_str("\nSuggestions:\n");
197 for (i, suggestion) in self.suggestions.iter().enumerate() {
198 output.push_str(&format!(" {}. {}\n", i + 1, suggestion));
199 }
200 }
201
202 if !self.related.is_empty() {
204 output.push_str("\nRelated:\n");
205 for rel in &self.related {
206 output.push_str(&format!(" - {}\n", rel));
207 }
208 }
209
210 output
211 }
212}
213
214impl fmt::Display for Diagnostic {
215 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216 write!(f, "{}", self.format())
217 }
218}
219
220pub struct ShapeMismatchDiagnostic;
222
223impl ShapeMismatchDiagnostic {
224 pub fn create(expected: &TensorShape, actual: &TensorShape, operation: &str) -> Diagnostic {
225 let mut diag = Diagnostic::error(format!("Shape mismatch in {} operation", operation))
226 .with_code("E001")
227 .with_context(format!(
228 "Expected shape: {:?}, but got: {:?}",
229 expected.dims, actual.dims
230 ));
231
232 if expected.rank() != actual.rank() {
234 diag = diag
235 .with_suggestion(format!(
236 "Expected rank {} but got rank {}. Consider reshaping your tensor.",
237 expected.rank(),
238 actual.rank()
239 ))
240 .with_suggestion(format!(
241 "Use tensor.reshape({:?}) to match the expected shape",
242 expected.dims
243 ));
244 } else {
245 let mismatches: Vec<_> = expected
247 .dims
248 .iter()
249 .zip(actual.dims.iter())
250 .enumerate()
251 .filter(|(_, (e, a))| e != a)
252 .collect();
253
254 for (dim, (exp, act)) in mismatches {
255 diag = diag.with_context(format!(
256 "Dimension {} mismatch: expected {:?}, got {:?}",
257 dim, exp, act
258 ));
259 }
260
261 diag = diag.with_suggestion(
262 "Check your input tensor shapes match the expected dimensions".to_string(),
263 );
264 }
265
266 diag
267 }
268}
269
270pub struct TypeMismatchDiagnostic;
272
273impl TypeMismatchDiagnostic {
274 pub fn create(expected: &str, actual: &str, context: &str) -> Diagnostic {
275 Diagnostic::error(format!("Type mismatch in {}", context))
276 .with_code("E002")
277 .with_context(format!("Expected type: {}, but got: {}", expected, actual))
278 .with_suggestion(format!("Convert your data to {} type", expected))
279 .with_suggestion("Check the input data types match the expected types".to_string())
280 }
281}
282
283pub struct NodeExecutionDiagnostic;
285
286impl NodeExecutionDiagnostic {
287 pub fn create(node_id: usize, error: &str, graph: &EinsumGraph) -> Diagnostic {
288 let mut diag = Diagnostic::error(format!("Failed to execute node {}", node_id))
289 .with_code("E003")
290 .with_context(error.to_string());
291
292 if let Some(node) = graph.nodes.get(node_id) {
294 diag = diag.with_context(format!("Node operation: {:?}", node.op));
295
296 if !node.inputs.is_empty() {
298 diag = diag.with_context(format!("Input nodes: {:?}", node.inputs));
299 }
300
301 diag = diag.with_suggestion(
303 "Check that all input tensors are properly initialized".to_string(),
304 );
305 diag = diag.with_suggestion(
306 "Verify input tensor shapes are compatible with this operation".to_string(),
307 );
308 }
309
310 for input_id in graph
312 .nodes
313 .get(node_id)
314 .map(|n| &n.inputs)
315 .unwrap_or(&vec![])
316 {
317 diag = diag.with_related(format!("Input node: {}", input_id));
318 }
319
320 diag
321 }
322}
323
324pub struct MemoryDiagnostic;
326
327impl MemoryDiagnostic {
328 pub fn out_of_memory(requested_bytes: usize, available_bytes: usize) -> Diagnostic {
329 let requested_mb = requested_bytes as f64 / (1024.0 * 1024.0);
330 let available_mb = available_bytes as f64 / (1024.0 * 1024.0);
331
332 Diagnostic::error("Out of memory")
333 .with_code("E004")
334 .with_context(format!(
335 "Requested: {:.2} MB, Available: {:.2} MB",
336 requested_mb, available_mb
337 ))
338 .with_suggestion("Reduce batch size to lower memory usage".to_string())
339 .with_suggestion("Enable streaming execution for large datasets".to_string())
340 .with_suggestion("Consider using a machine with more memory".to_string())
341 .with_suggestion("Enable memory pooling to reuse allocations".to_string())
342 }
343
344 pub fn memory_leak_warning(leaked_bytes: usize) -> Diagnostic {
345 let leaked_mb = leaked_bytes as f64 / (1024.0 * 1024.0);
346
347 Diagnostic::warning(format!(
348 "Potential memory leak detected: {:.2} MB",
349 leaked_mb
350 ))
351 .with_code("W001")
352 .with_suggestion("Check that all tensors are properly released".to_string())
353 .with_suggestion("Enable memory profiling to identify the leak source".to_string())
354 .with_suggestion("Use memory pooling to manage allocations".to_string())
355 }
356}
357
358pub struct PerformanceDiagnostic;
360
361impl PerformanceDiagnostic {
362 pub fn slow_operation(
363 operation: &str,
364 actual_time_ms: f64,
365 expected_time_ms: f64,
366 ) -> Diagnostic {
367 let slowdown = actual_time_ms / expected_time_ms;
368
369 Diagnostic::warning(format!(
370 "Slow {} operation: {:.2}x slower than expected",
371 operation, slowdown
372 ))
373 .with_code("W002")
374 .with_context(format!(
375 "Actual: {:.2}ms, Expected: {:.2}ms",
376 actual_time_ms, expected_time_ms
377 ))
378 .with_suggestion("Enable graph optimization to improve performance".to_string())
379 .with_suggestion("Check if operation fusion is enabled".to_string())
380 .with_suggestion("Consider using a more powerful device (GPU)".to_string())
381 .with_suggestion("Profile the execution to identify bottlenecks".to_string())
382 }
383
384 pub fn high_memory_usage(peak_mb: f64, threshold_mb: f64) -> Diagnostic {
385 Diagnostic::warning(format!("High memory usage: {:.2} MB", peak_mb))
386 .with_code("W003")
387 .with_context(format!("Threshold: {:.2} MB", threshold_mb))
388 .with_suggestion("Enable memory optimization".to_string())
389 .with_suggestion("Reduce batch size".to_string())
390 .with_suggestion("Use streaming execution for large datasets".to_string())
391 }
392}
393
394#[derive(Debug, Default)]
396pub struct DiagnosticCollector {
397 diagnostics: Vec<Diagnostic>,
398}
399
400impl DiagnosticCollector {
401 pub fn new() -> Self {
402 Self::default()
403 }
404
405 pub fn add(&mut self, diagnostic: Diagnostic) {
407 self.diagnostics.push(diagnostic);
408 }
409
410 pub fn diagnostics(&self) -> &[Diagnostic] {
412 &self.diagnostics
413 }
414
415 pub fn has_errors(&self) -> bool {
417 self.diagnostics
418 .iter()
419 .any(|d| d.severity >= Severity::Error)
420 }
421
422 pub fn error_count(&self) -> usize {
424 self.diagnostics
425 .iter()
426 .filter(|d| d.severity == Severity::Error)
427 .count()
428 }
429
430 pub fn warning_count(&self) -> usize {
432 self.diagnostics
433 .iter()
434 .filter(|d| d.severity == Severity::Warning)
435 .count()
436 }
437
438 pub fn format_all(&self) -> String {
440 let mut output = String::new();
441 for diag in &self.diagnostics {
442 output.push_str(&diag.format());
443 output.push('\n');
444 }
445
446 output.push_str(&format!(
447 "\nSummary: {} error(s), {} warning(s)\n",
448 self.error_count(),
449 self.warning_count()
450 ));
451
452 output
453 }
454
455 pub fn clear(&mut self) {
457 self.diagnostics.clear();
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_diagnostic_creation() {
467 let diag = Diagnostic::error("Test error")
468 .with_code("E001")
469 .with_context("Additional context")
470 .with_suggestion("Try this fix");
471
472 assert_eq!(diag.severity, Severity::Error);
473 assert_eq!(diag.message, "Test error");
474 assert_eq!(diag.code, Some("E001".to_string()));
475 assert_eq!(diag.context.len(), 1);
476 assert_eq!(diag.suggestions.len(), 1);
477 }
478
479 #[test]
480 fn test_shape_mismatch_diagnostic() {
481 let expected = TensorShape::static_shape(vec![64, 128]);
482 let actual = TensorShape::static_shape(vec![64, 256]);
483
484 let diag = ShapeMismatchDiagnostic::create(&expected, &actual, "matmul");
485
486 assert_eq!(diag.severity, Severity::Error);
487 assert!(diag.message.contains("Shape mismatch"));
488 assert!(!diag.suggestions.is_empty());
489 }
490
491 #[test]
492 fn test_type_mismatch_diagnostic() {
493 let diag = TypeMismatchDiagnostic::create("f32", "f64", "tensor operation");
494
495 assert_eq!(diag.severity, Severity::Error);
496 assert!(diag.message.contains("Type mismatch"));
497 assert_eq!(diag.code, Some("E002".to_string()));
498 }
499
500 #[test]
501 fn test_memory_diagnostic() {
502 let diag = MemoryDiagnostic::out_of_memory(1024 * 1024 * 1024, 512 * 1024 * 1024);
503
504 assert_eq!(diag.severity, Severity::Error);
505 assert!(diag.message.contains("Out of memory"));
506 assert!(!diag.suggestions.is_empty());
507 }
508
509 #[test]
510 fn test_performance_diagnostic() {
511 let diag = PerformanceDiagnostic::slow_operation("einsum", 100.0, 50.0);
512
513 assert_eq!(diag.severity, Severity::Warning);
514 assert!(diag.message.contains("Slow"));
515 assert!(diag.message.contains("2.00x"));
516 }
517
518 #[test]
519 fn test_diagnostic_collector() {
520 let mut collector = DiagnosticCollector::new();
521
522 collector.add(Diagnostic::error("Error 1"));
523 collector.add(Diagnostic::warning("Warning 1"));
524 collector.add(Diagnostic::error("Error 2"));
525
526 assert_eq!(collector.error_count(), 2);
527 assert_eq!(collector.warning_count(), 1);
528 assert!(collector.has_errors());
529
530 let formatted = collector.format_all();
531 assert!(formatted.contains("2 error(s), 1 warning(s)"));
532 }
533
534 #[test]
535 fn test_source_location() {
536 let loc = SourceLocation::new()
537 .with_file("test.rs".to_string())
538 .with_line(42)
539 .with_column(10);
540
541 assert_eq!(loc.to_string(), "test.rs:42:10");
542 }
543
544 #[test]
545 fn test_severity_ordering() {
546 assert!(Severity::Info < Severity::Warning);
547 assert!(Severity::Warning < Severity::Error);
548 assert!(Severity::Error < Severity::Critical);
549 }
550}