1use crate::debug::{ExecutionTrace, TensorStats};
14use crate::profiling::ProfileData;
15use tensorlogic_ir::EinsumGraph;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum VisualizationFormat {
20 Ascii,
22 Dot,
24 Json,
26 Html,
28}
29
30#[derive(Debug, Clone)]
32pub struct TimelineConfig {
33 pub width: usize,
35 pub show_names: bool,
37 pub show_timing: bool,
39 pub group_by_type: bool,
41}
42
43impl Default for TimelineConfig {
44 fn default() -> Self {
45 Self {
46 width: 80,
47 show_names: true,
48 show_timing: true,
49 group_by_type: false,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct GraphConfig {
57 pub show_shapes: bool,
59 pub show_op_types: bool,
61 pub highlight_critical_path: bool,
63 pub vertical_layout: bool,
65}
66
67impl Default for GraphConfig {
68 fn default() -> Self {
69 Self {
70 show_shapes: true,
71 show_op_types: true,
72 highlight_critical_path: false,
73 vertical_layout: true,
74 }
75 }
76}
77
78pub struct TimelineVisualizer {
80 config: TimelineConfig,
81}
82
83impl TimelineVisualizer {
84 pub fn new(config: TimelineConfig) -> Self {
86 Self { config }
87 }
88
89 pub fn visualize_trace(&self, trace: &ExecutionTrace) -> String {
91 let mut output = String::new();
92
93 output.push_str(&format!(
95 "Execution Timeline ({:.2}ms total)\n",
96 trace.total_duration_ms()
97 ));
98 output.push_str(&"=".repeat(self.config.width));
99 output.push('\n');
100
101 if trace.entries().is_empty() {
102 output.push_str("No operations recorded\n");
103 return output;
104 }
105
106 let start_time = trace.entries()[0].start_time;
108 let total_duration = trace.total_duration();
109
110 for entry in trace.entries() {
112 let elapsed = entry.start_time.duration_since(start_time);
113 let duration = entry.duration;
114
115 let start_pos = (elapsed.as_secs_f64() / total_duration.as_secs_f64()
117 * self.config.width as f64) as usize;
118 let bar_width = ((duration.as_secs_f64() / total_duration.as_secs_f64()
119 * self.config.width as f64) as usize)
120 .max(1);
121
122 if self.config.show_names {
124 output.push_str(&format!("Node {}: {} ", entry.node_id, entry.operation));
125 }
126
127 if self.config.show_timing {
128 output.push_str(&format!("({:.2}ms)\n", entry.duration_ms()));
129 } else {
130 output.push('\n');
131 }
132
133 output.push_str(&" ".repeat(start_pos));
135 output.push_str(&"█".repeat(bar_width));
136 output.push('\n');
137 }
138
139 output.push_str(&"=".repeat(self.config.width));
140 output.push('\n');
141
142 output
143 }
144
145 pub fn visualize_profile(&self, profile: &ProfileData) -> String {
147 let mut output = String::new();
148
149 output.push_str("Performance Profile\n");
150 output.push_str(&"=".repeat(self.config.width));
151 output.push('\n');
152
153 let mut ops: Vec<_> = profile.op_profiles.iter().collect();
155 ops.sort_by(|(_, a), (_, b)| {
156 let a_total_ms = a.avg_time.as_secs_f64() * 1000.0 * a.count as f64;
157 let b_total_ms = b.avg_time.as_secs_f64() * 1000.0 * b.count as f64;
158 b_total_ms.partial_cmp(&a_total_ms).unwrap()
159 });
160
161 output.push_str(&format!(
163 "{:<30} {:>10} {:>10} {:>15}\n",
164 "Operation", "Count", "Avg (ms)", "Total (ms)"
165 ));
166 output.push_str(&"-".repeat(self.config.width));
167 output.push('\n');
168
169 for (name, stats) in ops {
171 let avg_time_ms = stats.avg_time.as_secs_f64() * 1000.0;
172 let total_time_ms = avg_time_ms * stats.count as f64;
173 output.push_str(&format!(
174 "{:<30} {:>10} {:>10.2} {:>15.2}\n",
175 name, stats.count, avg_time_ms, total_time_ms
176 ));
177 }
178
179 output.push_str(&"=".repeat(self.config.width));
180 output.push('\n');
181
182 output
183 }
184}
185
186pub struct GraphVisualizer {
188 config: GraphConfig,
189}
190
191impl GraphVisualizer {
192 pub fn new(config: GraphConfig) -> Self {
194 Self { config }
195 }
196
197 pub fn visualize_ascii(&self, graph: &EinsumGraph) -> String {
199 let mut output = String::new();
200
201 output.push_str("Computation Graph\n");
202 output.push_str("=================\n\n");
203
204 if graph.nodes.is_empty() {
205 output.push_str("Empty graph\n");
206 return output;
207 }
208
209 for (node_idx, node) in graph.nodes.iter().enumerate() {
210 output.push_str(&format!("Node {}:\n", node_idx));
212
213 if self.config.show_op_types {
215 output.push_str(&format!(" Op: {:?}\n", node.op));
216 }
217
218 if !node.inputs.is_empty() {
220 output.push_str(" Inputs: ");
221 for (i, input_id) in node.inputs.iter().enumerate() {
222 if i > 0 {
223 output.push_str(", ");
224 }
225 output.push_str(&format!("{}", input_id));
226 }
227 output.push('\n');
228 }
229
230 output.push('\n');
231 }
232
233 output
234 }
235
236 pub fn visualize_dot(&self, graph: &EinsumGraph) -> String {
238 let mut output = String::new();
239
240 output.push_str("digraph ComputationGraph {\n");
241 output.push_str(" rankdir=TB;\n");
242 output.push_str(" node [shape=box, style=rounded];\n\n");
243
244 for (node_idx, node) in graph.nodes.iter().enumerate() {
246 let label = format!("Node {}\\n{:?}", node_idx, node.op);
247 output.push_str(&format!(" n{} [label=\"{}\"];\n", node_idx, label));
248 }
249
250 output.push('\n');
251
252 for (node_idx, node) in graph.nodes.iter().enumerate() {
254 for input_id in &node.inputs {
255 output.push_str(&format!(" n{} -> n{};\n", input_id, node_idx));
256 }
257 }
258
259 output.push_str("}\n");
260
261 output
262 }
263
264 pub fn visualize_json(&self, graph: &EinsumGraph) -> String {
266 let mut output = String::new();
267
268 output.push_str("{\n");
269 output.push_str(" \"nodes\": [\n");
270
271 for (node_idx, node) in graph.nodes.iter().enumerate() {
272 if node_idx > 0 {
273 output.push_str(",\n");
274 }
275 output.push_str(" {\n");
276 output.push_str(&format!(" \"id\": {},\n", node_idx));
277 output.push_str(&format!(" \"op\": \"{:?}\",\n", node.op));
278 output.push_str(" \"inputs\": [");
279
280 for (j, input_id) in node.inputs.iter().enumerate() {
281 if j > 0 {
282 output.push_str(", ");
283 }
284 output.push_str(&format!("{}", input_id));
285 }
286
287 output.push_str("]\n");
288 output.push_str(" }");
289 }
290
291 output.push_str("\n ]\n");
292 output.push_str("}\n");
293
294 output
295 }
296}
297
298pub struct TensorStatsVisualizer;
300
301impl TensorStatsVisualizer {
302 pub fn visualize(&self, stats: &TensorStats) -> String {
304 format!("{}", stats)
305 }
306
307 pub fn visualize_table(&self, stats: &[TensorStats]) -> String {
309 let mut output = String::new();
310
311 output.push_str("Tensor Statistics\n");
312 output.push_str(&"=".repeat(80));
313 output.push('\n');
314
315 if stats.is_empty() {
316 output.push_str("No tensors recorded\n");
317 return output;
318 }
319
320 output.push_str(&format!(
322 "{:<8} {:<20} {:<15} {:>10} {:>10}\n",
323 "ID", "Shape", "DType", "NaNs", "Infs"
324 ));
325 output.push_str(&"-".repeat(80));
326 output.push('\n');
327
328 for stat in stats {
330 let shape_str = format!("{:?}", stat.shape);
331 let nans = stat.num_nans.unwrap_or(0);
332 let infs = stat.num_infs.unwrap_or(0);
333
334 output.push_str(&format!(
335 "{:<8} {:<20} {:<15} {:>10} {:>10}\n",
336 stat.tensor_id, shape_str, stat.dtype, nans, infs
337 ));
338
339 if stat.has_numerical_issues() {
341 output.push_str(" ⚠️ Numerical issues detected!\n");
342 }
343 }
344
345 output.push_str(&"=".repeat(80));
346 output.push('\n');
347
348 output
349 }
350
351 pub fn histogram(&self, values: &[f64], bins: usize) -> String {
353 let mut output = String::new();
354
355 if values.is_empty() {
356 return "No values\n".to_string();
357 }
358
359 let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
360 let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
361 let range = max - min;
362
363 if range == 0.0 {
364 return format!("All values are {:.6}\n", min);
365 }
366
367 let mut counts = vec![0; bins];
369 for &value in values {
370 let bin = ((value - min) / range * bins as f64) as usize;
371 let bin = bin.min(bins - 1);
372 counts[bin] += 1;
373 }
374
375 let max_count = *counts.iter().max().unwrap();
376
377 output.push_str("Value Distribution\n");
379 output.push_str(&"=".repeat(50));
380 output.push('\n');
381
382 for (i, &count) in counts.iter().enumerate() {
383 let bin_start = min + (i as f64 / bins as f64) * range;
384 let bin_end = min + ((i + 1) as f64 / bins as f64) * range;
385 let bar_width = if max_count > 0 {
386 (count as f64 / max_count as f64 * 40.0) as usize
387 } else {
388 0
389 };
390
391 output.push_str(&format!(
392 "[{:>8.2}, {:>8.2}): {} ({})\n",
393 bin_start,
394 bin_end,
395 "█".repeat(bar_width),
396 count
397 ));
398 }
399
400 output.push_str(&"=".repeat(50));
401 output.push('\n');
402
403 output
404 }
405}
406
407pub struct ExportFormat;
409
410impl ExportFormat {
411 pub fn trace_to_json(trace: &ExecutionTrace) -> String {
413 let mut output = String::new();
414
415 output.push_str("{\n");
416 output.push_str(&format!(
417 " \"total_duration_ms\": {},\n",
418 trace.total_duration_ms()
419 ));
420 output.push_str(" \"entries\": [\n");
421
422 for (i, entry) in trace.entries().iter().enumerate() {
423 if i > 0 {
424 output.push_str(",\n");
425 }
426 output.push_str(" {\n");
427 output.push_str(&format!(" \"entry_id\": {},\n", entry.entry_id));
428 output.push_str(&format!(" \"node_id\": {},\n", entry.node_id));
429 output.push_str(&format!(" \"operation\": \"{}\",\n", entry.operation));
430 output.push_str(&format!(
431 " \"duration_ms\": {},\n",
432 entry.duration_ms()
433 ));
434 output.push_str(&format!(" \"input_ids\": {:?},\n", entry.input_ids));
435 output.push_str(&format!(" \"output_ids\": {:?}\n", entry.output_ids));
436 output.push_str(" }");
437 }
438
439 output.push_str("\n ]\n");
440 output.push_str("}\n");
441
442 output
443 }
444
445 pub fn graph_to_graphml(graph: &EinsumGraph) -> String {
447 let mut output = String::new();
448
449 output.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n");
450 output.push_str("<graphml xmlns=\"http://graphml.graphdrawing.org/xmlns\">\n");
451 output.push_str(" <graph id=\"G\" edgedefault=\"directed\">\n");
452
453 for (node_idx, node) in graph.nodes.iter().enumerate() {
455 output.push_str(&format!(" <node id=\"n{}\">\n", node_idx));
456 output.push_str(&format!(
457 " <data key=\"operation\">{:?}</data>\n",
458 node.op
459 ));
460 output.push_str(" </node>\n");
461 }
462
463 for (node_idx, node) in graph.nodes.iter().enumerate() {
465 for input_id in &node.inputs {
466 output.push_str(&format!(
467 " <edge source=\"n{}\" target=\"n{}\"/>\n",
468 input_id, node_idx
469 ));
470 }
471 }
472
473 output.push_str(" </graph>\n");
474 output.push_str("</graphml>\n");
475
476 output
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use crate::debug::{ExecutionTracer, TensorStats};
484 use std::collections::HashMap;
485 use std::time::Duration;
486
487 #[test]
488 fn test_timeline_visualizer() {
489 let mut tracer = ExecutionTracer::new();
490 tracer.enable();
491 tracer.start_trace(Some(1));
492
493 let handle = tracer.record_operation_start(0, "einsum", vec![]);
494 std::thread::sleep(Duration::from_millis(10));
495 tracer.record_operation_end(handle, 0, "einsum", vec![], vec![1], HashMap::new());
496
497 let trace = tracer.get_trace();
498 let visualizer = TimelineVisualizer::new(TimelineConfig::default());
499 let output = visualizer.visualize_trace(trace);
500
501 assert!(output.contains("Execution Timeline"));
502 assert!(output.contains("Node 0"));
503 assert!(output.contains("einsum"));
504 }
505
506 #[test]
507 fn test_graph_visualizer_ascii() {
508 use tensorlogic_ir::EinsumNode;
509
510 let graph = EinsumGraph {
511 tensors: vec!["input".to_string(), "output".to_string()],
512 nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
513 inputs: vec![0],
514 outputs: vec![1],
515 tensor_metadata: HashMap::new(),
516 };
517
518 let visualizer = GraphVisualizer::new(GraphConfig::default());
519 let output = visualizer.visualize_ascii(&graph);
520
521 assert!(output.contains("Computation Graph"));
522 assert!(output.contains("Node 0"));
523 }
524
525 #[test]
526 fn test_graph_visualizer_dot() {
527 use tensorlogic_ir::EinsumNode;
528
529 let graph = EinsumGraph {
530 tensors: vec!["input".to_string(), "output".to_string()],
531 nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
532 inputs: vec![0],
533 outputs: vec![1],
534 tensor_metadata: HashMap::new(),
535 };
536
537 let visualizer = GraphVisualizer::new(GraphConfig::default());
538 let output = visualizer.visualize_dot(&graph);
539
540 assert!(output.contains("digraph ComputationGraph"));
541 assert!(output.contains("n0"));
542 }
543
544 #[test]
545 fn test_graph_visualizer_json() {
546 use tensorlogic_ir::EinsumNode;
547
548 let graph = EinsumGraph {
549 tensors: vec!["input".to_string(), "output".to_string()],
550 nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
551 inputs: vec![0],
552 outputs: vec![1],
553 tensor_metadata: HashMap::new(),
554 };
555
556 let visualizer = GraphVisualizer::new(GraphConfig::default());
557 let output = visualizer.visualize_json(&graph);
558
559 assert!(output.contains("\"nodes\""));
560 assert!(output.contains("\"id\": 0"));
561 }
562
563 #[test]
564 fn test_tensor_stats_visualizer() {
565 let stats =
566 TensorStats::new(0, vec![2, 3], "f64").with_statistics(0.0, 1.0, 0.5, 0.25, 0, 0);
567
568 let visualizer = TensorStatsVisualizer;
569 let output = visualizer.visualize(&stats);
570
571 assert!(output.contains("Tensor 0"));
572 assert!(output.contains("f64"));
573 }
574
575 #[test]
576 fn test_tensor_stats_table() {
577 let stats = vec![
578 TensorStats::new(0, vec![2, 3], "f64"),
579 TensorStats::new(1, vec![4, 5], "f64"),
580 ];
581
582 let visualizer = TensorStatsVisualizer;
583 let output = visualizer.visualize_table(&stats);
584
585 assert!(output.contains("Tensor Statistics"));
586 assert!(output.contains("ID"));
587 assert!(output.contains("Shape"));
588 }
589
590 #[test]
591 fn test_histogram() {
592 let values = vec![1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0];
593 let visualizer = TensorStatsVisualizer;
594 let output = visualizer.histogram(&values, 5);
595
596 assert!(output.contains("Value Distribution"));
597 assert!(output.contains("█"));
598 }
599
600 #[test]
601 fn test_export_trace_to_json() {
602 let mut tracer = ExecutionTracer::new();
603 tracer.enable();
604 tracer.start_trace(Some(1));
605
606 let handle = tracer.record_operation_start(0, "einsum", vec![]);
607 tracer.record_operation_end(handle, 0, "einsum", vec![], vec![1], HashMap::new());
608
609 let trace = tracer.get_trace();
610 let json = ExportFormat::trace_to_json(trace);
611
612 assert!(json.contains("total_duration_ms"));
613 assert!(json.contains("entries"));
614 assert!(json.contains("\"operation\": \"einsum\""));
615 }
616
617 #[test]
618 fn test_export_graph_to_graphml() {
619 use tensorlogic_ir::EinsumNode;
620
621 let graph = EinsumGraph {
622 tensors: vec!["input".to_string(), "output".to_string()],
623 nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
624 inputs: vec![0],
625 outputs: vec![1],
626 tensor_metadata: HashMap::new(),
627 };
628
629 let graphml = ExportFormat::graph_to_graphml(&graph);
630
631 assert!(graphml.contains("<?xml"));
632 assert!(graphml.contains("<graphml"));
633 assert!(graphml.contains("<node id=\"n0\""));
634 }
635}