1use crate::debug::{ExecutionTrace, TensorStats};
14use crate::profiling::ProfileData;
15use tensorlogic_ir::EinsumGraph;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum VisualizationFormat {
20 Ascii,
22 Dot,
24 Json,
26 Html,
28}
29
30#[derive(Debug, Clone)]
32pub struct TimelineConfig {
33 pub width: usize,
35 pub show_names: bool,
37 pub show_timing: bool,
39 pub group_by_type: bool,
41}
42
43impl Default for TimelineConfig {
44 fn default() -> Self {
45 Self {
46 width: 80,
47 show_names: true,
48 show_timing: true,
49 group_by_type: false,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct GraphConfig {
57 pub show_shapes: bool,
59 pub show_op_types: bool,
61 pub highlight_critical_path: bool,
63 pub vertical_layout: bool,
65}
66
67impl Default for GraphConfig {
68 fn default() -> Self {
69 Self {
70 show_shapes: true,
71 show_op_types: true,
72 highlight_critical_path: false,
73 vertical_layout: true,
74 }
75 }
76}
77
78pub struct TimelineVisualizer {
80 config: TimelineConfig,
81}
82
83impl TimelineVisualizer {
84 pub fn new(config: TimelineConfig) -> Self {
86 Self { config }
87 }
88
89 pub fn visualize_trace(&self, trace: &ExecutionTrace) -> String {
91 let mut output = String::new();
92
93 output.push_str(&format!(
95 "Execution Timeline ({:.2}ms total)\n",
96 trace.total_duration_ms()
97 ));
98 output.push_str(&"=".repeat(self.config.width));
99 output.push('\n');
100
101 if trace.entries().is_empty() {
102 output.push_str("No operations recorded\n");
103 return output;
104 }
105
106 let start_time = trace.entries()[0].start_time;
108 let total_duration = trace.total_duration();
109
110 for entry in trace.entries() {
112 let elapsed = entry.start_time.duration_since(start_time);
113 let duration = entry.duration;
114
115 let start_pos = (elapsed.as_secs_f64() / total_duration.as_secs_f64()
117 * self.config.width as f64) as usize;
118 let bar_width = ((duration.as_secs_f64() / total_duration.as_secs_f64()
119 * self.config.width as f64) as usize)
120 .max(1);
121
122 if self.config.show_names {
124 output.push_str(&format!("Node {}: {} ", entry.node_id, entry.operation));
125 }
126
127 if self.config.show_timing {
128 output.push_str(&format!("({:.2}ms)\n", entry.duration_ms()));
129 } else {
130 output.push('\n');
131 }
132
133 output.push_str(&" ".repeat(start_pos));
135 output.push_str(&"█".repeat(bar_width));
136 output.push('\n');
137 }
138
139 output.push_str(&"=".repeat(self.config.width));
140 output.push('\n');
141
142 output
143 }
144
145 pub fn visualize_profile(&self, profile: &ProfileData) -> String {
147 let mut output = String::new();
148
149 output.push_str("Performance Profile\n");
150 output.push_str(&"=".repeat(self.config.width));
151 output.push('\n');
152
153 let mut ops: Vec<_> = profile.op_profiles.iter().collect();
155 ops.sort_by(|(_, a), (_, b)| {
156 let a_total_ms = a.avg_time.as_secs_f64() * 1000.0 * a.count as f64;
157 let b_total_ms = b.avg_time.as_secs_f64() * 1000.0 * b.count as f64;
158 b_total_ms
159 .partial_cmp(&a_total_ms)
160 .unwrap_or(std::cmp::Ordering::Equal)
161 });
162
163 output.push_str(&format!(
165 "{:<30} {:>10} {:>10} {:>15}\n",
166 "Operation", "Count", "Avg (ms)", "Total (ms)"
167 ));
168 output.push_str(&"-".repeat(self.config.width));
169 output.push('\n');
170
171 for (name, stats) in ops {
173 let avg_time_ms = stats.avg_time.as_secs_f64() * 1000.0;
174 let total_time_ms = avg_time_ms * stats.count as f64;
175 output.push_str(&format!(
176 "{:<30} {:>10} {:>10.2} {:>15.2}\n",
177 name, stats.count, avg_time_ms, total_time_ms
178 ));
179 }
180
181 output.push_str(&"=".repeat(self.config.width));
182 output.push('\n');
183
184 output
185 }
186}
187
188pub struct GraphVisualizer {
190 config: GraphConfig,
191}
192
193impl GraphVisualizer {
194 pub fn new(config: GraphConfig) -> Self {
196 Self { config }
197 }
198
199 pub fn visualize_ascii(&self, graph: &EinsumGraph) -> String {
201 let mut output = String::new();
202
203 output.push_str("Computation Graph\n");
204 output.push_str("=================\n\n");
205
206 if graph.nodes.is_empty() {
207 output.push_str("Empty graph\n");
208 return output;
209 }
210
211 for (node_idx, node) in graph.nodes.iter().enumerate() {
212 output.push_str(&format!("Node {}:\n", node_idx));
214
215 if self.config.show_op_types {
217 output.push_str(&format!(" Op: {:?}\n", node.op));
218 }
219
220 if !node.inputs.is_empty() {
222 output.push_str(" Inputs: ");
223 for (i, input_id) in node.inputs.iter().enumerate() {
224 if i > 0 {
225 output.push_str(", ");
226 }
227 output.push_str(&format!("{}", input_id));
228 }
229 output.push('\n');
230 }
231
232 output.push('\n');
233 }
234
235 output
236 }
237
238 pub fn visualize_dot(&self, graph: &EinsumGraph) -> String {
240 let mut output = String::new();
241
242 output.push_str("digraph ComputationGraph {\n");
243 output.push_str(" rankdir=TB;\n");
244 output.push_str(" node [shape=box, style=rounded];\n\n");
245
246 for (node_idx, node) in graph.nodes.iter().enumerate() {
248 let label = format!("Node {}\\n{:?}", node_idx, node.op);
249 output.push_str(&format!(" n{} [label=\"{}\"];\n", node_idx, label));
250 }
251
252 output.push('\n');
253
254 for (node_idx, node) in graph.nodes.iter().enumerate() {
256 for input_id in &node.inputs {
257 output.push_str(&format!(" n{} -> n{};\n", input_id, node_idx));
258 }
259 }
260
261 output.push_str("}\n");
262
263 output
264 }
265
266 pub fn visualize_json(&self, graph: &EinsumGraph) -> String {
268 let mut output = String::new();
269
270 output.push_str("{\n");
271 output.push_str(" \"nodes\": [\n");
272
273 for (node_idx, node) in graph.nodes.iter().enumerate() {
274 if node_idx > 0 {
275 output.push_str(",\n");
276 }
277 output.push_str(" {\n");
278 output.push_str(&format!(" \"id\": {},\n", node_idx));
279 output.push_str(&format!(" \"op\": \"{:?}\",\n", node.op));
280 output.push_str(" \"inputs\": [");
281
282 for (j, input_id) in node.inputs.iter().enumerate() {
283 if j > 0 {
284 output.push_str(", ");
285 }
286 output.push_str(&format!("{}", input_id));
287 }
288
289 output.push_str("]\n");
290 output.push_str(" }");
291 }
292
293 output.push_str("\n ]\n");
294 output.push_str("}\n");
295
296 output
297 }
298}
299
300pub struct TensorStatsVisualizer;
302
303impl TensorStatsVisualizer {
304 pub fn visualize(&self, stats: &TensorStats) -> String {
306 format!("{}", stats)
307 }
308
309 pub fn visualize_table(&self, stats: &[TensorStats]) -> String {
311 let mut output = String::new();
312
313 output.push_str("Tensor Statistics\n");
314 output.push_str(&"=".repeat(80));
315 output.push('\n');
316
317 if stats.is_empty() {
318 output.push_str("No tensors recorded\n");
319 return output;
320 }
321
322 output.push_str(&format!(
324 "{:<8} {:<20} {:<15} {:>10} {:>10}\n",
325 "ID", "Shape", "DType", "NaNs", "Infs"
326 ));
327 output.push_str(&"-".repeat(80));
328 output.push('\n');
329
330 for stat in stats {
332 let shape_str = format!("{:?}", stat.shape);
333 let nans = stat.num_nans.unwrap_or(0);
334 let infs = stat.num_infs.unwrap_or(0);
335
336 output.push_str(&format!(
337 "{:<8} {:<20} {:<15} {:>10} {:>10}\n",
338 stat.tensor_id, shape_str, stat.dtype, nans, infs
339 ));
340
341 if stat.has_numerical_issues() {
343 output.push_str(" ⚠️ Numerical issues detected!\n");
344 }
345 }
346
347 output.push_str(&"=".repeat(80));
348 output.push('\n');
349
350 output
351 }
352
353 pub fn histogram(&self, values: &[f64], bins: usize) -> String {
355 let mut output = String::new();
356
357 if values.is_empty() {
358 return "No values\n".to_string();
359 }
360
361 let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
362 let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
363 let range = max - min;
364
365 if range == 0.0 {
366 return format!("All values are {:.6}\n", min);
367 }
368
369 let mut counts = vec![0; bins];
371 for &value in values {
372 let bin = ((value - min) / range * bins as f64) as usize;
373 let bin = bin.min(bins - 1);
374 counts[bin] += 1;
375 }
376
377 let max_count = *counts
378 .iter()
379 .max()
380 .expect("counts has bins elements, so max always exists");
381
382 output.push_str("Value Distribution\n");
384 output.push_str(&"=".repeat(50));
385 output.push('\n');
386
387 for (i, &count) in counts.iter().enumerate() {
388 let bin_start = min + (i as f64 / bins as f64) * range;
389 let bin_end = min + ((i + 1) as f64 / bins as f64) * range;
390 let bar_width = if max_count > 0 {
391 (count as f64 / max_count as f64 * 40.0) as usize
392 } else {
393 0
394 };
395
396 output.push_str(&format!(
397 "[{:>8.2}, {:>8.2}): {} ({})\n",
398 bin_start,
399 bin_end,
400 "█".repeat(bar_width),
401 count
402 ));
403 }
404
405 output.push_str(&"=".repeat(50));
406 output.push('\n');
407
408 output
409 }
410}
411
412pub struct ExportFormat;
414
415impl ExportFormat {
416 pub fn trace_to_json(trace: &ExecutionTrace) -> String {
418 let mut output = String::new();
419
420 output.push_str("{\n");
421 output.push_str(&format!(
422 " \"total_duration_ms\": {},\n",
423 trace.total_duration_ms()
424 ));
425 output.push_str(" \"entries\": [\n");
426
427 for (i, entry) in trace.entries().iter().enumerate() {
428 if i > 0 {
429 output.push_str(",\n");
430 }
431 output.push_str(" {\n");
432 output.push_str(&format!(" \"entry_id\": {},\n", entry.entry_id));
433 output.push_str(&format!(" \"node_id\": {},\n", entry.node_id));
434 output.push_str(&format!(" \"operation\": \"{}\",\n", entry.operation));
435 output.push_str(&format!(
436 " \"duration_ms\": {},\n",
437 entry.duration_ms()
438 ));
439 output.push_str(&format!(" \"input_ids\": {:?},\n", entry.input_ids));
440 output.push_str(&format!(" \"output_ids\": {:?}\n", entry.output_ids));
441 output.push_str(" }");
442 }
443
444 output.push_str("\n ]\n");
445 output.push_str("}\n");
446
447 output
448 }
449
450 pub fn graph_to_graphml(graph: &EinsumGraph) -> String {
452 let mut output = String::new();
453
454 output.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n");
455 output.push_str("<graphml xmlns=\"http://graphml.graphdrawing.org/xmlns\">\n");
456 output.push_str(" <graph id=\"G\" edgedefault=\"directed\">\n");
457
458 for (node_idx, node) in graph.nodes.iter().enumerate() {
460 output.push_str(&format!(" <node id=\"n{}\">\n", node_idx));
461 output.push_str(&format!(
462 " <data key=\"operation\">{:?}</data>\n",
463 node.op
464 ));
465 output.push_str(" </node>\n");
466 }
467
468 for (node_idx, node) in graph.nodes.iter().enumerate() {
470 for input_id in &node.inputs {
471 output.push_str(&format!(
472 " <edge source=\"n{}\" target=\"n{}\"/>\n",
473 input_id, node_idx
474 ));
475 }
476 }
477
478 output.push_str(" </graph>\n");
479 output.push_str("</graphml>\n");
480
481 output
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use crate::debug::{ExecutionTracer, TensorStats};
489 use std::collections::HashMap;
490 use std::time::Duration;
491
492 #[test]
493 fn test_timeline_visualizer() {
494 let mut tracer = ExecutionTracer::new();
495 tracer.enable();
496 tracer.start_trace(Some(1));
497
498 let handle = tracer.record_operation_start(0, "einsum", vec![]);
499 std::thread::sleep(Duration::from_millis(10));
500 tracer.record_operation_end(handle, 0, "einsum", vec![], vec![1], HashMap::new());
501
502 let trace = tracer.get_trace();
503 let visualizer = TimelineVisualizer::new(TimelineConfig::default());
504 let output = visualizer.visualize_trace(trace);
505
506 assert!(output.contains("Execution Timeline"));
507 assert!(output.contains("Node 0"));
508 assert!(output.contains("einsum"));
509 }
510
511 #[test]
512 fn test_graph_visualizer_ascii() {
513 use tensorlogic_ir::EinsumNode;
514
515 let graph = EinsumGraph {
516 tensors: vec!["input".to_string(), "output".to_string()],
517 nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
518 inputs: vec![0],
519 outputs: vec![1],
520 tensor_metadata: HashMap::new(),
521 };
522
523 let visualizer = GraphVisualizer::new(GraphConfig::default());
524 let output = visualizer.visualize_ascii(&graph);
525
526 assert!(output.contains("Computation Graph"));
527 assert!(output.contains("Node 0"));
528 }
529
530 #[test]
531 fn test_graph_visualizer_dot() {
532 use tensorlogic_ir::EinsumNode;
533
534 let graph = EinsumGraph {
535 tensors: vec!["input".to_string(), "output".to_string()],
536 nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
537 inputs: vec![0],
538 outputs: vec![1],
539 tensor_metadata: HashMap::new(),
540 };
541
542 let visualizer = GraphVisualizer::new(GraphConfig::default());
543 let output = visualizer.visualize_dot(&graph);
544
545 assert!(output.contains("digraph ComputationGraph"));
546 assert!(output.contains("n0"));
547 }
548
549 #[test]
550 fn test_graph_visualizer_json() {
551 use tensorlogic_ir::EinsumNode;
552
553 let graph = EinsumGraph {
554 tensors: vec!["input".to_string(), "output".to_string()],
555 nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
556 inputs: vec![0],
557 outputs: vec![1],
558 tensor_metadata: HashMap::new(),
559 };
560
561 let visualizer = GraphVisualizer::new(GraphConfig::default());
562 let output = visualizer.visualize_json(&graph);
563
564 assert!(output.contains("\"nodes\""));
565 assert!(output.contains("\"id\": 0"));
566 }
567
568 #[test]
569 fn test_tensor_stats_visualizer() {
570 let stats =
571 TensorStats::new(0, vec![2, 3], "f64").with_statistics(0.0, 1.0, 0.5, 0.25, 0, 0);
572
573 let visualizer = TensorStatsVisualizer;
574 let output = visualizer.visualize(&stats);
575
576 assert!(output.contains("Tensor 0"));
577 assert!(output.contains("f64"));
578 }
579
580 #[test]
581 fn test_tensor_stats_table() {
582 let stats = vec![
583 TensorStats::new(0, vec![2, 3], "f64"),
584 TensorStats::new(1, vec![4, 5], "f64"),
585 ];
586
587 let visualizer = TensorStatsVisualizer;
588 let output = visualizer.visualize_table(&stats);
589
590 assert!(output.contains("Tensor Statistics"));
591 assert!(output.contains("ID"));
592 assert!(output.contains("Shape"));
593 }
594
595 #[test]
596 fn test_histogram() {
597 let values = vec![1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0];
598 let visualizer = TensorStatsVisualizer;
599 let output = visualizer.histogram(&values, 5);
600
601 assert!(output.contains("Value Distribution"));
602 assert!(output.contains("█"));
603 }
604
605 #[test]
606 fn test_export_trace_to_json() {
607 let mut tracer = ExecutionTracer::new();
608 tracer.enable();
609 tracer.start_trace(Some(1));
610
611 let handle = tracer.record_operation_start(0, "einsum", vec![]);
612 tracer.record_operation_end(handle, 0, "einsum", vec![], vec![1], HashMap::new());
613
614 let trace = tracer.get_trace();
615 let json = ExportFormat::trace_to_json(trace);
616
617 assert!(json.contains("total_duration_ms"));
618 assert!(json.contains("entries"));
619 assert!(json.contains("\"operation\": \"einsum\""));
620 }
621
622 #[test]
623 fn test_export_graph_to_graphml() {
624 use tensorlogic_ir::EinsumNode;
625
626 let graph = EinsumGraph {
627 tensors: vec!["input".to_string(), "output".to_string()],
628 nodes: vec![EinsumNode::new("ij->ij", vec![], vec![1])],
629 inputs: vec![0],
630 outputs: vec![1],
631 tensor_metadata: HashMap::new(),
632 };
633
634 let graphml = ExportFormat::graph_to_graphml(&graph);
635
636 assert!(graphml.contains("<?xml"));
637 assert!(graphml.contains("<graphml"));
638 assert!(graphml.contains("<node id=\"n0\""));
639 }
640}