Skip to main content

tensorlogic_infer/
visualization.rs

1//! Visualization utilities for execution analysis and debugging.
2//!
3//! This module provides tools for visualizing:
4//! - Execution timelines
5//! - Computation graphs
6//! - Performance data
7//!
8//! Supports multiple output formats:
9//! - ASCII art for terminal display
10//! - GraphViz DOT format
11//! - JSON for custom rendering
12
13use crate::debug::{ExecutionTrace, TensorStats};
14use crate::profiling::ProfileData;
15use tensorlogic_ir::EinsumGraph;
16
17/// Visualization format options.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum VisualizationFormat {
20    /// ASCII art for terminal display
21    Ascii,
22    /// GraphViz DOT format
23    Dot,
24    /// JSON format
25    Json,
26    /// HTML format
27    Html,
28}
29
30/// Timeline visualization configuration.
31#[derive(Debug, Clone)]
32pub struct TimelineConfig {
33    /// Width in characters (for ASCII output)
34    pub width: usize,
35    /// Show operation names
36    pub show_names: bool,
37    /// Show timing information
38    pub show_timing: bool,
39    /// Group by operation type
40    pub group_by_type: bool,
41}
42
43impl Default for TimelineConfig {
44    fn default() -> Self {
45        Self {
46            width: 80,
47            show_names: true,
48            show_timing: true,
49            group_by_type: false,
50        }
51    }
52}
53
54/// Graph visualization configuration.
55#[derive(Debug, Clone)]
56pub struct GraphConfig {
57    /// Show tensor shapes
58    pub show_shapes: bool,
59    /// Show operation types
60    pub show_op_types: bool,
61    /// Highlight critical path
62    pub highlight_critical_path: bool,
63    /// Vertical or horizontal layout
64    pub vertical_layout: bool,
65}
66
67impl Default for GraphConfig {
68    fn default() -> Self {
69        Self {
70            show_shapes: true,
71            show_op_types: true,
72            highlight_critical_path: false,
73            vertical_layout: true,
74        }
75    }
76}
77
78/// Timeline visualizer.
79pub struct TimelineVisualizer {
80    config: TimelineConfig,
81}
82
83impl TimelineVisualizer {
84    /// Create a new timeline visualizer.
85    pub fn new(config: TimelineConfig) -> Self {
86        Self { config }
87    }
88
89    /// Visualize an execution trace as ASCII timeline.
90    pub fn visualize_trace(&self, trace: &ExecutionTrace) -> String {
91        let mut output = String::new();
92
93        // Header
94        output.push_str(&format!(
95            "Execution Timeline ({:.2}ms total)\n",
96            trace.total_duration_ms()
97        ));
98        output.push_str(&"=".repeat(self.config.width));
99        output.push('\n');
100
101        if trace.entries().is_empty() {
102            output.push_str("No operations recorded\n");
103            return output;
104        }
105
106        // Find time bounds
107        let start_time = trace.entries()[0].start_time;
108        let total_duration = trace.total_duration();
109
110        // Draw timeline
111        for entry in trace.entries() {
112            let elapsed = entry.start_time.duration_since(start_time);
113            let duration = entry.duration;
114
115            // Calculate bar position and width
116            let start_pos = (elapsed.as_secs_f64() / total_duration.as_secs_f64()
117                * self.config.width as f64) as usize;
118            let bar_width = ((duration.as_secs_f64() / total_duration.as_secs_f64()
119                * self.config.width as f64) as usize)
120                .max(1);
121
122            // Operation name
123            if self.config.show_names {
124                output.push_str(&format!("Node {}: {} ", entry.node_id, entry.operation));
125            }
126
127            if self.config.show_timing {
128                output.push_str(&format!("({:.2}ms)\n", entry.duration_ms()));
129            } else {
130                output.push('\n');
131            }
132
133            // Timeline bar
134            output.push_str(&" ".repeat(start_pos));
135            output.push_str(&"█".repeat(bar_width));
136            output.push('\n');
137        }
138
139        output.push_str(&"=".repeat(self.config.width));
140        output.push('\n');
141
142        output
143    }
144
145    /// Visualize profile data as text report.
146    pub fn visualize_profile(&self, profile: &ProfileData) -> String {
147        let mut output = String::new();
148
149        output.push_str("Performance Profile\n");
150        output.push_str(&"=".repeat(self.config.width));
151        output.push('\n');
152
153        // Sort operations by total time
154        let mut ops: Vec<_> = profile.op_profiles.iter().collect();
155        ops.sort_by(|(_, a), (_, b)| {
156            let a_total_ms = a.avg_time.as_secs_f64() * 1000.0 * a.count as f64;
157            let b_total_ms = b.avg_time.as_secs_f64() * 1000.0 * b.count as f64;
158            b_total_ms.partial_cmp(&a_total_ms).unwrap()
159        });
160
161        // Header
162        output.push_str(&format!(
163            "{:<30} {:>10} {:>10} {:>15}\n",
164            "Operation", "Count", "Avg (ms)", "Total (ms)"
165        ));
166        output.push_str(&"-".repeat(self.config.width));
167        output.push('\n');
168
169        // Operations
170        for (name, stats) in ops {
171            let avg_time_ms = stats.avg_time.as_secs_f64() * 1000.0;
172            let total_time_ms = avg_time_ms * stats.count as f64;
173            output.push_str(&format!(
174                "{:<30} {:>10} {:>10.2} {:>15.2}\n",
175                name, stats.count, avg_time_ms, total_time_ms
176            ));
177        }
178
179        output.push_str(&"=".repeat(self.config.width));
180        output.push('\n');
181
182        output
183    }
184}
185
186/// Graph visualizer.
187pub struct GraphVisualizer {
188    config: GraphConfig,
189}
190
191impl GraphVisualizer {
192    /// Create a new graph visualizer.
193    pub fn new(config: GraphConfig) -> Self {
194        Self { config }
195    }
196
197    /// Visualize a computation graph as ASCII art.
198    pub fn visualize_ascii(&self, graph: &EinsumGraph) -> String {
199        let mut output = String::new();
200
201        output.push_str("Computation Graph\n");
202        output.push_str("=================\n\n");
203
204        if graph.nodes.is_empty() {
205            output.push_str("Empty graph\n");
206            return output;
207        }
208
209        for (node_idx, node) in graph.nodes.iter().enumerate() {
210            // Node representation
211            output.push_str(&format!("Node {}:\n", node_idx));
212
213            // Operation type
214            if self.config.show_op_types {
215                output.push_str(&format!("  Op: {:?}\n", node.op));
216            }
217
218            // Inputs
219            if !node.inputs.is_empty() {
220                output.push_str("  Inputs: ");
221                for (i, input_id) in node.inputs.iter().enumerate() {
222                    if i > 0 {
223                        output.push_str(", ");
224                    }
225                    output.push_str(&format!("{}", input_id));
226                }
227                output.push('\n');
228            }
229
230            output.push('\n');
231        }
232
233        output
234    }
235
236    /// Generate GraphViz DOT format.
237    pub fn visualize_dot(&self, graph: &EinsumGraph) -> String {
238        let mut output = String::new();
239
240        output.push_str("digraph ComputationGraph {\n");
241        output.push_str("  rankdir=TB;\n");
242        output.push_str("  node [shape=box, style=rounded];\n\n");
243
244        // Nodes
245        for (node_idx, node) in graph.nodes.iter().enumerate() {
246            let label = format!("Node {}\\n{:?}", node_idx, node.op);
247            output.push_str(&format!("  n{} [label=\"{}\"];\n", node_idx, label));
248        }
249
250        output.push('\n');
251
252        // Edges
253        for (node_idx, node) in graph.nodes.iter().enumerate() {
254            for input_id in &node.inputs {
255                output.push_str(&format!("  n{} -> n{};\n", input_id, node_idx));
256            }
257        }
258
259        output.push_str("}\n");
260
261        output
262    }
263
264    /// Generate JSON representation.
265    pub fn visualize_json(&self, graph: &EinsumGraph) -> String {
266        let mut output = String::new();
267
268        output.push_str("{\n");
269        output.push_str("  \"nodes\": [\n");
270
271        for (node_idx, node) in graph.nodes.iter().enumerate() {
272            if node_idx > 0 {
273                output.push_str(",\n");
274            }
275            output.push_str("    {\n");
276            output.push_str(&format!("      \"id\": {},\n", node_idx));
277            output.push_str(&format!("      \"op\": \"{:?}\",\n", node.op));
278            output.push_str("      \"inputs\": [");
279
280            for (j, input_id) in node.inputs.iter().enumerate() {
281                if j > 0 {
282                    output.push_str(", ");
283                }
284                output.push_str(&format!("{}", input_id));
285            }
286
287            output.push_str("]\n");
288            output.push_str("    }");
289        }
290
291        output.push_str("\n  ]\n");
292        output.push_str("}\n");
293
294        output
295    }
296}
297
298/// Tensor statistics visualizer.
299pub struct TensorStatsVisualizer;
300
301impl TensorStatsVisualizer {
302    /// Visualize tensor statistics as a text report.
303    pub fn visualize(&self, stats: &TensorStats) -> String {
304        format!("{}", stats)
305    }
306
307    /// Visualize multiple tensor statistics as a table.
308    pub fn visualize_table(&self, stats: &[TensorStats]) -> String {
309        let mut output = String::new();
310
311        output.push_str("Tensor Statistics\n");
312        output.push_str(&"=".repeat(80));
313        output.push('\n');
314
315        if stats.is_empty() {
316            output.push_str("No tensors recorded\n");
317            return output;
318        }
319
320        // Header
321        output.push_str(&format!(
322            "{:<8} {:<20} {:<15} {:>10} {:>10}\n",
323            "ID", "Shape", "DType", "NaNs", "Infs"
324        ));
325        output.push_str(&"-".repeat(80));
326        output.push('\n');
327
328        // Rows
329        for stat in stats {
330            let shape_str = format!("{:?}", stat.shape);
331            let nans = stat.num_nans.unwrap_or(0);
332            let infs = stat.num_infs.unwrap_or(0);
333
334            output.push_str(&format!(
335                "{:<8} {:<20} {:<15} {:>10} {:>10}\n",
336                stat.tensor_id, shape_str, stat.dtype, nans, infs
337            ));
338
339            // Highlight issues
340            if stat.has_numerical_issues() {
341                output.push_str("         ⚠️  Numerical issues detected!\n");
342            }
343        }
344
345        output.push_str(&"=".repeat(80));
346        output.push('\n');
347
348        output
349    }
350
351    /// Generate histogram of tensor values (ASCII).
352    pub fn histogram(&self, values: &[f64], bins: usize) -> String {
353        let mut output = String::new();
354
355        if values.is_empty() {
356            return "No values\n".to_string();
357        }
358
359        let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
360        let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
361        let range = max - min;
362
363        if range == 0.0 {
364            return format!("All values are {:.6}\n", min);
365        }
366
367        // Create bins
368        let mut counts = vec![0; bins];
369        for &value in values {
370            let bin = ((value - min) / range * bins as f64) as usize;
371            let bin = bin.min(bins - 1);
372            counts[bin] += 1;
373        }
374
375        let max_count = *counts.iter().max().unwrap();
376
377        // Draw histogram
378        output.push_str("Value Distribution\n");
379        output.push_str(&"=".repeat(50));
380        output.push('\n');
381
382        for (i, &count) in counts.iter().enumerate() {
383            let bin_start = min + (i as f64 / bins as f64) * range;
384            let bin_end = min + ((i + 1) as f64 / bins as f64) * range;
385            let bar_width = if max_count > 0 {
386                (count as f64 / max_count as f64 * 40.0) as usize
387            } else {
388                0
389            };
390
391            output.push_str(&format!(
392                "[{:>8.2}, {:>8.2}): {} ({})\n",
393                bin_start,
394                bin_end,
395                "█".repeat(bar_width),
396                count
397            ));
398        }
399
400        output.push_str(&"=".repeat(50));
401        output.push('\n');
402
403        output
404    }
405}
406
407/// Export formats for external visualization tools.
408pub struct ExportFormat;
409
410impl ExportFormat {
411    /// Export trace to JSON for custom visualization.
412    pub fn trace_to_json(trace: &ExecutionTrace) -> String {
413        let mut output = String::new();
414
415        output.push_str("{\n");
416        output.push_str(&format!(
417            "  \"total_duration_ms\": {},\n",
418            trace.total_duration_ms()
419        ));
420        output.push_str("  \"entries\": [\n");
421
422        for (i, entry) in trace.entries().iter().enumerate() {
423            if i > 0 {
424                output.push_str(",\n");
425            }
426            output.push_str("    {\n");
427            output.push_str(&format!("      \"entry_id\": {},\n", entry.entry_id));
428            output.push_str(&format!("      \"node_id\": {},\n", entry.node_id));
429            output.push_str(&format!("      \"operation\": \"{}\",\n", entry.operation));
430            output.push_str(&format!(
431                "      \"duration_ms\": {},\n",
432                entry.duration_ms()
433            ));
434            output.push_str(&format!("      \"input_ids\": {:?},\n", entry.input_ids));
435            output.push_str(&format!("      \"output_ids\": {:?}\n", entry.output_ids));
436            output.push_str("    }");
437        }
438
439        output.push_str("\n  ]\n");
440        output.push_str("}\n");
441
442        output
443    }
444
445    /// Export graph to GraphML format.
446    pub fn graph_to_graphml(graph: &EinsumGraph) -> String {
447        let mut output = String::new();
448
449        output.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n");
450        output.push_str("<graphml xmlns=\"http://graphml.graphdrawing.org/xmlns\">\n");
451        output.push_str("  <graph id=\"G\" edgedefault=\"directed\">\n");
452
453        // Nodes
454        for (node_idx, node) in graph.nodes.iter().enumerate() {
455            output.push_str(&format!("    <node id=\"n{}\">\n", node_idx));
456            output.push_str(&format!(
457                "      <data key=\"operation\">{:?}</data>\n",
458                node.op
459            ));
460            output.push_str("    </node>\n");
461        }
462
463        // Edges
464        for (node_idx, node) in graph.nodes.iter().enumerate() {
465            for input_id in &node.inputs {
466                output.push_str(&format!(
467                    "    <edge source=\"n{}\" target=\"n{}\"/>\n",
468                    input_id, node_idx
469                ));
470            }
471        }
472
473        output.push_str("  </graph>\n");
474        output.push_str("</graphml>\n");
475
476        output
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use crate::debug::{ExecutionTracer, TensorStats};
484    use std::collections::HashMap;
485    use std::time::Duration;
486
487    #[test]
488    fn test_timeline_visualizer() {
489        let mut tracer = ExecutionTracer::new();
490        tracer.enable();
491        tracer.start_trace(Some(1));
492
493        let handle = tracer.record_operation_start(0, "einsum", vec![]);
494        std::thread::sleep(Duration::from_millis(10));
495        tracer.record_operation_end(handle, 0, "einsum", vec![], vec![1], HashMap::new());
496
497        let trace = tracer.get_trace();
498        let visualizer = TimelineVisualizer::new(TimelineConfig::default());
499        let output = visualizer.visualize_trace(trace);
500
501        assert!(output.contains("Execution Timeline"));
502        assert!(output.contains("Node 0"));
503        assert!(output.contains("einsum"));
504    }
505
506    #[test]
507    fn test_graph_visualizer_ascii() {
508        use tensorlogic_ir::EinsumNode;
509
510        let graph = EinsumGraph {
511            tensors: vec!["input".to_string(), "output".to_string()],
512            nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
513            inputs: vec![0],
514            outputs: vec![1],
515            tensor_metadata: HashMap::new(),
516        };
517
518        let visualizer = GraphVisualizer::new(GraphConfig::default());
519        let output = visualizer.visualize_ascii(&graph);
520
521        assert!(output.contains("Computation Graph"));
522        assert!(output.contains("Node 0"));
523    }
524
525    #[test]
526    fn test_graph_visualizer_dot() {
527        use tensorlogic_ir::EinsumNode;
528
529        let graph = EinsumGraph {
530            tensors: vec!["input".to_string(), "output".to_string()],
531            nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
532            inputs: vec![0],
533            outputs: vec![1],
534            tensor_metadata: HashMap::new(),
535        };
536
537        let visualizer = GraphVisualizer::new(GraphConfig::default());
538        let output = visualizer.visualize_dot(&graph);
539
540        assert!(output.contains("digraph ComputationGraph"));
541        assert!(output.contains("n0"));
542    }
543
544    #[test]
545    fn test_graph_visualizer_json() {
546        use tensorlogic_ir::EinsumNode;
547
548        let graph = EinsumGraph {
549            tensors: vec!["input".to_string(), "output".to_string()],
550            nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
551            inputs: vec![0],
552            outputs: vec![1],
553            tensor_metadata: HashMap::new(),
554        };
555
556        let visualizer = GraphVisualizer::new(GraphConfig::default());
557        let output = visualizer.visualize_json(&graph);
558
559        assert!(output.contains("\"nodes\""));
560        assert!(output.contains("\"id\": 0"));
561    }
562
563    #[test]
564    fn test_tensor_stats_visualizer() {
565        let stats =
566            TensorStats::new(0, vec![2, 3], "f64").with_statistics(0.0, 1.0, 0.5, 0.25, 0, 0);
567
568        let visualizer = TensorStatsVisualizer;
569        let output = visualizer.visualize(&stats);
570
571        assert!(output.contains("Tensor 0"));
572        assert!(output.contains("f64"));
573    }
574
575    #[test]
576    fn test_tensor_stats_table() {
577        let stats = vec![
578            TensorStats::new(0, vec![2, 3], "f64"),
579            TensorStats::new(1, vec![4, 5], "f64"),
580        ];
581
582        let visualizer = TensorStatsVisualizer;
583        let output = visualizer.visualize_table(&stats);
584
585        assert!(output.contains("Tensor Statistics"));
586        assert!(output.contains("ID"));
587        assert!(output.contains("Shape"));
588    }
589
590    #[test]
591    fn test_histogram() {
592        let values = vec![1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0];
593        let visualizer = TensorStatsVisualizer;
594        let output = visualizer.histogram(&values, 5);
595
596        assert!(output.contains("Value Distribution"));
597        assert!(output.contains("█"));
598    }
599
600    #[test]
601    fn test_export_trace_to_json() {
602        let mut tracer = ExecutionTracer::new();
603        tracer.enable();
604        tracer.start_trace(Some(1));
605
606        let handle = tracer.record_operation_start(0, "einsum", vec![]);
607        tracer.record_operation_end(handle, 0, "einsum", vec![], vec![1], HashMap::new());
608
609        let trace = tracer.get_trace();
610        let json = ExportFormat::trace_to_json(trace);
611
612        assert!(json.contains("total_duration_ms"));
613        assert!(json.contains("entries"));
614        assert!(json.contains("\"operation\": \"einsum\""));
615    }
616
617    #[test]
618    fn test_export_graph_to_graphml() {
619        use tensorlogic_ir::EinsumNode;
620
621        let graph = EinsumGraph {
622            tensors: vec!["input".to_string(), "output".to_string()],
623            nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
624            inputs: vec![0],
625            outputs: vec![1],
626            tensor_metadata: HashMap::new(),
627        };
628
629        let graphml = ExportFormat::graph_to_graphml(&graph);
630
631        assert!(graphml.contains("<?xml"));
632        assert!(graphml.contains("<graphml"));
633        assert!(graphml.contains("<node id=\"n0\""));
634    }
635}