Skip to main content

tensorlogic_infer/
visualization.rs

1//! Visualization utilities for execution analysis and debugging.
2//!
3//! This module provides tools for visualizing:
4//! - Execution timelines
5//! - Computation graphs
6//! - Performance data
7//!
8//! Supports multiple output formats:
9//! - ASCII art for terminal display
10//! - GraphViz DOT format
11//! - JSON for custom rendering
12
13use crate::debug::{ExecutionTrace, TensorStats};
14use crate::profiling::ProfileData;
15use tensorlogic_ir::EinsumGraph;
16
17/// Visualization format options.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum VisualizationFormat {
20    /// ASCII art for terminal display
21    Ascii,
22    /// GraphViz DOT format
23    Dot,
24    /// JSON format
25    Json,
26    /// HTML format
27    Html,
28}
29
30/// Timeline visualization configuration.
31#[derive(Debug, Clone)]
32pub struct TimelineConfig {
33    /// Width in characters (for ASCII output)
34    pub width: usize,
35    /// Show operation names
36    pub show_names: bool,
37    /// Show timing information
38    pub show_timing: bool,
39    /// Group by operation type
40    pub group_by_type: bool,
41}
42
43impl Default for TimelineConfig {
44    fn default() -> Self {
45        Self {
46            width: 80,
47            show_names: true,
48            show_timing: true,
49            group_by_type: false,
50        }
51    }
52}
53
54/// Graph visualization configuration.
55#[derive(Debug, Clone)]
56pub struct GraphConfig {
57    /// Show tensor shapes
58    pub show_shapes: bool,
59    /// Show operation types
60    pub show_op_types: bool,
61    /// Highlight critical path
62    pub highlight_critical_path: bool,
63    /// Vertical or horizontal layout
64    pub vertical_layout: bool,
65}
66
67impl Default for GraphConfig {
68    fn default() -> Self {
69        Self {
70            show_shapes: true,
71            show_op_types: true,
72            highlight_critical_path: false,
73            vertical_layout: true,
74        }
75    }
76}
77
78/// Timeline visualizer.
79pub struct TimelineVisualizer {
80    config: TimelineConfig,
81}
82
83impl TimelineVisualizer {
84    /// Create a new timeline visualizer.
85    pub fn new(config: TimelineConfig) -> Self {
86        Self { config }
87    }
88
89    /// Visualize an execution trace as ASCII timeline.
90    pub fn visualize_trace(&self, trace: &ExecutionTrace) -> String {
91        let mut output = String::new();
92
93        // Header
94        output.push_str(&format!(
95            "Execution Timeline ({:.2}ms total)\n",
96            trace.total_duration_ms()
97        ));
98        output.push_str(&"=".repeat(self.config.width));
99        output.push('\n');
100
101        if trace.entries().is_empty() {
102            output.push_str("No operations recorded\n");
103            return output;
104        }
105
106        // Find time bounds
107        let start_time = trace.entries()[0].start_time;
108        let total_duration = trace.total_duration();
109
110        // Draw timeline
111        for entry in trace.entries() {
112            let elapsed = entry.start_time.duration_since(start_time);
113            let duration = entry.duration;
114
115            // Calculate bar position and width
116            let start_pos = (elapsed.as_secs_f64() / total_duration.as_secs_f64()
117                * self.config.width as f64) as usize;
118            let bar_width = ((duration.as_secs_f64() / total_duration.as_secs_f64()
119                * self.config.width as f64) as usize)
120                .max(1);
121
122            // Operation name
123            if self.config.show_names {
124                output.push_str(&format!("Node {}: {} ", entry.node_id, entry.operation));
125            }
126
127            if self.config.show_timing {
128                output.push_str(&format!("({:.2}ms)\n", entry.duration_ms()));
129            } else {
130                output.push('\n');
131            }
132
133            // Timeline bar
134            output.push_str(&" ".repeat(start_pos));
135            output.push_str(&"█".repeat(bar_width));
136            output.push('\n');
137        }
138
139        output.push_str(&"=".repeat(self.config.width));
140        output.push('\n');
141
142        output
143    }
144
145    /// Visualize profile data as text report.
146    pub fn visualize_profile(&self, profile: &ProfileData) -> String {
147        let mut output = String::new();
148
149        output.push_str("Performance Profile\n");
150        output.push_str(&"=".repeat(self.config.width));
151        output.push('\n');
152
153        // Sort operations by total time
154        let mut ops: Vec<_> = profile.op_profiles.iter().collect();
155        ops.sort_by(|(_, a), (_, b)| {
156            let a_total_ms = a.avg_time.as_secs_f64() * 1000.0 * a.count as f64;
157            let b_total_ms = b.avg_time.as_secs_f64() * 1000.0 * b.count as f64;
158            b_total_ms
159                .partial_cmp(&a_total_ms)
160                .unwrap_or(std::cmp::Ordering::Equal)
161        });
162
163        // Header
164        output.push_str(&format!(
165            "{:<30} {:>10} {:>10} {:>15}\n",
166            "Operation", "Count", "Avg (ms)", "Total (ms)"
167        ));
168        output.push_str(&"-".repeat(self.config.width));
169        output.push('\n');
170
171        // Operations
172        for (name, stats) in ops {
173            let avg_time_ms = stats.avg_time.as_secs_f64() * 1000.0;
174            let total_time_ms = avg_time_ms * stats.count as f64;
175            output.push_str(&format!(
176                "{:<30} {:>10} {:>10.2} {:>15.2}\n",
177                name, stats.count, avg_time_ms, total_time_ms
178            ));
179        }
180
181        output.push_str(&"=".repeat(self.config.width));
182        output.push('\n');
183
184        output
185    }
186}
187
188/// Graph visualizer.
189pub struct GraphVisualizer {
190    config: GraphConfig,
191}
192
193impl GraphVisualizer {
194    /// Create a new graph visualizer.
195    pub fn new(config: GraphConfig) -> Self {
196        Self { config }
197    }
198
199    /// Visualize a computation graph as ASCII art.
200    pub fn visualize_ascii(&self, graph: &EinsumGraph) -> String {
201        let mut output = String::new();
202
203        output.push_str("Computation Graph\n");
204        output.push_str("=================\n\n");
205
206        if graph.nodes.is_empty() {
207            output.push_str("Empty graph\n");
208            return output;
209        }
210
211        for (node_idx, node) in graph.nodes.iter().enumerate() {
212            // Node representation
213            output.push_str(&format!("Node {}:\n", node_idx));
214
215            // Operation type
216            if self.config.show_op_types {
217                output.push_str(&format!("  Op: {:?}\n", node.op));
218            }
219
220            // Inputs
221            if !node.inputs.is_empty() {
222                output.push_str("  Inputs: ");
223                for (i, input_id) in node.inputs.iter().enumerate() {
224                    if i > 0 {
225                        output.push_str(", ");
226                    }
227                    output.push_str(&format!("{}", input_id));
228                }
229                output.push('\n');
230            }
231
232            output.push('\n');
233        }
234
235        output
236    }
237
238    /// Generate GraphViz DOT format.
239    pub fn visualize_dot(&self, graph: &EinsumGraph) -> String {
240        let mut output = String::new();
241
242        output.push_str("digraph ComputationGraph {\n");
243        output.push_str("  rankdir=TB;\n");
244        output.push_str("  node [shape=box, style=rounded];\n\n");
245
246        // Nodes
247        for (node_idx, node) in graph.nodes.iter().enumerate() {
248            let label = format!("Node {}\\n{:?}", node_idx, node.op);
249            output.push_str(&format!("  n{} [label=\"{}\"];\n", node_idx, label));
250        }
251
252        output.push('\n');
253
254        // Edges
255        for (node_idx, node) in graph.nodes.iter().enumerate() {
256            for input_id in &node.inputs {
257                output.push_str(&format!("  n{} -> n{};\n", input_id, node_idx));
258            }
259        }
260
261        output.push_str("}\n");
262
263        output
264    }
265
266    /// Generate JSON representation.
267    pub fn visualize_json(&self, graph: &EinsumGraph) -> String {
268        let mut output = String::new();
269
270        output.push_str("{\n");
271        output.push_str("  \"nodes\": [\n");
272
273        for (node_idx, node) in graph.nodes.iter().enumerate() {
274            if node_idx > 0 {
275                output.push_str(",\n");
276            }
277            output.push_str("    {\n");
278            output.push_str(&format!("      \"id\": {},\n", node_idx));
279            output.push_str(&format!("      \"op\": \"{:?}\",\n", node.op));
280            output.push_str("      \"inputs\": [");
281
282            for (j, input_id) in node.inputs.iter().enumerate() {
283                if j > 0 {
284                    output.push_str(", ");
285                }
286                output.push_str(&format!("{}", input_id));
287            }
288
289            output.push_str("]\n");
290            output.push_str("    }");
291        }
292
293        output.push_str("\n  ]\n");
294        output.push_str("}\n");
295
296        output
297    }
298}
299
300/// Tensor statistics visualizer.
301pub struct TensorStatsVisualizer;
302
303impl TensorStatsVisualizer {
304    /// Visualize tensor statistics as a text report.
305    pub fn visualize(&self, stats: &TensorStats) -> String {
306        format!("{}", stats)
307    }
308
309    /// Visualize multiple tensor statistics as a table.
310    pub fn visualize_table(&self, stats: &[TensorStats]) -> String {
311        let mut output = String::new();
312
313        output.push_str("Tensor Statistics\n");
314        output.push_str(&"=".repeat(80));
315        output.push('\n');
316
317        if stats.is_empty() {
318            output.push_str("No tensors recorded\n");
319            return output;
320        }
321
322        // Header
323        output.push_str(&format!(
324            "{:<8} {:<20} {:<15} {:>10} {:>10}\n",
325            "ID", "Shape", "DType", "NaNs", "Infs"
326        ));
327        output.push_str(&"-".repeat(80));
328        output.push('\n');
329
330        // Rows
331        for stat in stats {
332            let shape_str = format!("{:?}", stat.shape);
333            let nans = stat.num_nans.unwrap_or(0);
334            let infs = stat.num_infs.unwrap_or(0);
335
336            output.push_str(&format!(
337                "{:<8} {:<20} {:<15} {:>10} {:>10}\n",
338                stat.tensor_id, shape_str, stat.dtype, nans, infs
339            ));
340
341            // Highlight issues
342            if stat.has_numerical_issues() {
343                output.push_str("         ⚠️  Numerical issues detected!\n");
344            }
345        }
346
347        output.push_str(&"=".repeat(80));
348        output.push('\n');
349
350        output
351    }
352
353    /// Generate histogram of tensor values (ASCII).
354    pub fn histogram(&self, values: &[f64], bins: usize) -> String {
355        let mut output = String::new();
356
357        if values.is_empty() {
358            return "No values\n".to_string();
359        }
360
361        let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
362        let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
363        let range = max - min;
364
365        if range == 0.0 {
366            return format!("All values are {:.6}\n", min);
367        }
368
369        // Create bins
370        let mut counts = vec![0; bins];
371        for &value in values {
372            let bin = ((value - min) / range * bins as f64) as usize;
373            let bin = bin.min(bins - 1);
374            counts[bin] += 1;
375        }
376
377        let max_count = *counts
378            .iter()
379            .max()
380            .expect("counts has bins elements, so max always exists");
381
382        // Draw histogram
383        output.push_str("Value Distribution\n");
384        output.push_str(&"=".repeat(50));
385        output.push('\n');
386
387        for (i, &count) in counts.iter().enumerate() {
388            let bin_start = min + (i as f64 / bins as f64) * range;
389            let bin_end = min + ((i + 1) as f64 / bins as f64) * range;
390            let bar_width = if max_count > 0 {
391                (count as f64 / max_count as f64 * 40.0) as usize
392            } else {
393                0
394            };
395
396            output.push_str(&format!(
397                "[{:>8.2}, {:>8.2}): {} ({})\n",
398                bin_start,
399                bin_end,
400                "█".repeat(bar_width),
401                count
402            ));
403        }
404
405        output.push_str(&"=".repeat(50));
406        output.push('\n');
407
408        output
409    }
410}
411
412/// Export formats for external visualization tools.
413pub struct ExportFormat;
414
415impl ExportFormat {
416    /// Export trace to JSON for custom visualization.
417    pub fn trace_to_json(trace: &ExecutionTrace) -> String {
418        let mut output = String::new();
419
420        output.push_str("{\n");
421        output.push_str(&format!(
422            "  \"total_duration_ms\": {},\n",
423            trace.total_duration_ms()
424        ));
425        output.push_str("  \"entries\": [\n");
426
427        for (i, entry) in trace.entries().iter().enumerate() {
428            if i > 0 {
429                output.push_str(",\n");
430            }
431            output.push_str("    {\n");
432            output.push_str(&format!("      \"entry_id\": {},\n", entry.entry_id));
433            output.push_str(&format!("      \"node_id\": {},\n", entry.node_id));
434            output.push_str(&format!("      \"operation\": \"{}\",\n", entry.operation));
435            output.push_str(&format!(
436                "      \"duration_ms\": {},\n",
437                entry.duration_ms()
438            ));
439            output.push_str(&format!("      \"input_ids\": {:?},\n", entry.input_ids));
440            output.push_str(&format!("      \"output_ids\": {:?}\n", entry.output_ids));
441            output.push_str("    }");
442        }
443
444        output.push_str("\n  ]\n");
445        output.push_str("}\n");
446
447        output
448    }
449
450    /// Export graph to GraphML format.
451    pub fn graph_to_graphml(graph: &EinsumGraph) -> String {
452        let mut output = String::new();
453
454        output.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n");
455        output.push_str("<graphml xmlns=\"http://graphml.graphdrawing.org/xmlns\">\n");
456        output.push_str("  <graph id=\"G\" edgedefault=\"directed\">\n");
457
458        // Nodes
459        for (node_idx, node) in graph.nodes.iter().enumerate() {
460            output.push_str(&format!("    <node id=\"n{}\">\n", node_idx));
461            output.push_str(&format!(
462                "      <data key=\"operation\">{:?}</data>\n",
463                node.op
464            ));
465            output.push_str("    </node>\n");
466        }
467
468        // Edges
469        for (node_idx, node) in graph.nodes.iter().enumerate() {
470            for input_id in &node.inputs {
471                output.push_str(&format!(
472                    "    <edge source=\"n{}\" target=\"n{}\"/>\n",
473                    input_id, node_idx
474                ));
475            }
476        }
477
478        output.push_str("  </graph>\n");
479        output.push_str("</graphml>\n");
480
481        output
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use crate::debug::{ExecutionTracer, TensorStats};
489    use std::collections::HashMap;
490    use std::time::Duration;
491
492    #[test]
493    fn test_timeline_visualizer() {
494        let mut tracer = ExecutionTracer::new();
495        tracer.enable();
496        tracer.start_trace(Some(1));
497
498        let handle = tracer.record_operation_start(0, "einsum", vec![]);
499        std::thread::sleep(Duration::from_millis(10));
500        tracer.record_operation_end(handle, 0, "einsum", vec![], vec![1], HashMap::new());
501
502        let trace = tracer.get_trace();
503        let visualizer = TimelineVisualizer::new(TimelineConfig::default());
504        let output = visualizer.visualize_trace(trace);
505
506        assert!(output.contains("Execution Timeline"));
507        assert!(output.contains("Node 0"));
508        assert!(output.contains("einsum"));
509    }
510
511    #[test]
512    fn test_graph_visualizer_ascii() {
513        use tensorlogic_ir::EinsumNode;
514
515        let graph = EinsumGraph {
516            tensors: vec!["input".to_string(), "output".to_string()],
517            nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
518            inputs: vec![0],
519            outputs: vec![1],
520            tensor_metadata: HashMap::new(),
521        };
522
523        let visualizer = GraphVisualizer::new(GraphConfig::default());
524        let output = visualizer.visualize_ascii(&graph);
525
526        assert!(output.contains("Computation Graph"));
527        assert!(output.contains("Node 0"));
528    }
529
530    #[test]
531    fn test_graph_visualizer_dot() {
532        use tensorlogic_ir::EinsumNode;
533
534        let graph = EinsumGraph {
535            tensors: vec!["input".to_string(), "output".to_string()],
536            nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
537            inputs: vec![0],
538            outputs: vec![1],
539            tensor_metadata: HashMap::new(),
540        };
541
542        let visualizer = GraphVisualizer::new(GraphConfig::default());
543        let output = visualizer.visualize_dot(&graph);
544
545        assert!(output.contains("digraph ComputationGraph"));
546        assert!(output.contains("n0"));
547    }
548
549    #[test]
550    fn test_graph_visualizer_json() {
551        use tensorlogic_ir::EinsumNode;
552
553        let graph = EinsumGraph {
554            tensors: vec!["input".to_string(), "output".to_string()],
555            nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
556            inputs: vec![0],
557            outputs: vec![1],
558            tensor_metadata: HashMap::new(),
559        };
560
561        let visualizer = GraphVisualizer::new(GraphConfig::default());
562        let output = visualizer.visualize_json(&graph);
563
564        assert!(output.contains("\"nodes\""));
565        assert!(output.contains("\"id\": 0"));
566    }
567
568    #[test]
569    fn test_tensor_stats_visualizer() {
570        let stats =
571            TensorStats::new(0, vec![2, 3], "f64").with_statistics(0.0, 1.0, 0.5, 0.25, 0, 0);
572
573        let visualizer = TensorStatsVisualizer;
574        let output = visualizer.visualize(&stats);
575
576        assert!(output.contains("Tensor 0"));
577        assert!(output.contains("f64"));
578    }
579
580    #[test]
581    fn test_tensor_stats_table() {
582        let stats = vec![
583            TensorStats::new(0, vec![2, 3], "f64"),
584            TensorStats::new(1, vec![4, 5], "f64"),
585        ];
586
587        let visualizer = TensorStatsVisualizer;
588        let output = visualizer.visualize_table(&stats);
589
590        assert!(output.contains("Tensor Statistics"));
591        assert!(output.contains("ID"));
592        assert!(output.contains("Shape"));
593    }
594
595    #[test]
596    fn test_histogram() {
597        let values = vec![1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0];
598        let visualizer = TensorStatsVisualizer;
599        let output = visualizer.histogram(&values, 5);
600
601        assert!(output.contains("Value Distribution"));
602        assert!(output.contains("█"));
603    }
604
605    #[test]
606    fn test_export_trace_to_json() {
607        let mut tracer = ExecutionTracer::new();
608        tracer.enable();
609        tracer.start_trace(Some(1));
610
611        let handle = tracer.record_operation_start(0, "einsum", vec![]);
612        tracer.record_operation_end(handle, 0, "einsum", vec![], vec![1], HashMap::new());
613
614        let trace = tracer.get_trace();
615        let json = ExportFormat::trace_to_json(trace);
616
617        assert!(json.contains("total_duration_ms"));
618        assert!(json.contains("entries"));
619        assert!(json.contains("\"operation\": \"einsum\""));
620    }
621
622    #[test]
623    fn test_export_graph_to_graphml() {
624        use tensorlogic_ir::EinsumNode;
625
626        let graph = EinsumGraph {
627            tensors: vec!["input".to_string(), "output".to_string()],
628            nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
629            inputs: vec![0],
630            outputs: vec![1],
631            tensor_metadata: HashMap::new(),
632        };
633
634        let graphml = ExportFormat::graph_to_graphml(&graph);
635
636        assert!(graphml.contains("<?xml"));
637        assert!(graphml.contains("<graphml"));
638        assert!(graphml.contains("<node id=\"n0\""));
639    }
640}