Skip to main content

tensorlogic_infer/
diagnostics.rs

1//! Enhanced error diagnostics with helpful suggestions.
2//!
3//! This module provides rich error messages with context, suggestions,
4//! and actionable advice for common mistakes.
5
6use std::fmt;
7use tensorlogic_ir::EinsumGraph;
8
9use crate::shape::TensorShape;
10
11/// Diagnostic severity level
12#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
13pub enum Severity {
14    /// Informational message
15    Info,
16    /// Warning (non-fatal)
17    Warning,
18    /// Error (fatal, prevents execution)
19    Error,
20    /// Critical error (system-level issue)
21    Critical,
22}
23
24impl fmt::Display for Severity {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        match self {
27            Severity::Info => write!(f, "INFO"),
28            Severity::Warning => write!(f, "WARNING"),
29            Severity::Error => write!(f, "ERROR"),
30            Severity::Critical => write!(f, "CRITICAL"),
31        }
32    }
33}
34
35/// Source location for error reporting
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct SourceLocation {
38    pub file: Option<String>,
39    pub line: Option<usize>,
40    pub column: Option<usize>,
41}
42
43impl SourceLocation {
44    pub fn new() -> Self {
45        SourceLocation {
46            file: None,
47            line: None,
48            column: None,
49        }
50    }
51
52    pub fn with_file(mut self, file: String) -> Self {
53        self.file = Some(file);
54        self
55    }
56
57    pub fn with_line(mut self, line: usize) -> Self {
58        self.line = Some(line);
59        self
60    }
61
62    pub fn with_column(mut self, column: usize) -> Self {
63        self.column = Some(column);
64        self
65    }
66}
67
68impl Default for SourceLocation {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl fmt::Display for SourceLocation {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        if let Some(ref file) = self.file {
77            write!(f, "{}", file)?;
78            if let Some(line) = self.line {
79                write!(f, ":{}", line)?;
80                if let Some(column) = self.column {
81                    write!(f, ":{}", column)?;
82                }
83            }
84        } else {
85            write!(f, "<unknown>")?;
86        }
87        Ok(())
88    }
89}
90
91/// Detailed diagnostic message
92#[derive(Debug, Clone)]
93pub struct Diagnostic {
94    /// Severity level
95    pub severity: Severity,
96    /// Primary error message
97    pub message: String,
98    /// Source location
99    pub location: Option<SourceLocation>,
100    /// Additional context
101    pub context: Vec<String>,
102    /// Suggested fixes
103    pub suggestions: Vec<String>,
104    /// Related nodes or operations
105    pub related: Vec<String>,
106    /// Error code (for documentation lookup)
107    pub code: Option<String>,
108}
109
110impl Diagnostic {
111    /// Create a new diagnostic
112    pub fn new(severity: Severity, message: impl Into<String>) -> Self {
113        Diagnostic {
114            severity,
115            message: message.into(),
116            location: None,
117            context: Vec::new(),
118            suggestions: Vec::new(),
119            related: Vec::new(),
120            code: None,
121        }
122    }
123
124    /// Create an error diagnostic
125    pub fn error(message: impl Into<String>) -> Self {
126        Self::new(Severity::Error, message)
127    }
128
129    /// Create a warning diagnostic
130    pub fn warning(message: impl Into<String>) -> Self {
131        Self::new(Severity::Warning, message)
132    }
133
134    /// Create an info diagnostic
135    pub fn info(message: impl Into<String>) -> Self {
136        Self::new(Severity::Info, message)
137    }
138
139    /// Add source location
140    pub fn with_location(mut self, location: SourceLocation) -> Self {
141        self.location = Some(location);
142        self
143    }
144
145    /// Add context information
146    pub fn with_context(mut self, context: impl Into<String>) -> Self {
147        self.context.push(context.into());
148        self
149    }
150
151    /// Add suggestion
152    pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
153        self.suggestions.push(suggestion.into());
154        self
155    }
156
157    /// Add related information
158    pub fn with_related(mut self, related: impl Into<String>) -> Self {
159        self.related.push(related.into());
160        self
161    }
162
163    /// Add error code
164    pub fn with_code(mut self, code: impl Into<String>) -> Self {
165        self.code = Some(code.into());
166        self
167    }
168
169    /// Format as user-friendly string
170    pub fn format(&self) -> String {
171        let mut output = String::new();
172
173        // Header
174        output.push_str(&format!("[{}] {}\n", self.severity, self.message));
175
176        // Location
177        if let Some(ref loc) = self.location {
178            output.push_str(&format!("  at {}\n", loc));
179        }
180
181        // Error code
182        if let Some(ref code) = self.code {
183            output.push_str(&format!("  code: {}\n", code));
184        }
185
186        // Context
187        if !self.context.is_empty() {
188            output.push_str("\nContext:\n");
189            for ctx in &self.context {
190                output.push_str(&format!("  {}\n", ctx));
191            }
192        }
193
194        // Suggestions
195        if !self.suggestions.is_empty() {
196            output.push_str("\nSuggestions:\n");
197            for (i, suggestion) in self.suggestions.iter().enumerate() {
198                output.push_str(&format!("  {}. {}\n", i + 1, suggestion));
199            }
200        }
201
202        // Related
203        if !self.related.is_empty() {
204            output.push_str("\nRelated:\n");
205            for rel in &self.related {
206                output.push_str(&format!("  - {}\n", rel));
207            }
208        }
209
210        output
211    }
212}
213
214impl fmt::Display for Diagnostic {
215    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        write!(f, "{}", self.format())
217    }
218}
219
220/// Shape mismatch diagnostic builder
221pub struct ShapeMismatchDiagnostic;
222
223impl ShapeMismatchDiagnostic {
224    pub fn create(expected: &TensorShape, actual: &TensorShape, operation: &str) -> Diagnostic {
225        let mut diag = Diagnostic::error(format!("Shape mismatch in {} operation", operation))
226            .with_code("E001")
227            .with_context(format!(
228                "Expected shape: {:?}, but got: {:?}",
229                expected.dims, actual.dims
230            ));
231
232        // Add specific suggestions based on the mismatch
233        if expected.rank() != actual.rank() {
234            diag = diag
235                .with_suggestion(format!(
236                    "Expected rank {} but got rank {}. Consider reshaping your tensor.",
237                    expected.rank(),
238                    actual.rank()
239                ))
240                .with_suggestion(format!(
241                    "Use tensor.reshape({:?}) to match the expected shape",
242                    expected.dims
243                ));
244        } else {
245            // Same rank, dimension mismatch
246            let mismatches: Vec<_> = expected
247                .dims
248                .iter()
249                .zip(actual.dims.iter())
250                .enumerate()
251                .filter(|(_, (e, a))| e != a)
252                .collect();
253
254            for (dim, (exp, act)) in mismatches {
255                diag = diag.with_context(format!(
256                    "Dimension {} mismatch: expected {:?}, got {:?}",
257                    dim, exp, act
258                ));
259            }
260
261            diag = diag.with_suggestion(
262                "Check your input tensor shapes match the expected dimensions".to_string(),
263            );
264        }
265
266        diag
267    }
268}
269
270/// Type mismatch diagnostic builder
271pub struct TypeMismatchDiagnostic;
272
273impl TypeMismatchDiagnostic {
274    pub fn create(expected: &str, actual: &str, context: &str) -> Diagnostic {
275        Diagnostic::error(format!("Type mismatch in {}", context))
276            .with_code("E002")
277            .with_context(format!("Expected type: {}, but got: {}", expected, actual))
278            .with_suggestion(format!("Convert your data to {} type", expected))
279            .with_suggestion("Check the input data types match the expected types".to_string())
280    }
281}
282
283/// Node execution diagnostic builder
284pub struct NodeExecutionDiagnostic;
285
286impl NodeExecutionDiagnostic {
287    pub fn create(node_id: usize, error: &str, graph: &EinsumGraph) -> Diagnostic {
288        let mut diag = Diagnostic::error(format!("Failed to execute node {}", node_id))
289            .with_code("E003")
290            .with_context(error.to_string());
291
292        // Add node information
293        if let Some(node) = graph.nodes.get(node_id) {
294            diag = diag.with_context(format!("Node operation: {:?}", node.op));
295
296            // Add input information
297            if !node.inputs.is_empty() {
298                diag = diag.with_context(format!("Input nodes: {:?}", node.inputs));
299            }
300
301            // Add suggestions based on operation type
302            diag = diag.with_suggestion(
303                "Check that all input tensors are properly initialized".to_string(),
304            );
305            diag = diag.with_suggestion(
306                "Verify input tensor shapes are compatible with this operation".to_string(),
307            );
308        }
309
310        // Add related nodes
311        for input_id in graph
312            .nodes
313            .get(node_id)
314            .map(|n| &n.inputs)
315            .unwrap_or(&vec![])
316        {
317            diag = diag.with_related(format!("Input node: {}", input_id));
318        }
319
320        diag
321    }
322}
323
324/// Memory diagnostic builder
325pub struct MemoryDiagnostic;
326
327impl MemoryDiagnostic {
328    pub fn out_of_memory(requested_bytes: usize, available_bytes: usize) -> Diagnostic {
329        let requested_mb = requested_bytes as f64 / (1024.0 * 1024.0);
330        let available_mb = available_bytes as f64 / (1024.0 * 1024.0);
331
332        Diagnostic::error("Out of memory")
333            .with_code("E004")
334            .with_context(format!(
335                "Requested: {:.2} MB, Available: {:.2} MB",
336                requested_mb, available_mb
337            ))
338            .with_suggestion("Reduce batch size to lower memory usage".to_string())
339            .with_suggestion("Enable streaming execution for large datasets".to_string())
340            .with_suggestion("Consider using a machine with more memory".to_string())
341            .with_suggestion("Enable memory pooling to reuse allocations".to_string())
342    }
343
344    pub fn memory_leak_warning(leaked_bytes: usize) -> Diagnostic {
345        let leaked_mb = leaked_bytes as f64 / (1024.0 * 1024.0);
346
347        Diagnostic::warning(format!(
348            "Potential memory leak detected: {:.2} MB",
349            leaked_mb
350        ))
351        .with_code("W001")
352        .with_suggestion("Check that all tensors are properly released".to_string())
353        .with_suggestion("Enable memory profiling to identify the leak source".to_string())
354        .with_suggestion("Use memory pooling to manage allocations".to_string())
355    }
356}
357
358impl ShapeMismatchDiagnostic {
359    /// If `expected` and `actual` are permutations of each other, add a transpose suggestion.
360    pub fn with_transpose_suggestion(
361        mut diag: Diagnostic,
362        expected: &[usize],
363        actual: &[usize],
364    ) -> Diagnostic {
365        if expected.len() == actual.len() {
366            let mut sorted_expected = expected.to_vec();
367            let mut sorted_actual = actual.to_vec();
368            sorted_expected.sort_unstable();
369            sorted_actual.sort_unstable();
370            if sorted_expected == sorted_actual {
371                // Find the permutation that maps actual → expected.
372                let perm: Vec<usize> = expected
373                    .iter()
374                    .map(|&e| actual.iter().position(|&a| a == e).unwrap_or(0))
375                    .collect();
376                diag = diag.with_suggestion(format!(
377                    "Shapes are permutations of each other. Consider transposing with axes {:?}",
378                    perm
379                ));
380            }
381        }
382        diag
383    }
384
385    /// If the ranks differ by 1 and broadcast/unsqueeze could reconcile them, add a suggestion.
386    pub fn with_broadcast_suggestion(
387        mut diag: Diagnostic,
388        expected: &[usize],
389        actual: &[usize],
390    ) -> Diagnostic {
391        let rank_diff = (expected.len() as isize - actual.len() as isize).unsigned_abs();
392        if rank_diff == 1 {
393            let (longer, shorter) = if expected.len() > actual.len() {
394                (expected, actual)
395            } else {
396                (actual, expected)
397            };
398            // Check if shorter is a suffix of longer (broadcast-compatible).
399            let suffix_matches = longer
400                .iter()
401                .rev()
402                .zip(shorter.iter().rev())
403                .all(|(&l, &s)| l == s || l == 1 || s == 1);
404            if suffix_matches {
405                diag = diag.with_suggestion(format!(
406                    "Ranks differ by 1. Try unsqueezing to shape {:?} or using broadcasting",
407                    longer
408                ));
409            }
410        }
411        diag
412    }
413}
414
415/// Performance diagnostic builder
416pub struct PerformanceDiagnostic;
417
418impl PerformanceDiagnostic {
419    pub fn slow_operation(
420        operation: &str,
421        actual_time_ms: f64,
422        expected_time_ms: f64,
423    ) -> Diagnostic {
424        let slowdown = actual_time_ms / expected_time_ms;
425
426        Diagnostic::warning(format!(
427            "Slow {} operation: {:.2}x slower than expected",
428            operation, slowdown
429        ))
430        .with_code("W002")
431        .with_context(format!(
432            "Actual: {:.2}ms, Expected: {:.2}ms",
433            actual_time_ms, expected_time_ms
434        ))
435        .with_suggestion("Enable graph optimization to improve performance".to_string())
436        .with_suggestion("Check if operation fusion is enabled".to_string())
437        .with_suggestion("Consider using a more powerful device (GPU)".to_string())
438        .with_suggestion("Profile the execution to identify bottlenecks".to_string())
439    }
440
441    pub fn high_memory_usage(peak_mb: f64, threshold_mb: f64) -> Diagnostic {
442        Diagnostic::warning(format!("High memory usage: {:.2} MB", peak_mb))
443            .with_code("W003")
444            .with_context(format!("Threshold: {:.2} MB", threshold_mb))
445            .with_suggestion("Enable memory optimization".to_string())
446            .with_suggestion("Reduce batch size".to_string())
447            .with_suggestion("Use streaming execution for large datasets".to_string())
448    }
449
450    /// Suggest increasing parallelism when independent ops exceed current thread count.
451    pub fn parallelism_available(num_independent_ops: usize, current_threads: usize) -> Diagnostic {
452        Diagnostic::info(format!(
453            "Parallelism opportunity: {} independent ops, only {} threads active",
454            num_independent_ops, current_threads
455        ))
456        .with_code("P001")
457        .with_context(format!(
458            "{} operations could run in parallel but only {} worker threads are available",
459            num_independent_ops, current_threads
460        ))
461        .with_suggestion(format!(
462            "Increase thread pool size to at least {} for maximum throughput",
463            num_independent_ops
464        ))
465        .with_suggestion(
466            "Use rayon or a work-stealing scheduler for automatic parallelism".to_string(),
467        )
468    }
469
470    /// Suggest memory pooling when the allocation rate exceeds a threshold.
471    pub fn high_allocation_rate(allocs_per_second: f64, threshold: f64) -> Diagnostic {
472        Diagnostic::warning(format!(
473            "High allocation rate: {:.1} allocs/s (threshold: {:.1})",
474            allocs_per_second, threshold
475        ))
476        .with_code("P002")
477        .with_context(format!(
478            "Tensor allocations are occurring at {:.1} per second",
479            allocs_per_second
480        ))
481        .with_suggestion("Enable a memory pool (WorkspacePool) to reuse buffers".to_string())
482        .with_suggestion("Pre-allocate output tensors where output shapes are known".to_string())
483    }
484
485    /// Suggest operation fusion when several fuseable ops are detected.
486    pub fn fusion_opportunity(num_fuseable: usize, op_names: &[&str]) -> Diagnostic {
487        Diagnostic::info(format!(
488            "Fusion opportunity: {} operations could be fused",
489            num_fuseable
490        ))
491        .with_code("P003")
492        .with_context(format!("Fuseable operations: {}", op_names.join(", ")))
493        .with_suggestion(
494            "Enable the FusionOptimizer pass to reduce kernel launch overhead".to_string(),
495        )
496        .with_suggestion("Consider using FusionStrategy::Aggressive for maximum fusion".to_string())
497    }
498
499    /// Suggest reducing f64 → f32 when a meaningful speedup is expected.
500    pub fn precision_downgrade_available(estimated_speedup: f64) -> Diagnostic {
501        Diagnostic::info(format!(
502            "Precision downgrade available: estimated {:.1}x speedup using f32",
503            estimated_speedup
504        ))
505        .with_code("P004")
506        .with_context("Computation is currently using f64 (double) precision".to_string())
507        .with_suggestion(
508            "Switch to f32 (single precision) if model accuracy tolerates it".to_string(),
509        )
510        .with_suggestion(
511            "Use MixedPrecisionConfig to selectively apply f16/f32 where safe".to_string(),
512        )
513    }
514}
515
516/// Diagnostic collector for gathering multiple diagnostics
517#[derive(Debug, Default)]
518pub struct DiagnosticCollector {
519    diagnostics: Vec<Diagnostic>,
520}
521
522impl DiagnosticCollector {
523    pub fn new() -> Self {
524        Self::default()
525    }
526
527    /// Add a diagnostic
528    pub fn add(&mut self, diagnostic: Diagnostic) {
529        self.diagnostics.push(diagnostic);
530    }
531
532    /// Get all diagnostics
533    pub fn diagnostics(&self) -> &[Diagnostic] {
534        &self.diagnostics
535    }
536
537    /// Check if there are any errors
538    pub fn has_errors(&self) -> bool {
539        self.diagnostics
540            .iter()
541            .any(|d| d.severity >= Severity::Error)
542    }
543
544    /// Get error count
545    pub fn error_count(&self) -> usize {
546        self.diagnostics
547            .iter()
548            .filter(|d| d.severity == Severity::Error)
549            .count()
550    }
551
552    /// Get warning count
553    pub fn warning_count(&self) -> usize {
554        self.diagnostics
555            .iter()
556            .filter(|d| d.severity == Severity::Warning)
557            .count()
558    }
559
560    /// Format all diagnostics
561    pub fn format_all(&self) -> String {
562        let mut output = String::new();
563        for diag in &self.diagnostics {
564            output.push_str(&diag.format());
565            output.push('\n');
566        }
567
568        output.push_str(&format!(
569            "\nSummary: {} error(s), {} warning(s)\n",
570            self.error_count(),
571            self.warning_count()
572        ));
573
574        output
575    }
576
577    /// Clear all diagnostics
578    pub fn clear(&mut self) {
579        self.diagnostics.clear();
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    #[test]
588    fn test_diagnostic_creation() {
589        let diag = Diagnostic::error("Test error")
590            .with_code("E001")
591            .with_context("Additional context")
592            .with_suggestion("Try this fix");
593
594        assert_eq!(diag.severity, Severity::Error);
595        assert_eq!(diag.message, "Test error");
596        assert_eq!(diag.code, Some("E001".to_string()));
597        assert_eq!(diag.context.len(), 1);
598        assert_eq!(diag.suggestions.len(), 1);
599    }
600
601    #[test]
602    fn test_shape_mismatch_diagnostic() {
603        let expected = TensorShape::static_shape(vec![64, 128]);
604        let actual = TensorShape::static_shape(vec![64, 256]);
605
606        let diag = ShapeMismatchDiagnostic::create(&expected, &actual, "matmul");
607
608        assert_eq!(diag.severity, Severity::Error);
609        assert!(diag.message.contains("Shape mismatch"));
610        assert!(!diag.suggestions.is_empty());
611    }
612
613    #[test]
614    fn test_type_mismatch_diagnostic() {
615        let diag = TypeMismatchDiagnostic::create("f32", "f64", "tensor operation");
616
617        assert_eq!(diag.severity, Severity::Error);
618        assert!(diag.message.contains("Type mismatch"));
619        assert_eq!(diag.code, Some("E002".to_string()));
620    }
621
622    #[test]
623    fn test_memory_diagnostic() {
624        let diag = MemoryDiagnostic::out_of_memory(1024 * 1024 * 1024, 512 * 1024 * 1024);
625
626        assert_eq!(diag.severity, Severity::Error);
627        assert!(diag.message.contains("Out of memory"));
628        assert!(!diag.suggestions.is_empty());
629    }
630
631    #[test]
632    fn test_performance_diagnostic() {
633        let diag = PerformanceDiagnostic::slow_operation("einsum", 100.0, 50.0);
634
635        assert_eq!(diag.severity, Severity::Warning);
636        assert!(diag.message.contains("Slow"));
637        assert!(diag.message.contains("2.00x"));
638    }
639
640    #[test]
641    fn test_diagnostic_collector() {
642        let mut collector = DiagnosticCollector::new();
643
644        collector.add(Diagnostic::error("Error 1"));
645        collector.add(Diagnostic::warning("Warning 1"));
646        collector.add(Diagnostic::error("Error 2"));
647
648        assert_eq!(collector.error_count(), 2);
649        assert_eq!(collector.warning_count(), 1);
650        assert!(collector.has_errors());
651
652        let formatted = collector.format_all();
653        assert!(formatted.contains("2 error(s), 1 warning(s)"));
654    }
655
656    #[test]
657    fn test_source_location() {
658        let loc = SourceLocation::new()
659            .with_file("test.rs".to_string())
660            .with_line(42)
661            .with_column(10);
662
663        assert_eq!(loc.to_string(), "test.rs:42:10");
664    }
665
666    #[test]
667    fn test_severity_ordering() {
668        assert!(Severity::Info < Severity::Warning);
669        assert!(Severity::Warning < Severity::Error);
670        assert!(Severity::Error < Severity::Critical);
671    }
672
673    #[test]
674    fn test_transpose_suggestion_added() {
675        let base = Diagnostic::error("shape mismatch");
676        // [3, 2] and [2, 3] are permutations of each other.
677        let diag = ShapeMismatchDiagnostic::with_transpose_suggestion(base, &[3, 2], &[2, 3]);
678        assert!(
679            diag.suggestions.iter().any(|s| s.contains("transpos")),
680            "Expected transpose suggestion, got: {:?}",
681            diag.suggestions
682        );
683    }
684
685    #[test]
686    fn test_broadcast_suggestion_added() {
687        let base = Diagnostic::error("shape mismatch");
688        // [1, 4] vs [4] differ by 1 rank; [4] is a suffix of [1, 4].
689        let diag = ShapeMismatchDiagnostic::with_broadcast_suggestion(base, &[1, 4], &[4]);
690        assert!(
691            diag.suggestions
692                .iter()
693                .any(|s| s.contains("unsqueez") || s.contains("broadcast")),
694            "Expected broadcast suggestion, got: {:?}",
695            diag.suggestions
696        );
697    }
698
699    #[test]
700    fn test_parallelism_diagnostic() {
701        let diag = PerformanceDiagnostic::parallelism_available(8, 2);
702        assert_eq!(diag.severity, Severity::Info);
703        assert!(diag.message.contains("Parallelism opportunity"));
704        assert!(!diag.suggestions.is_empty());
705    }
706
707    #[test]
708    fn test_fusion_opportunity_diagnostic() {
709        let diag = PerformanceDiagnostic::fusion_opportunity(3, &["relu", "matmul", "add"]);
710        assert_eq!(diag.severity, Severity::Info);
711        assert!(diag.message.contains("Fusion opportunity"));
712        assert!(diag.context.iter().any(|c| c.contains("relu")));
713    }
714}