Skip to main content

tensorlogic_infer/
diagnostics.rs

1//! Enhanced error diagnostics with helpful suggestions.
2//!
3//! This module provides rich error messages with context, suggestions,
4//! and actionable advice for common mistakes.
5
6use std::fmt;
7use tensorlogic_ir::EinsumGraph;
8
9use crate::shape::TensorShape;
10
11/// Diagnostic severity level
12#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
13pub enum Severity {
14    /// Informational message
15    Info,
16    /// Warning (non-fatal)
17    Warning,
18    /// Error (fatal, prevents execution)
19    Error,
20    /// Critical error (system-level issue)
21    Critical,
22}
23
24impl fmt::Display for Severity {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        match self {
27            Severity::Info => write!(f, "INFO"),
28            Severity::Warning => write!(f, "WARNING"),
29            Severity::Error => write!(f, "ERROR"),
30            Severity::Critical => write!(f, "CRITICAL"),
31        }
32    }
33}
34
35/// Source location for error reporting
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct SourceLocation {
38    pub file: Option<String>,
39    pub line: Option<usize>,
40    pub column: Option<usize>,
41}
42
43impl SourceLocation {
44    pub fn new() -> Self {
45        SourceLocation {
46            file: None,
47            line: None,
48            column: None,
49        }
50    }
51
52    pub fn with_file(mut self, file: String) -> Self {
53        self.file = Some(file);
54        self
55    }
56
57    pub fn with_line(mut self, line: usize) -> Self {
58        self.line = Some(line);
59        self
60    }
61
62    pub fn with_column(mut self, column: usize) -> Self {
63        self.column = Some(column);
64        self
65    }
66}
67
68impl Default for SourceLocation {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl fmt::Display for SourceLocation {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        if let Some(ref file) = self.file {
77            write!(f, "{}", file)?;
78            if let Some(line) = self.line {
79                write!(f, ":{}", line)?;
80                if let Some(column) = self.column {
81                    write!(f, ":{}", column)?;
82                }
83            }
84        } else {
85            write!(f, "<unknown>")?;
86        }
87        Ok(())
88    }
89}
90
91/// Detailed diagnostic message
92#[derive(Debug, Clone)]
93pub struct Diagnostic {
94    /// Severity level
95    pub severity: Severity,
96    /// Primary error message
97    pub message: String,
98    /// Source location
99    pub location: Option<SourceLocation>,
100    /// Additional context
101    pub context: Vec<String>,
102    /// Suggested fixes
103    pub suggestions: Vec<String>,
104    /// Related nodes or operations
105    pub related: Vec<String>,
106    /// Error code (for documentation lookup)
107    pub code: Option<String>,
108}
109
110impl Diagnostic {
111    /// Create a new diagnostic
112    pub fn new(severity: Severity, message: impl Into<String>) -> Self {
113        Diagnostic {
114            severity,
115            message: message.into(),
116            location: None,
117            context: Vec::new(),
118            suggestions: Vec::new(),
119            related: Vec::new(),
120            code: None,
121        }
122    }
123
124    /// Create an error diagnostic
125    pub fn error(message: impl Into<String>) -> Self {
126        Self::new(Severity::Error, message)
127    }
128
129    /// Create a warning diagnostic
130    pub fn warning(message: impl Into<String>) -> Self {
131        Self::new(Severity::Warning, message)
132    }
133
134    /// Create an info diagnostic
135    pub fn info(message: impl Into<String>) -> Self {
136        Self::new(Severity::Info, message)
137    }
138
139    /// Add source location
140    pub fn with_location(mut self, location: SourceLocation) -> Self {
141        self.location = Some(location);
142        self
143    }
144
145    /// Add context information
146    pub fn with_context(mut self, context: impl Into<String>) -> Self {
147        self.context.push(context.into());
148        self
149    }
150
151    /// Add suggestion
152    pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
153        self.suggestions.push(suggestion.into());
154        self
155    }
156
157    /// Add related information
158    pub fn with_related(mut self, related: impl Into<String>) -> Self {
159        self.related.push(related.into());
160        self
161    }
162
163    /// Add error code
164    pub fn with_code(mut self, code: impl Into<String>) -> Self {
165        self.code = Some(code.into());
166        self
167    }
168
169    /// Format as user-friendly string
170    pub fn format(&self) -> String {
171        let mut output = String::new();
172
173        // Header
174        output.push_str(&format!("[{}] {}\n", self.severity, self.message));
175
176        // Location
177        if let Some(ref loc) = self.location {
178            output.push_str(&format!("  at {}\n", loc));
179        }
180
181        // Error code
182        if let Some(ref code) = self.code {
183            output.push_str(&format!("  code: {}\n", code));
184        }
185
186        // Context
187        if !self.context.is_empty() {
188            output.push_str("\nContext:\n");
189            for ctx in &self.context {
190                output.push_str(&format!("  {}\n", ctx));
191            }
192        }
193
194        // Suggestions
195        if !self.suggestions.is_empty() {
196            output.push_str("\nSuggestions:\n");
197            for (i, suggestion) in self.suggestions.iter().enumerate() {
198                output.push_str(&format!("  {}. {}\n", i + 1, suggestion));
199            }
200        }
201
202        // Related
203        if !self.related.is_empty() {
204            output.push_str("\nRelated:\n");
205            for rel in &self.related {
206                output.push_str(&format!("  - {}\n", rel));
207            }
208        }
209
210        output
211    }
212}
213
214impl fmt::Display for Diagnostic {
215    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        write!(f, "{}", self.format())
217    }
218}
219
220/// Shape mismatch diagnostic builder
221pub struct ShapeMismatchDiagnostic;
222
223impl ShapeMismatchDiagnostic {
224    pub fn create(expected: &TensorShape, actual: &TensorShape, operation: &str) -> Diagnostic {
225        let mut diag = Diagnostic::error(format!("Shape mismatch in {} operation", operation))
226            .with_code("E001")
227            .with_context(format!(
228                "Expected shape: {:?}, but got: {:?}",
229                expected.dims, actual.dims
230            ));
231
232        // Add specific suggestions based on the mismatch
233        if expected.rank() != actual.rank() {
234            diag = diag
235                .with_suggestion(format!(
236                    "Expected rank {} but got rank {}. Consider reshaping your tensor.",
237                    expected.rank(),
238                    actual.rank()
239                ))
240                .with_suggestion(format!(
241                    "Use tensor.reshape({:?}) to match the expected shape",
242                    expected.dims
243                ));
244        } else {
245            // Same rank, dimension mismatch
246            let mismatches: Vec<_> = expected
247                .dims
248                .iter()
249                .zip(actual.dims.iter())
250                .enumerate()
251                .filter(|(_, (e, a))| e != a)
252                .collect();
253
254            for (dim, (exp, act)) in mismatches {
255                diag = diag.with_context(format!(
256                    "Dimension {} mismatch: expected {:?}, got {:?}",
257                    dim, exp, act
258                ));
259            }
260
261            diag = diag.with_suggestion(
262                "Check your input tensor shapes match the expected dimensions".to_string(),
263            );
264        }
265
266        diag
267    }
268}
269
270/// Type mismatch diagnostic builder
271pub struct TypeMismatchDiagnostic;
272
273impl TypeMismatchDiagnostic {
274    pub fn create(expected: &str, actual: &str, context: &str) -> Diagnostic {
275        Diagnostic::error(format!("Type mismatch in {}", context))
276            .with_code("E002")
277            .with_context(format!("Expected type: {}, but got: {}", expected, actual))
278            .with_suggestion(format!("Convert your data to {} type", expected))
279            .with_suggestion("Check the input data types match the expected types".to_string())
280    }
281}
282
283/// Node execution diagnostic builder
284pub struct NodeExecutionDiagnostic;
285
286impl NodeExecutionDiagnostic {
287    pub fn create(node_id: usize, error: &str, graph: &EinsumGraph) -> Diagnostic {
288        let mut diag = Diagnostic::error(format!("Failed to execute node {}", node_id))
289            .with_code("E003")
290            .with_context(error.to_string());
291
292        // Add node information
293        if let Some(node) = graph.nodes.get(node_id) {
294            diag = diag.with_context(format!("Node operation: {:?}", node.op));
295
296            // Add input information
297            if !node.inputs.is_empty() {
298                diag = diag.with_context(format!("Input nodes: {:?}", node.inputs));
299            }
300
301            // Add suggestions based on operation type
302            diag = diag.with_suggestion(
303                "Check that all input tensors are properly initialized".to_string(),
304            );
305            diag = diag.with_suggestion(
306                "Verify input tensor shapes are compatible with this operation".to_string(),
307            );
308        }
309
310        // Add related nodes
311        for input_id in graph
312            .nodes
313            .get(node_id)
314            .map(|n| &n.inputs)
315            .unwrap_or(&vec![])
316        {
317            diag = diag.with_related(format!("Input node: {}", input_id));
318        }
319
320        diag
321    }
322}
323
324/// Memory diagnostic builder
325pub struct MemoryDiagnostic;
326
327impl MemoryDiagnostic {
328    pub fn out_of_memory(requested_bytes: usize, available_bytes: usize) -> Diagnostic {
329        let requested_mb = requested_bytes as f64 / (1024.0 * 1024.0);
330        let available_mb = available_bytes as f64 / (1024.0 * 1024.0);
331
332        Diagnostic::error("Out of memory")
333            .with_code("E004")
334            .with_context(format!(
335                "Requested: {:.2} MB, Available: {:.2} MB",
336                requested_mb, available_mb
337            ))
338            .with_suggestion("Reduce batch size to lower memory usage".to_string())
339            .with_suggestion("Enable streaming execution for large datasets".to_string())
340            .with_suggestion("Consider using a machine with more memory".to_string())
341            .with_suggestion("Enable memory pooling to reuse allocations".to_string())
342    }
343
344    pub fn memory_leak_warning(leaked_bytes: usize) -> Diagnostic {
345        let leaked_mb = leaked_bytes as f64 / (1024.0 * 1024.0);
346
347        Diagnostic::warning(format!(
348            "Potential memory leak detected: {:.2} MB",
349            leaked_mb
350        ))
351        .with_code("W001")
352        .with_suggestion("Check that all tensors are properly released".to_string())
353        .with_suggestion("Enable memory profiling to identify the leak source".to_string())
354        .with_suggestion("Use memory pooling to manage allocations".to_string())
355    }
356}
357
358/// Performance diagnostic builder
359pub struct PerformanceDiagnostic;
360
361impl PerformanceDiagnostic {
362    pub fn slow_operation(
363        operation: &str,
364        actual_time_ms: f64,
365        expected_time_ms: f64,
366    ) -> Diagnostic {
367        let slowdown = actual_time_ms / expected_time_ms;
368
369        Diagnostic::warning(format!(
370            "Slow {} operation: {:.2}x slower than expected",
371            operation, slowdown
372        ))
373        .with_code("W002")
374        .with_context(format!(
375            "Actual: {:.2}ms, Expected: {:.2}ms",
376            actual_time_ms, expected_time_ms
377        ))
378        .with_suggestion("Enable graph optimization to improve performance".to_string())
379        .with_suggestion("Check if operation fusion is enabled".to_string())
380        .with_suggestion("Consider using a more powerful device (GPU)".to_string())
381        .with_suggestion("Profile the execution to identify bottlenecks".to_string())
382    }
383
384    pub fn high_memory_usage(peak_mb: f64, threshold_mb: f64) -> Diagnostic {
385        Diagnostic::warning(format!("High memory usage: {:.2} MB", peak_mb))
386            .with_code("W003")
387            .with_context(format!("Threshold: {:.2} MB", threshold_mb))
388            .with_suggestion("Enable memory optimization".to_string())
389            .with_suggestion("Reduce batch size".to_string())
390            .with_suggestion("Use streaming execution for large datasets".to_string())
391    }
392}
393
394/// Diagnostic collector for gathering multiple diagnostics
395#[derive(Debug, Default)]
396pub struct DiagnosticCollector {
397    diagnostics: Vec<Diagnostic>,
398}
399
400impl DiagnosticCollector {
401    pub fn new() -> Self {
402        Self::default()
403    }
404
405    /// Add a diagnostic
406    pub fn add(&mut self, diagnostic: Diagnostic) {
407        self.diagnostics.push(diagnostic);
408    }
409
410    /// Get all diagnostics
411    pub fn diagnostics(&self) -> &[Diagnostic] {
412        &self.diagnostics
413    }
414
415    /// Check if there are any errors
416    pub fn has_errors(&self) -> bool {
417        self.diagnostics
418            .iter()
419            .any(|d| d.severity >= Severity::Error)
420    }
421
422    /// Get error count
423    pub fn error_count(&self) -> usize {
424        self.diagnostics
425            .iter()
426            .filter(|d| d.severity == Severity::Error)
427            .count()
428    }
429
430    /// Get warning count
431    pub fn warning_count(&self) -> usize {
432        self.diagnostics
433            .iter()
434            .filter(|d| d.severity == Severity::Warning)
435            .count()
436    }
437
438    /// Format all diagnostics
439    pub fn format_all(&self) -> String {
440        let mut output = String::new();
441        for diag in &self.diagnostics {
442            output.push_str(&diag.format());
443            output.push('\n');
444        }
445
446        output.push_str(&format!(
447            "\nSummary: {} error(s), {} warning(s)\n",
448            self.error_count(),
449            self.warning_count()
450        ));
451
452        output
453    }
454
455    /// Clear all diagnostics
456    pub fn clear(&mut self) {
457        self.diagnostics.clear();
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn test_diagnostic_creation() {
467        let diag = Diagnostic::error("Test error")
468            .with_code("E001")
469            .with_context("Additional context")
470            .with_suggestion("Try this fix");
471
472        assert_eq!(diag.severity, Severity::Error);
473        assert_eq!(diag.message, "Test error");
474        assert_eq!(diag.code, Some("E001".to_string()));
475        assert_eq!(diag.context.len(), 1);
476        assert_eq!(diag.suggestions.len(), 1);
477    }
478
479    #[test]
480    fn test_shape_mismatch_diagnostic() {
481        let expected = TensorShape::static_shape(vec![64, 128]);
482        let actual = TensorShape::static_shape(vec![64, 256]);
483
484        let diag = ShapeMismatchDiagnostic::create(&expected, &actual, "matmul");
485
486        assert_eq!(diag.severity, Severity::Error);
487        assert!(diag.message.contains("Shape mismatch"));
488        assert!(!diag.suggestions.is_empty());
489    }
490
491    #[test]
492    fn test_type_mismatch_diagnostic() {
493        let diag = TypeMismatchDiagnostic::create("f32", "f64", "tensor operation");
494
495        assert_eq!(diag.severity, Severity::Error);
496        assert!(diag.message.contains("Type mismatch"));
497        assert_eq!(diag.code, Some("E002".to_string()));
498    }
499
500    #[test]
501    fn test_memory_diagnostic() {
502        let diag = MemoryDiagnostic::out_of_memory(1024 * 1024 * 1024, 512 * 1024 * 1024);
503
504        assert_eq!(diag.severity, Severity::Error);
505        assert!(diag.message.contains("Out of memory"));
506        assert!(!diag.suggestions.is_empty());
507    }
508
509    #[test]
510    fn test_performance_diagnostic() {
511        let diag = PerformanceDiagnostic::slow_operation("einsum", 100.0, 50.0);
512
513        assert_eq!(diag.severity, Severity::Warning);
514        assert!(diag.message.contains("Slow"));
515        assert!(diag.message.contains("2.00x"));
516    }
517
518    #[test]
519    fn test_diagnostic_collector() {
520        let mut collector = DiagnosticCollector::new();
521
522        collector.add(Diagnostic::error("Error 1"));
523        collector.add(Diagnostic::warning("Warning 1"));
524        collector.add(Diagnostic::error("Error 2"));
525
526        assert_eq!(collector.error_count(), 2);
527        assert_eq!(collector.warning_count(), 1);
528        assert!(collector.has_errors());
529
530        let formatted = collector.format_all();
531        assert!(formatted.contains("2 error(s), 1 warning(s)"));
532    }
533
534    #[test]
535    fn test_source_location() {
536        let loc = SourceLocation::new()
537            .with_file("test.rs".to_string())
538            .with_line(42)
539            .with_column(10);
540
541        assert_eq!(loc.to_string(), "test.rs:42:10");
542    }
543
544    #[test]
545    fn test_severity_ordering() {
546        assert!(Severity::Info < Severity::Warning);
547        assert!(Severity::Warning < Severity::Error);
548        assert!(Severity::Error < Severity::Critical);
549    }
550}